4.2. Benchmarking: In-silico Models
📘 Overview
This notebook benchmarks our DILI prediction model against state-of-the-art in-silico models DILIPredictor, DILIGeNN and the recently released LLM from Google DeepMind - TxGemma Predict 2B, 9B, 27B variants.
Important note: To run this notebook, download the reproducibility directory, which contains all required files and scripts.
Outputs: * Figure 6B * Benchmark summary
[1]:
%%capture
!pip install torch-geometric captum rdkit matplotlib_venn
[2]:
%load_ext autoreload
%autoreload 2
[3]:
import pandas as pd
import numpy as np
import random
import dilimap as dmap
from sklearn.metrics import (
confusion_matrix,
balanced_accuracy_score,
recall_score,
)
import torch
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import inchi
from rdkit.Chem import AllChem, DataStructs
/opt/anaconda3/envs/py310/lib/python3.10/site-packages/threadpoolctl.py:1226: RuntimeWarning:
Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md
warnings.warn(msg, RuntimeWarning)
[4]:
import rdkit
rdkit.__version__
[4]:
'2023.09.2'
[5]:
from insilico_benchmarks.graph_gen_diligenn import process_dili_data, clean_dili_data
from insilico_benchmarks.run_diligenn import (
diligenn_predict_outer_folds,
diligenn_predict_outer_folds_warm_starts,
Graph_custom,
GNNModel,
)
[6]:
from sklearn.metrics import roc_auc_score
from matplotlib_venn import venn3, venn2
1. Utility functions
[7]:
def smiles_to_inchikey14(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
inchikey = inchi.MolToInchiKey(mol)
return inchikey[:14] # return first 14 characters
[8]:
def compute_tanimoto_similarity_matrix(
list_of_smiles_1: pd.DataFrame,
list_of_smiles_2: pd.DataFrame,
) -> np.ndarray:
"""
Computes the Tanimoto similarity matrix between two sets of SMILES strings.
Args:
smiles_col1 (str): Column name for SMILES in df1.
smiles_col2 (str): Column name for SMILES in df2.
Returns:
np.ndarray: Similarity matrix of shape (len(df1), len(df2)).
"""
# Convert SMILES to RDKit Mol
mols1 = [Chem.MolFromSmiles(s) for s in list_of_smiles_1]
mols2 = [Chem.MolFromSmiles(s) for s in list_of_smiles_2]
# Filter out None values
valid1 = [(i, m) for i, m in enumerate(mols1) if m is not None]
valid2 = [(j, m) for j, m in enumerate(mols2) if m is not None]
if not valid1 or not valid2:
return np.empty((0, 0)) # Return empty matrix if no valid inputs
idx1, mols1_valid = zip(*valid1)
idx2, mols2_valid = zip(*valid2)
# Compute fingerprints
fps1 = [AllChem.GetMorganFingerprintAsBitVect(m, 2) for m in mols1_valid]
fps2 = [AllChem.GetMorganFingerprintAsBitVect(m, 2) for m in mols2_valid]
# Compute similarity matrix
sim_matrix = np.array(
[DataStructs.BulkTanimotoSimilarity(fp1, fps2) for fp1 in fps1]
)
# Fill full matrix with NaNs for invalid entries
full_matrix = np.full((len(list_of_smiles_1), len(list_of_smiles_2)), np.nan)
for i, row in zip(idx1, sim_matrix):
for j, val in zip(idx2, row):
full_matrix[i, j] = val
return full_matrix
[9]:
def calculate_comprehensive_metrics(df, true_label_col, prediction_columns):
"""
Calculate metrics for multiple models and create comparison table.
Parameters:
- df: DataFrame containing predictions
- true_label_col: column name for true labels
- prediction_columns: list of column names containing binary predictions
Returns:
- DataFrame with metrics comparison
"""
y_true = df[true_label_col].astype(str)
y_true_bool = (y_true == 'True').astype(bool)
metrics_data = []
for i, pred_col in enumerate(prediction_columns):
if pred_col not in df.columns:
continue
# Handle different prediction formats
if df[pred_col].dtype == 'bool':
y_pred = df[pred_col]
elif df[pred_col].dtype == 'object':
y_pred = (df[pred_col].astype(str) == 'True').astype(bool)
else:
y_pred = df[pred_col].astype(bool)
# Calculate metrics
balanced_acc = balanced_accuracy_score(y_true_bool, y_pred)
recall = recall_score(y_true_bool, y_pred)
# Calculate specificity
try:
tn, fp, fn, tp = confusion_matrix(y_true_bool, y_pred).ravel()
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
except Exception:
specificity = 0
metrics_data.append(
{
'Model': pred_col.replace('_', ' ').title(),
'Balanced Accuracy': round(balanced_acc, 3),
'Specificity': round(specificity, 3),
'Sensitivity': round(recall, 3),
# 'F1-Score': round(f1, 3),
}
)
return pd.DataFrame(metrics_data)
[10]:
def plot_roc_curves(
df, true_label_col, probability_columns, model_names=None, colors=None
):
"""
Plot ROC curves for multiple models.
Parameters:
- df: DataFrame containing predictions
- true_label_col: column name for true labels
- probability_columns: list of column names containing probability scores
- model_names: list of model names for legend (optional)
- colors: list of colors for each curve (optional)
"""
y_true = df[true_label_col].astype(str)
y_true_bool = (y_true == 'True').astype(bool)
if colors is None:
colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray']
if model_names is None:
model_names = [col.replace('_', ' ').title() for col in probability_columns]
for i, (prob_col, model_name, color) in enumerate(
zip(probability_columns, model_names, colors)
):
if prob_col not in df.columns:
continue
try:
y_prob = df[prob_col].astype(float)
show_plot = i == len(probability_columns) - 1 # Show only on last iteration
dmap.pl.roc_curve(
y_true_bool, y_prob, label=model_name, color=color, show=show_plot
)
except Exception as e:
print(f'Could not plot ROC curve for {model_name}: {e}')
[11]:
def print_confusion_matrix_analysis(
df, true_label_col='True_label', pred_label_col='Predicted_label'
):
"""
Print value counts for true labels and calculate confusion matrix components.
Parameters:
df: DataFrame containing the predictions
true_label_col: Column name for true labels
pred_label_col: Column name for predicted labels
"""
# Print value counts for True_label
print('True_label distribution:')
print(df[true_label_col].value_counts())
# Calculate and print true positives and true negatives for ToxPredictor
y_true = (df[true_label_col].astype(str) == 'True').astype(bool)
y_pred = df[pred_label_col].astype(bool)
# Calculate confusion matrix components
tp = ((y_true == True) & (y_pred == True)).sum()
tn = ((y_true == False) & (y_pred == False)).sum()
print(f'\nConfusion matrix: for {pred_label_col}')
print(
f'True Positives (TP): {tp} / Total DILI {df[true_label_col].value_counts()["True"]}'
)
print(
f'True Negatives (TN): {tn}/ Total Non-DILI {df[true_label_col].value_counts()["False"]}'
)
[12]:
def set_random_seeds(seed=42):
"""
Set random seeds for reproducibility across numpy and torch.
Parameters:
seed: Random seed value (default: 42)
"""
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Print all seeds in use
print('Seeds currently in use:')
print(f'PyTorch manual seed: {torch.initial_seed()}')
print(f'NumPy random seed: {np.random.get_state()[1][0]}')
print(f'CUDA deterministic: {torch.backends.cudnn.deterministic}')
print(f'CUDA benchmark: {torch.backends.cudnn.benchmark}')
return
set_random_seeds(seed=42)
Seeds currently in use:
PyTorch manual seed: 42
NumPy random seed: 42
CUDA deterministic: True
CUDA benchmark: False
2. DILImap results
[13]:
# Pull model
model = dmap.models.ToxPredictor()
2025-07-28 12:01:25,474 - INFO - Found credentials in environment variables.
2025-07-28 12:01:26,548 - INFO - Found credentials in environment variables.
2025-07-28 12:01:27,406 - INFO - Found credentials in environment variables.
Package: s3://dilimap/public/models. Top hash: b119d5a238
2025-07-28 12:01:31,520 - INFO - Found credentials in environment variables.
2025-07-28 12:01:32,466 - INFO - Found credentials in environment variables.
2025-07-28 12:01:33,203 - INFO - Found credentials in environment variables.
Package: s3://dilimap/public/data. Top hash: e5bf3de9d2
Package: s3://dilimap/public/models. Top hash: b119d5a238
AWS credentials not found in environment or .env file.
Package: s3://dilimap/public/data. Top hash: e5bf3de9d2
[14]:
# Get results from validation data
adata_val = dmap.datasets.DILImap_validation_data()
dmap.utils.map_dili_labels_and_cmax(adata_val)
df_res_margins_val = model.compute_safety_margin(adata_val)
2025-07-28 12:01:34,667 - INFO - Found credentials in environment variables.
2025-07-28 12:01:35,774 - INFO - Found credentials in environment variables.
2025-07-28 12:01:36,509 - INFO - Found credentials in environment variables.
Package: s3://dilimap/public/data. Top hash: e5bf3de9d2
2025-07-28 12:01:38,316 - INFO - Found credentials in environment variables.
2025-07-28 12:01:39,443 - INFO - Found credentials in environment variables.
2025-07-28 12:01:40,154 - INFO - Found credentials in environment variables.
Package: s3://dilimap/public/data. Top hash: e5bf3de9d2
2025-07-28 12:01:41,498 - INFO - Found credentials in environment variables.
2025-07-28 12:01:42,588 - INFO - Found credentials in environment variables.
2025-07-28 12:01:43,301 - INFO - Found credentials in environment variables.
Package: s3://dilimap/public/data. Top hash: e5bf3de9d2
2025-07-28 12:01:44,615 - INFO - Found credentials in environment variables.
2025-07-28 12:01:45,713 - INFO - Found credentials in environment variables.
2025-07-28 12:01:46,445 - INFO - Found credentials in environment variables.
Package: s3://dilimap/public/data. Top hash: e5bf3de9d2
283 out of 469 features in your data are not present in the training data. These features will not impact predictions. You can access the features available in the training data via `model.features`.
[15]:
# Concatenate cross-val and validation results
df_res_all = pd.concat(
[
model.cross_val_results.assign(Batch='cross-validation'),
df_res_margins_val.assign(Batch='validation'),
],
ignore_index=False,
)
[16]:
# Map DILI label and smiles
dmap.utils.map_dili_labels_and_cmax(
df_res_all,
labels=['DILIrank', 'label_section', 'livertox_score', 'DILI_label', 'smiles'],
insert_at_front=True,
)
df_res_all.loc['Auranofin', 'smiles'] = (
'CCP(CC)CC.CC(=O)OC[C@@H]1[C@H]([C@@H]([C@H]([C@@H](O1)[S-])OC(=O)C)OC(=O)C)OC(=O)C.[Au+]'
)
df_res_all.loc['FIRU', 'smiles'] = (
'O[C@H]1[C@@H](F)[C@H](N2C(NC(C(I)=C2)=O)=O)O[C@@H]1CO'
)
2025-07-28 12:01:48,700 - INFO - Found credentials in environment variables.
2025-07-28 12:01:49,875 - INFO - Found credentials in environment variables.
2025-07-28 12:01:50,540 - INFO - Found credentials in environment variables.
Package: s3://dilimap/public/data. Top hash: e5bf3de9d2
Package: s3://dilimap/public/data. Top hash: e5bf3de9d2
[17]:
# True vs. predicted labels
df_res_all['True_label'] = np.where(
df_res_all['DILI_label'].isin(
['DILI (withdrawn)', 'DILI (known)', 'DILI (likely)']
),
True,
np.where(df_res_all['DILI_label'].isin(['No DILI']), False, ''),
)
df_res_all['Predicted_label'] = df_res_all['Classification'].map(
{'+': True, '-': False}
)
df_res_all['compound_name'] = df_res_all.index.tolist()
[18]:
len(df_res_all)
[18]:
300
[19]:
indices_true_false = df_res_all[
df_res_all['True_label'].isin(['True', 'False'])
].index.tolist()
len(indices_true_false)
[19]:
210
[20]:
df_res_all.head()
[20]:
| DILIrank | label_section | livertox_score | smiles | DILI_label | Cmax_uM | First_DILI_uM | MOS_Cytotoxicity | MOS_ToxPredictor | Primary_DILI_driver | Classification | Batch | True_label | Predicted_label | compound_name | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| compound_name | |||||||||||||||
| Abacavir | Most-DILI-Concern | Warnings and precautions | C | Nc1nc(NC2CC2)c2ncn([C@H]3C=C[C@@H](CO)C3)c2n1 | DILI (likely) | 14.850000 | 333.333333 | 300.00000 | 22.446689 | Transcriptomics | + | cross-validation | True | True | Abacavir |
| Acarbose | Most-DILI-Concern | Warnings and precautions | B | C[C@H]1O[C@H](O[C@H]2[C@H](O)[C@@H](O)[C@@H](O... | DILI (likely) | 0.112949 | 333.333333 | 300.00000 | 300.000000 | none | - | cross-validation | True | False | Acarbose |
| Acebutolol | Less-DILI-Concern | Adverse reactions | C | CCCC(=O)Nc1ccc(OCC(O)CNC(C)C)c(C(C)=O)c1 | DILI (few cases) | 2.737568 | 333.333000 | 300.00000 | 121.762455 | none | - | cross-validation | False | Acebutolol | |
| Aceclofenac | Less-DILI-Concern | No match | NaN | O=C(O)COC(=O)Cc1ccccc1Nc1c(Cl)cccc1Cl | DILI (few cases) | 30.443705 | 100.000000 | 20.01169 | 3.284751 | Transcriptomics | + | cross-validation | True | Aceclofenac | |
| Acetaminophen | Most-DILI-Concern | Warnings and precautions | A [HD] | CC(=O)Nc1ccc(O)cc1 | DILI (known) | 132.310000 | NaN | 300.00000 | 15.116015 | Cmax | + | cross-validation | True | True | Acetaminophen |
Load DILI SMILES dataframe
[21]:
all_cmpds = dmap.s3.read('compound_DILI_labels.csv')
all_cmpds.source.value_counts()
all_cmpds.loc['Auranofin', 'smiles'] = (
'CCP(CC)CC.CC(=O)OC[C@@H]1[C@H]([C@@H]([C@H]([C@@H](O1)[S-])OC(=O)C)OC(=O)C)OC(=O)C.[Au+]'
)
all_cmpds.loc['FIRU', 'smiles'] = (
'O[C@H]1[C@@H](F)[C@H](N2C(NC(C(I)=C2)=O)=O)O[C@@H]1CO'
)
cellarity_smiles_dili_dataset = all_cmpds[
(
(all_cmpds['DILI_label_binary'].notna())
&
# (all_cmpds['DILIrank'].notna()) &
(all_cmpds['smiles'].notna())
)
# (all_cmpds['compound_name'].isin(specific_compounds))
].copy()
cellarity_smiles_dili_dataset.DILI_label_binary.value_counts()
2025-07-28 12:01:52,160 - INFO - Found credentials in environment variables.
2025-07-28 12:01:53,308 - INFO - Found credentials in environment variables.
2025-07-28 12:01:54,056 - INFO - Found credentials in environment variables.
Package: s3://dilimap/public/data. Top hash: e5bf3de9d2
[21]:
DILI_label_binary
No DILI 712
DILI 339
Name: count, dtype: int64
[21]:
DILI_label_binary
No DILI 712
DILI 339
Name: count, dtype: int64
3. Run DILIGeNN on validation set
3.1 Load data and identify unseen subset using InChiKey14, Compound Name and SMILES
[22]:
df_dg = pd.read_csv('./insilico_benchmarks/smiles_data/DILIGeNN_seen_1167.csv')
df_dg.head()
[22]:
| DILIST_ID | name | label | Routs of Administration | smiles | smiles_pcp | error_msg | smiles_std | |
|---|---|---|---|---|---|---|---|---|
| 0 | 1 | mercaptopurine | 1 | Oral | C1=NC2=C(N1)C(=S)N=CN2 | C1=NC2=C(N1)C(=S)N=CN2 | Success | S=c1[nH]cnc2[nH]cnc12 |
| 1 | 2 | acetaminophen | 1 | Oral | CC(=O)NC1=CC=C(C=C1)O | CC(=O)NC1=CC=C(C=C1)O | Success | CC(=O)Nc1ccc(O)cc1 |
| 2 | 3 | azathioprine | 1 | Oral | CN1C=NC(=C1SC2=NC=NC3=C2NC=N3)[N+](=O)[O-] | CN1C=NC(=C1SC2=NC=NC3=C2NC=N3)[N+](=O)[O-] | Success | Cn1cnc([N+](=O)[O-])c1Sc1ncnc2[nH]cnc12 |
| 3 | 4 | chlorpheniramine | 0 | Oral | CN(C)CCC(C1=CC=C(C=C1)Cl)C2=CC=CC=N2 | CN(C)CCC(C1=CC=C(C=C1)Cl)C2=CC=CC=N2 | Success | CN(C)CCC(c1ccc(Cl)cc1)c1ccccn1 |
| 4 | 5 | clofibrate | 1 | Oral | CCOC(=O)C(C)(C)OC1=CC=C(C=C1)Cl | CCOC(=O)C(C)(C)OC1=CC=C(C=C1)Cl | Success | CCOC(=O)C(C)(C)Oc1ccc(Cl)cc1 |
[23]:
cellarity_smiles_dili_dataset.head()
[23]:
| LTKBID | compound_name | DILIrank | label_section | severity_class | DILIst | roa | source | livertox_score | livertox_primary_classification | ... | pathway | information | CAS_number | DMSO_mM_solubility | DILI_label | DILI_label_binary | DILI_label_section | livertox_updated | livertox_iDILI | livertox_mechanism_summary | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 4-aminosalicylic acid | LT00505 | 4-aminosalicylic acid | Most-DILI-Concern | Warnings and precautions | 5.0 | NaN | Oral | DILIrank | NaN | NaN | ... | Microbiology | 4-Aminosalicylic acid (Para-aminosalicylic aci... | 65-49-6 | 65.3 | DILI (likely) | DILI | Likely | NaN | NaN | NaN |
| Abacavir | LT00040 | Abacavir | Most-DILI-Concern | Warnings and precautions | 8.0 | NaN | Oral | DILIrank | C | Antimicrobial | ... | Microbiology | Abacavir (1592U89, ABC) is a powerful nucleosi... | 136470-78-5 | 199.07 | DILI (likely) | DILI | Likely | 2016 Jan 4 | NaN | Hepatitis, Cholestasis/Biliary, Hypersensitivi... |
| Acarbose | LT01034 | Acarbose | Most-DILI-Concern | Warnings and precautions | 8.0 | NaN | Oral | DILIrank | B | Endocrine | ... | Metabolism | Acarbose(BAY g 5421,Prandase, Precose, Glucoba... | 56180-94-0 | 154.89 | DILI (likely) | DILI | Likely | 2021 Jan 10 | NaN | Hepatitis, Cholestasis/Biliary, Immune-mediate... |
| Acetaminophen | LT00004 | Acetaminophen | Most-DILI-Concern | Warnings and precautions | 5.0 | NaN | Oral | DILIrank | A [HD] | Analgesic | ... | Neuronal Signaling | Acetaminophen is a COX inhibitor for COX-1 and... | 103-90-2 | 198.47 | DILI (known) | DILI | Known | 2016 Jan 28 | NaN | CYP, Hypersensitivity |
| Acetazolamide | LT00498 | Acetazolamide | Most-DILI-Concern | Warnings and precautions | 8.0 | NaN | Oral | DILIrank | D | Cardiovascular | ... | Metabolism | Acetazolamide (Diamox), a potent carbonic anhy... | 59-66-5 | 197.98 | DILI (likely) | DILI | Likely | NaN | NaN | NaN |
5 rows × 60 columns
[24]:
cellarity_smiles_dili_dataset['InChiKey14'] = [
smiles_to_inchikey14(smiles) for smiles in cellarity_smiles_dili_dataset['smiles']
]
df_dg['InChiKey14'] = [smiles_to_inchikey14(smiles) for smiles in df_dg['smiles_std']]
[25]:
# 1. Direct SMILES matching (intersection)
smiles_overlap_values = set(cellarity_smiles_dili_dataset['smiles']).intersection(
set(df_dg['smiles'])
)
smiles_overlap_mask = cellarity_smiles_dili_dataset['smiles'].isin(
smiles_overlap_values
)
# 1b. Direct SMILES matching for standardized SMILES (intersection)
if 'smiles_std' in df_dg.columns:
smiles_std_overlap_values = set(
cellarity_smiles_dili_dataset['smiles']
).intersection(set(df_dg['smiles_std']))
smiles_std_overlap_mask = cellarity_smiles_dili_dataset['smiles'].isin(
smiles_std_overlap_values
)
else:
smiles_std_overlap_mask = pd.Series(
False, index=cellarity_smiles_dili_dataset.index
)
# 2. Name matching (case-insensitive, intersection)
cellarity_names_lower = (
cellarity_smiles_dili_dataset['compound_name'].astype(str).str.lower()
)
df_dg_names_lower = df_dg['name'].astype(str).str.lower()
name_overlap_values = set(cellarity_names_lower).intersection(set(df_dg_names_lower))
name_overlap_mask = cellarity_names_lower.isin(name_overlap_values)
# 3. InChiKey14 based overlap (intersection)
cellarity_inchikey14_lower = (
cellarity_smiles_dili_dataset['InChiKey14'].astype(str).str.lower()
)
df_dg_inchikey14_lower = df_dg['InChiKey14'].astype(str).str.lower()
inchikey14_overlap_values = set(cellarity_inchikey14_lower).intersection(
set(df_dg_inchikey14_lower)
)
inchikey14_overlap_mask = cellarity_inchikey14_lower.isin(inchikey14_overlap_values)
# Combine all overlap masks
overlap_mask = (
smiles_overlap_mask
| smiles_std_overlap_mask
| name_overlap_mask
| inchikey14_overlap_mask
)
# Subset to compounds that do NOT overlap
diligenn_unseeen_smiles = cellarity_smiles_dili_dataset[~overlap_mask].copy()
print(f'Unseen compounds: {len(diligenn_unseeen_smiles)}')
Unseen compounds: 349
[26]:
diligenn_unseeen_smiles.head()
[26]:
| LTKBID | compound_name | DILIrank | label_section | severity_class | DILIst | roa | source | livertox_score | livertox_primary_classification | ... | information | CAS_number | DMSO_mM_solubility | DILI_label | DILI_label_binary | DILI_label_section | livertox_updated | livertox_iDILI | livertox_mechanism_summary | InChiKey14 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Anagrelide | LT00964 | Anagrelide | Ambiguous DILI-concern | Warnings and precautions | 3.0 | NaN | NaN | DILIrank | E* | Hematologic | ... | NaN | NaN | NaN | No DILI | No DILI | No DILI | 2017 Jul 5 | NaN | NaN | OTBXOEAOVRKTNQ |
| Aprepitant | LT01195 | Aprepitant | Ambiguous DILI-concern | Adverse reactions | 3.0 | NaN | NaN | DILIrank | E | Gastrointestinal | ... | Aprepitant (MK-0869, L-754030, Emend) is a pot... | 170729-80-3 | 200.21 | No DILI | No DILI | No DILI | 2024 Feb 28 | NaN | NaN | ATALOFNDEOCMKK |
| Bacitracin | LT02059 | Bacitracin | No-DILI-Concern | No match | 0.0 | NaN | Intramuscular | DILIrank | NaN | NaN | ... | Bacitracin is a mixture of related cyclic poly... | 1405-87-4 | NaN | No DILI | No DILI | No DILI | NaN | NaN | NaN | CLKOFPXJLQSYAH |
| Bortezomib | LT01202 | Bortezomib | Most-DILI-Concern | Warnings and precautions | 7.0 | NaN | Intravenous | DILIrank | C | Antineoplastic | ... | Bortezomib (PS-341, Velcade, LDP-341, MLM341, ... | 179324-69-7 | 197.79 | DILI (likely) | DILI | Likely | 2017 Sep 30 | NaN | Hepatitis, Necrosis, CYP, Cholestasis/Biliary | GXJABQQUPOEUTA |
| Butabarbital | LT00707 | Butabarbital | Ambiguous DILI-concern | Adverse reactions | 3.0 | NaN | NaN | DILIrank | E | CNS | ... | NaN | NaN | NaN | No DILI | No DILI | No DILI | NaN | NaN | NaN | ZRIHAIZYIMGOAB |
5 rows × 61 columns
[27]:
diligenn_unseeen_smiles['label'] = diligenn_unseeen_smiles['DILI_label_binary'].map(
{'DILI': 1, 'No DILI': 0}
)
# diligenn_unseeen_smiles.to_csv('./insilico_benchmarks/smiles_data/unseen_smiles_diligenn_349.csv')
[28]:
df_res_all[df_res_all.index.isin(diligenn_unseeen_smiles.index)]
[28]:
| DILIrank | label_section | livertox_score | smiles | DILI_label | Cmax_uM | First_DILI_uM | MOS_Cytotoxicity | MOS_ToxPredictor | Primary_DILI_driver | Classification | Batch | True_label | Predicted_label | compound_name | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| compound_name | |||||||||||||||
| Hydroxyzine | No-DILI-Concern | No match | E | OCCOCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1 | No DILI | 0.196691 | 1000.000000 | 300.000000 | 300.000000 | none | - | cross-validation | False | False | Hydroxyzine |
| AKN-028 | NaN | NaN | NaN | Nc1ncc(-c2ccncc2)nc1Nc1ccc2[nH]ccc2c1 | DILI (withdrawn) | 1.000000 | 16.666667 | 300.000000 | 16.666667 | Transcriptomics | + | validation | True | True | AKN-028 |
| BMS-986142 | NaN | NaN | NaN | Cc1c(-c2c(F)cc(C(N)=O)c3[nH]c4c(c23)CC[C@H](C(... | DILI (withdrawn) | 1.887880 | 3.703704 | 300.000000 | 1.961832 | Transcriptomics | + | validation | True | True | BMS-986142 |
| Evobrutinib | NaN | NaN | NaN | C=CC(=O)N1CCC(CNc2ncnc(N)c2-c2ccc(Oc3ccccc3)cc... | DILI (withdrawn) | 1.292170 | 1.851000 | 300.000000 | 1.432474 | Transcriptomics | + | validation | True | True | Evobrutinib |
| Orelabrutinib | NaN | NaN | NaN | C=CC(=O)N1CCC(c2ccc(C(N)=O)c(-c3ccc(Oc4ccccc4)... | DILI (withdrawn) | 4.509942 | 33.333333 | 221.732365 | 7.391079 | Transcriptomics | + | validation | True | True | Orelabrutinib |
| TAK-875 | NaN | NaN | NaN | Cc1cc(OCCCS(C)(=O)=O)cc(C)c1-c1cccc(COc2ccc3c(... | DILI (withdrawn) | 2.382632 | 1.851852 | 83.940800 | 1.000000 | Transcriptomics | + | validation | True | True | TAK-875 |
| Tofacitinib | NaN | NaN | E* | C[C@@H]1CCN(C(=O)CC#N)C[C@@H]1N(C)c1ncnc2[nH]c... | No DILI | 0.137977 | 33.333333 | 300.000000 | 241.585460 | none | - | validation | False | False | Tofacitinib |
| Upadacitinib | NaN | NaN | D | CC[C@@H]1CN(C(=O)NCC(F)(F)F)C[C@@H]1c1cnc2cnc3... | No DILI | 0.215563 | 100.000000 | 300.000000 | 300.000000 | none | - | validation | False | False | Upadacitinib |
[29]:
# diligenn_truly_unseen.to_csv('./insilico_benchmarks/smiles_data/diligenn_unseen_std_cleaned.csv')
3.2 Graph pre-processing SMILES, Loading best models and running predictions using reproducible script
Standardize Smiles and Run predictions
[30]:
## Standardize the Smiles before graph generation
standardized_df = process_dili_data(diligenn_unseeen_smiles, smiles_col='smiles')
df_std_filtered = clean_dili_data(
standardized_df, smiles_col='smiles', label_col='label'
)
# Location of saved data object: insilico_benchmark/data/diligenn
Initial dataset size: 349
Compounds with missing SMILES: 0 (0.0%)
After removing missing SMILES: 349
Processing molecules: 100%|███████████████████| 349/349 [00:43<00:00, 7.97it/s]
Standardised dataset size: 349
Compounds failed standardization: 18 (5.2%)
After standardization filtering: 331
Compounds failed conformer generation: 0 (0.0%)
Error indices: []
Final cleaned dataset size: 331
Total compounds filtered out: 18 (5.2%)
[31]:
set_random_seeds(seed=42)
Seeds currently in use:
PyTorch manual seed: 42
NumPy random seed: 42
CUDA deterministic: True
CUDA benchmark: False
Predictions with warm starts
[32]:
diligenn_true_unseen_smiles_warmstart = diligenn_predict_outer_folds_warm_starts(
df_std_filtered,
dataset_class=Graph_custom,
model_name='GraphSAGE_Optimised',
model_class=GNNModel,
device='cuda' if torch.cuda.is_available() else 'cpu',
model_selection_strategy='last_inner_fold',
)
For model GCN, mean val_auc of optimal hyperparameters across 20 inner folds: 0.627 ± 0.034
For model GAT, mean val_auc of optimal hyperparameters across 20 inner folds: 0.627 ± 0.031
For model GraphSAGE, mean val_auc of optimal hyperparameters across 20 inner folds: 0.624 ± 0.032
For model GIN, mean val_auc of optimal hyperparameters across 20 inner folds: 0.626 ± 0.032
100%|█████████████████████████████████████████| 331/331 [03:04<00:00, 1.79it/s]
Using models from last significant run: 3
Selected last inner fold 4 for outer fold 0
Selected last inner fold 4 for outer fold 1
Selected last inner fold 4 for outer fold 2
Selected last inner fold 4 for outer fold 3
Predictions with base model
[33]:
diligenn_true_unseen_smiles_base = diligenn_predict_outer_folds(
df_std_filtered,
dataset_class=Graph_custom,
model_name='GraphSAGE_Optimised',
model_class=GNNModel,
device='cuda' if torch.cuda.is_available() else 'cpu',
)
For model GCN, mean val_auc of optimal hyperparameters across 20 inner folds: 0.627 ± 0.034
For model GAT, mean val_auc of optimal hyperparameters across 20 inner folds: 0.627 ± 0.031
For model GraphSAGE, mean val_auc of optimal hyperparameters across 20 inner folds: 0.624 ± 0.032
For model GIN, mean val_auc of optimal hyperparameters across 20 inner folds: 0.626 ± 0.032
100%|█████████████████████████████████████████| 331/331 [02:59<00:00, 1.84it/s]
[34]:
print(
f'DILIGeNN warm start filtered shape: {diligenn_true_unseen_smiles_warmstart.shape}'
)
print(f'DILIGeNN base filtered shape: {diligenn_true_unseen_smiles_base.shape}')
DILIGeNN warm start filtered shape: (331, 72)
DILIGeNN base filtered shape: (331, 72)
[35]:
# Define columns to merge
prob_cols = ['prob_outer0', 'prob_outer1', 'prob_outer2', 'prob_outer3']
label_cols = ['label_outer0', 'label_outer1', 'label_outer2', 'label_outer3']
# Prepare base model columns with suffix
cols = prob_cols + label_cols
base_renamed = {col: f'{col}_base' for col in cols} # keep for reference
diligenn_base_renamed = diligenn_true_unseen_smiles_base.rename(columns=base_renamed)
warmstart_renamed = {col: f'{col}_warmstart' for col in prob_cols + label_cols}
diligenn_warmstart_renamed = diligenn_true_unseen_smiles_warmstart[cols].rename(
columns=warmstart_renamed
)
# Merge on index (assumes same order and index)
combined_df = pd.concat([diligenn_base_renamed, diligenn_warmstart_renamed], axis=1)
# Calculate average probabilities and mode predictions for each model
for model in ['base', 'warmstart']:
combined_df[f'DILIGeNN_avg_prob_{model}'] = combined_df[
[f'prob_outer{i}_{model}' for i in range(4)]
].mean(axis=1)
label_cols_model = [f'label_outer{i}_{model}' for i in range(4)]
combined_df[f'DILIGeNN_mode_pred_{model}'] = combined_df[label_cols_model].mode(
axis=1
)[0]
combined_df['True_label'] = combined_df['DILI_label_binary'].map(
{'DILI': True, 'No DILI': False}
)
label_mapping = {'DILI': 1, 'No DILI': 0}
combined_df['True_label_numeric'] = combined_df['DILI_label_binary'].map(label_mapping)
# combined_df.to_csv('./insilico_benchmarks/predictions/benchmark_diligenn_331_predictions.csv')
Load predictions
[36]:
combined_df = pd.read_csv(
'./insilico_benchmarks/predictions/benchmark_diligenn_331_predictions.csv',
index_col=0,
)
[37]:
combined_df.DILI_label_binary.value_counts()
[37]:
DILI_label_binary
No DILI 284
DILI 47
Name: count, dtype: int64
3.3 Evaluate DILIGeNN GraphSage (warm-start) models on (potentially) unseen data (331 compounds)
[38]:
def get_performance(df, model_prefix):
label_cols = [f'label_outer{i}_{model_prefix}' for i in range(4)]
return calculate_comprehensive_metrics(df, 'True_label', label_cols)
warmstart_performance = get_performance(combined_df, 'warmstart')
base_performance = get_performance(combined_df, 'base')
print('\nWarm Start Model Performance Across Outer Folds:')
display(warmstart_performance)
print('\nBase Model Performance Across Outer Folds:')
display(base_performance)
# Create summary statistics
warmstart_numeric_cols = warmstart_performance.select_dtypes(
include=[np.number]
).columns
base_numeric_cols = base_performance.select_dtypes(include=[np.number]).columns
warmstart_summary = pd.DataFrame(
{
'Model': ['DILIGeNN Warm Start (GraphSAGE)'],
**{
col: [
f'{warmstart_performance[col].mean():.3f} ± {warmstart_performance[col].std():.3f}'
]
for col in warmstart_numeric_cols
},
}
)
base_summary = pd.DataFrame(
{
'Model': ['DILIGeNN Base (GraphSAGE)'],
**{
col: [
f'{base_performance[col].mean():.3f} ± {base_performance[col].std():.3f}'
]
for col in base_numeric_cols
},
}
)
final_comparison_df = pd.concat([warmstart_summary, base_summary], ignore_index=True)
print('\nFinal Performance Comparison:')
display(final_comparison_df)
# Plot ROC curves for both models
plot_roc_curves(
combined_df,
'True_label',
['DILIGeNN_avg_prob_warmstart', 'DILIGeNN_avg_prob_base'],
model_names=['DILIGeNN Warm Start', 'DILIGeNN Base'],
)
Warm Start Model Performance Across Outer Folds:
| Model | Balanced Accuracy | Specificity | Sensitivity | |
|---|---|---|---|---|
| 0 | Label Outer0 Warmstart | 0.595 | 0.254 | 0.936 |
| 1 | Label Outer1 Warmstart | 0.467 | 0.296 | 0.638 |
| 2 | Label Outer2 Warmstart | 0.516 | 0.116 | 0.915 |
| 3 | Label Outer3 Warmstart | 0.593 | 0.292 | 0.894 |
Base Model Performance Across Outer Folds:
| Model | Balanced Accuracy | Specificity | Sensitivity | |
|---|---|---|---|---|
| 0 | Label Outer0 Base | 0.552 | 0.254 | 0.851 |
| 1 | Label Outer1 Base | 0.470 | 0.366 | 0.574 |
| 2 | Label Outer2 Base | 0.498 | 0.165 | 0.830 |
| 3 | Label Outer3 Base | 0.520 | 0.401 | 0.638 |
Final Performance Comparison:
| Model | Balanced Accuracy | Specificity | Sensitivity | |
|---|---|---|---|---|
| 0 | DILIGeNN Warm Start (GraphSAGE) | 0.543 ± 0.062 | 0.239 ± 0.084 | 0.846 ± 0.140 |
| 1 | DILIGeNN Base (GraphSAGE) | 0.510 ± 0.035 | 0.296 ± 0.108 | 0.723 ± 0.138 |
3.4 Benchmarking DILIGeNN GraphSage (warm-start) model against ToxPredictor (8 compounds)
[39]:
overlap_with_df_res_all = set(combined_df['compound_name']).intersection(
set(df_res_all['compound_name'])
)
len(overlap_with_df_res_all)
[39]:
8
[40]:
df_res_all_subset = df_res_all[
~(
(df_res_all.compound_name == 'Chlorpromazine')
& (df_res_all.Batch == 'cross-validation')
)
].copy()
df_res_all_subset = df_res_all_subset[
df_res_all_subset.index.isin(overlap_with_df_res_all)
].copy()
df_res_all_subset.head()
[40]:
| DILIrank | label_section | livertox_score | smiles | DILI_label | Cmax_uM | First_DILI_uM | MOS_Cytotoxicity | MOS_ToxPredictor | Primary_DILI_driver | Classification | Batch | True_label | Predicted_label | compound_name | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| compound_name | |||||||||||||||
| Hydroxyzine | No-DILI-Concern | No match | E | OCCOCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1 | No DILI | 0.196691 | 1000.000000 | 300.000000 | 300.000000 | none | - | cross-validation | False | False | Hydroxyzine |
| AKN-028 | NaN | NaN | NaN | Nc1ncc(-c2ccncc2)nc1Nc1ccc2[nH]ccc2c1 | DILI (withdrawn) | 1.000000 | 16.666667 | 300.000000 | 16.666667 | Transcriptomics | + | validation | True | True | AKN-028 |
| BMS-986142 | NaN | NaN | NaN | Cc1c(-c2c(F)cc(C(N)=O)c3[nH]c4c(c23)CC[C@H](C(... | DILI (withdrawn) | 1.887880 | 3.703704 | 300.000000 | 1.961832 | Transcriptomics | + | validation | True | True | BMS-986142 |
| Evobrutinib | NaN | NaN | NaN | C=CC(=O)N1CCC(CNc2ncnc(N)c2-c2ccc(Oc3ccccc3)cc... | DILI (withdrawn) | 1.292170 | 1.851000 | 300.000000 | 1.432474 | Transcriptomics | + | validation | True | True | Evobrutinib |
| Orelabrutinib | NaN | NaN | NaN | C=CC(=O)N1CCC(c2ccc(C(N)=O)c(-c3ccc(Oc4ccccc4)... | DILI (withdrawn) | 4.509942 | 33.333333 | 221.732365 | 7.391079 | Transcriptomics | + | validation | True | True | Orelabrutinib |
[41]:
diligenn_dmap = combined_df.copy()
diligenn_dmap = diligenn_dmap[
diligenn_dmap.compound_name.isin(overlap_with_df_res_all)
].copy()
# Update diligenn_dmap with values from df_res_all
for col in ['True_label', 'Predicted_label', 'MOS_ToxPredictor']:
diligenn_dmap[col] = diligenn_dmap['compound_name'].map(
df_res_all_subset.set_index('compound_name')[col]
)
diligenn_dmap['ToxPredictor_Inverse_MOS'] = 1 / diligenn_dmap['MOS_ToxPredictor']
diligenn_dmap.head()
[41]:
| LTKBID | compound_name | DILIrank | label_section | severity_class | DILIst | roa | source | livertox_score | livertox_primary_classification | ... | label_outer1_warmstart | label_outer2_warmstart | label_outer3_warmstart | DILIGeNN_avg_prob_warmstart | DILIGeNN_mode_pred_warmstart | True_label | True_label_numeric | Predicted_label | MOS_ToxPredictor | ToxPredictor_Inverse_MOS | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 26 | LT01437 | Hydroxyzine | No-DILI-Concern | No match | 0.0 | NaN | NaN | DILIrank | E | Respiratory | ... | 0 | 1 | 0 | 0.370794 | 0.0 | False | 0 | False | 300.000000 | 0.003333 |
| 49 | NaN | AKN-028 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | 1 | 1 | 1 | 0.991106 | 1.0 | True | 1 | True | 16.666667 | 0.060000 |
| 69 | NaN | BMS-986142 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | 1 | 1 | 1 | 0.998426 | 1.0 | True | 1 | True | 1.961832 | 0.509728 |
| 143 | NaN | Evobrutinib | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | 1 | 1 | 1 | 0.992263 | 1.0 | True | 1 | True | 1.432474 | 0.698093 |
| 220 | NaN | Orelabrutinib | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | 1 | 1 | 1 | 0.995653 | 1.0 | True | 1 | True | 7.391079 | 0.135298 |
5 rows × 89 columns
[42]:
diligenn_dmap.True_label.value_counts()
[42]:
True_label
True 5
False 3
Name: count, dtype: int64
[43]:
def get_performance(df, model_prefix):
label_cols = [f'label_outer{i}_{model_prefix}' for i in range(4)]
return calculate_comprehensive_metrics(df, 'True_label', label_cols)
warmstart_performance = get_performance(diligenn_dmap, 'warmstart')
print('\nWarm Start Model Performance Across Outer Folds:')
display(warmstart_performance)
# Calculate ToxPredictor performance using Predicted_label
toxpredictor_performance = calculate_comprehensive_metrics(
diligenn_dmap, 'True_label', ['Predicted_label']
)
print('\nToxPredictor Performance:')
display(toxpredictor_performance)
# Create summary statistics
warmstart_numeric_cols = warmstart_performance.select_dtypes(
include=[np.number]
).columns
warmstart_summary = pd.DataFrame(
{
'Model': ['DILIGeNN Warm Start (GraphSAGE)'],
**{
col: [
f'{warmstart_performance[col].mean():.3f} ± {warmstart_performance[col].std():.3f}'
]
for col in warmstart_numeric_cols
},
}
)
toxpredictor_summary = pd.DataFrame(
{
'Model': ['ToxPredictor'],
**{
col: [f'{toxpredictor_performance[col].iloc[0]:.3f}']
for col in warmstart_numeric_cols
},
}
)
final_comparison_df = pd.concat(
[toxpredictor_summary, warmstart_summary], ignore_index=True
)
print('\nFinal Performance Comparison:')
display(final_comparison_df)
Warm Start Model Performance Across Outer Folds:
| Model | Balanced Accuracy | Specificity | Sensitivity | |
|---|---|---|---|---|
| 0 | Label Outer0 Warmstart | 0.667 | 0.333 | 1.0 |
| 1 | Label Outer1 Warmstart | 0.667 | 0.333 | 1.0 |
| 2 | Label Outer2 Warmstart | 0.667 | 0.333 | 1.0 |
| 3 | Label Outer3 Warmstart | 1.000 | 1.000 | 1.0 |
ToxPredictor Performance:
| Model | Balanced Accuracy | Specificity | Sensitivity | |
|---|---|---|---|---|
| 0 | Predicted Label | 1.0 | 1.0 | 1.0 |
Final Performance Comparison:
| Model | Balanced Accuracy | Specificity | Sensitivity | |
|---|---|---|---|---|
| 0 | ToxPredictor | 1.000 | 1.000 | 1.000 |
| 1 | DILIGeNN Warm Start (GraphSAGE) | 0.750 ± 0.166 | 0.500 ± 0.334 | 1.000 ± 0.000 |
[44]:
print_confusion_matrix_analysis(
diligenn_dmap,
true_label_col='True_label',
pred_label_col='DILIGeNN_mode_pred_warmstart',
)
print_confusion_matrix_analysis(diligenn_dmap)
True_label distribution:
True_label
True 5
False 3
Name: count, dtype: int64
Confusion matrix: for DILIGeNN_mode_pred_warmstart
True Positives (TP): 5 / Total DILI 5
True Negatives (TN): 2/ Total Non-DILI 3
True_label distribution:
True_label
True 5
False 3
Name: count, dtype: int64
Confusion matrix: for Predicted_label
True Positives (TP): 5 / Total DILI 5
True Negatives (TN): 3/ Total Non-DILI 3
4. DILIPRedictor
4.1 Load data and identify unseen subset using InChiKey14, Compound Name and SMILES
[45]:
dilirank_dp = pd.read_csv(
'./insilico_benchmarks/smiles_data/DILIPredictor_seen_1111.csv'
)
dilirank_dp.head()
[45]:
| smiles_r | TOXICITY | Source_rank | Source | Data | InChIKey | InChIKey14 | protonated_smiles_r | |
|---|---|---|---|---|---|---|---|---|
| 0 | C#Cc1cccc(N=c2[nH]cnc3cc(OCCOC)c(OCCOC)cc23)c1 | 1 | 1 | DILIst Classification Oral | DILI | AAKJLRGGTJKAMG-UHFFFAOYSA-N | AAKJLRGGTJKAMG | C#Cc1cccc(N=c2[nH]cnc3cc(OCCOC)c(OCCOC)cc23)c1 |
| 1 | COC1(NC(=O)CSC(F)F)C(=O)N2C(C(=O)O)=C(CSc3nnnn... | 1 | 1 | DILIst Classification | DILI | UHRBTBZOWWGKMK-UHFFFAOYSA-N | UHRBTBZOWWGKMK | COC1(NC(=O)CSC(F)F)C(=O)[NH+]2C(C(=O)[O-])=C(C... |
| 2 | CC(C)COCC(CN(Cc1ccccc1)c1ccccc1)N1CCCC1 | 1 | 1 | DILIst Classification | DILI | UIEATEWHFDRYRU-UHFFFAOYSA-N | UIEATEWHFDRYRU | CC(C)COCC(CN(Cc1ccccc1)c1ccccc1)[NH+]1CCCC1 |
| 3 | Cc1onc(-c2c(F)cccc2Cl)c1C(=O)NC1C(=O)N2C1SC(C)... | 1 | 1 | DILIst Classification oral | DILI | UIOFUWFRIANQPC-UHFFFAOYSA-N | UIOFUWFRIANQPC | Cc1onc(-c2c(F)cccc2Cl)c1C(=O)NC1C(=O)[NH+]2C1S... |
| 4 | CC1OC1[P](=O)(=O)O | 1 | 1 | DILIst Classification Oral | DILI | UJNUDOLLRRCQDH-UHFFFAOYSA-N | UJNUDOLLRRCQDH | CC1OC1[P](=O)(=O)O |
[46]:
# 1. Direct SMILES matching (intersection)
smiles_overlap_values = set(cellarity_smiles_dili_dataset['smiles']).intersection(
set(dilirank_dp['smiles_r'])
)
smiles_overlap_mask = cellarity_smiles_dili_dataset['smiles'].isin(
smiles_overlap_values
)
# 1b. Direct SMILES matching for protonated SMILES (intersection)
smiles_proto_overlap_values = set(cellarity_smiles_dili_dataset['smiles']).intersection(
set(dilirank_dp['protonated_smiles_r'])
)
smiles_proto_overlap_mask = cellarity_smiles_dili_dataset['smiles'].isin(
smiles_proto_overlap_values
)
# 2. InChiKey14 based overlap (intersection)
cellarity_inchikey14_lower = cellarity_smiles_dili_dataset['InChiKey14'].astype(str)
dilirank_dp_inchikey14_lower = dilirank_dp['InChIKey14'].astype(str)
inchikey14_overlap_values = set(cellarity_inchikey14_lower).intersection(
set(dilirank_dp_inchikey14_lower)
)
inchikey14_overlap_mask = cellarity_inchikey14_lower.isin(inchikey14_overlap_values)
# Combine all overlap masks
overlap_mask = smiles_overlap_mask | smiles_proto_overlap_mask | inchikey14_overlap_mask
# Subset to compounds that do NOT overlap
dilipr_unseeen_smiles = cellarity_smiles_dili_dataset[~overlap_mask].copy()
print(f'Unseen compounds: {len(dilipr_unseeen_smiles)}')
Unseen compounds: 483
[47]:
dilipr_unseeen_smiles['label'] = dilipr_unseeen_smiles['DILI_label_binary'].map(
{'DILI': 1, 'No DILI': 0}
)
# dilipr_unseeen_smiles.to_csv('./insilico_benchmarks/smiles_data/unseen_smiles_dilipr_483.csv')
4.2 Run predictions using reproducible scripts
Run this command in terminal and follow the instructions as recommended by the code python ./insilico_benchmarks/run_dilipredictor.py -i ./insilico_benchmarks/smiles_data/unseen_smiles_dilipr_483.csv -o ./insilico_benchmarks/predictions/benchmark_dilipr_483_model_output.csv
[48]:
dilipr_test_pred = pd.read_csv(
'./insilico_benchmarks/predictions/benchmark_dilipr_483_model_output.csv'
)
dilipr_predictions = dilipr_unseeen_smiles.copy()
dilipr_predictions['DILIPredictor_prediction'] = dilipr_test_pred[
'DILIPRedictor_prediction'
].tolist()
dilipr_predictions['DILIPredictor_binary'] = (
dilipr_predictions['DILIPredictor_prediction'] == 'DILI'
)
dilipr_predictions['DILIPredictor_probability'] = dilipr_test_pred[
'DILIPRedictor_probability'
].tolist()
# dilipr_predictions.to_csv('predictions/benchmark_dilipr_483_predictions.csv')
dilipr_predictions.head()
[48]:
| LTKBID | compound_name | DILIrank | label_section | severity_class | DILIst | roa | source | livertox_score | livertox_primary_classification | ... | DILI_label_binary | DILI_label_section | livertox_updated | livertox_iDILI | livertox_mechanism_summary | InChiKey14 | label | DILIPredictor_prediction | DILIPredictor_binary | DILIPredictor_probability | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Acetylcholine chloride | LT02096 | Acetylcholine chloride | No-DILI-Concern | No match | 0.0 | NaN | Oral | DILIrank | NaN | NaN | ... | No DILI | No DILI | NaN | NaN | NaN | JUGOREOARAHOCO | 0 | No DILI | False | 0.3382606667842065 |
| Alatrofloxacin mesylate | LT00461 | Alatrofloxacin mesylate | Most-DILI-Concern | Withdrawn | 8.0 | NaN | Oral | DILIrank | NaN | NaN | ... | DILI | Withdrawn | NaN | NaN | NaN | CYETUYYEVKNSHZ | 1 | DILI | True | 0.7504044838184715 |
| Allopurinol | LT00043 | Allopurinol | Most-DILI-Concern | Warnings and precautions | 8.0 | NaN | Oral | DILIrank | A | Rheumatologic | ... | DILI | Known | 2020 Dec 26 | NaN | Hepatitis, Cholestasis/Biliary, Immune-mediate... | OFCNXPDARWKPPY | 1 | DILI | True | 0.7283365728907885 |
| Amifostine | LT01067 | Amifostine | No-DILI-Concern | No match | 0.0 | NaN | Intravenous | DILIrank | NaN | Toxicology | ... | No DILI | No DILI | NaN | NaN | NaN | JKOQGQFVAUAYPM | 0 | No DILI | False | 0.49992585412698437 |
| Amikacin | LT01068 | Amikacin | No-DILI-Concern | No match | 0.0 | NaN | Intravenous | DILIrank | E | Antimicrobial | ... | No DILI | No DILI | 2019 Apr 12 | NaN | NaN | LKCWBDHBTVXHDL | 0 | No DILI | False | 0.48674586244480567 |
5 rows × 65 columns
Removing failed predictions
[49]:
prob_series = dilipr_predictions['DILIPredictor_probability']
error_mask = prob_series.astype(str).str.contains('Error', case=False, na=False)
error_indices = error_mask[error_mask].index
print(f'Found {len(error_indices)} error entries at indices: {error_indices.tolist()}')
if len(error_indices) > 0:
# Remove rows with errors
dilipr_predictions_clean = dilipr_predictions[~error_mask].copy()
print(f'\nOriginal dataframe shape: {dilipr_predictions.shape}')
print(f'Cleaned dataframe shape: {dilipr_predictions_clean.shape}')
# Update the main dataframe
dilipr_predictions = dilipr_predictions_clean
else:
print('No error entries found')
prob_series_clean = dilipr_predictions['DILIPredictor_probability']
non_numeric_count = pd.to_numeric(prob_series_clean, errors='coerce').isna().sum()
print(f'Error output left : {non_numeric_count}')
Found 6 error entries at indices: ['Exenatide', 'Rubidium chloride rb-82', 'Arsenic', 'Fluoride', 'Lithium', 'Magnesium']
Original dataframe shape: (483, 65)
Cleaned dataframe shape: (477, 65)
Error output left : 0
Removing 6 more compounds seen by DILIPredictor
[50]:
seen_inchikeys = [
'VOVIALXJUBGFJZ',
'VPNYRYCIDCJBOM',
'HYIMSNHJOBLJNT',
'KVWDHTXUZHCGIO',
'AFNTWHMDBNQQPX',
'WYHIICXRPHEJKI',
]
dilipr_predictions = dilipr_predictions[
~dilipr_predictions.InChiKey14.isin(seen_inchikeys)
].copy()
dilipr_predictions.shape
[50]:
(471, 65)
[51]:
# dilipr_predictions.to_csv('./insilico_benchmarks/predictions/benchmark_dilipr_471_predictions.csv')
4.3 Evaluate performance on (potentially) unseen smiles (471 compounds)
[52]:
# dilipr_predictions = pd.read_csv('./insilico_benchmarks/predictions/benchmark_dilipr_471_predictions.csv',index_col=0)
dilipr_predictions.head()
[52]:
| LTKBID | compound_name | DILIrank | label_section | severity_class | DILIst | roa | source | livertox_score | livertox_primary_classification | ... | DILI_label_binary | DILI_label_section | livertox_updated | livertox_iDILI | livertox_mechanism_summary | InChiKey14 | label | DILIPredictor_prediction | DILIPredictor_binary | DILIPredictor_probability | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Acetylcholine chloride | LT02096 | Acetylcholine chloride | No-DILI-Concern | No match | 0.0 | NaN | Oral | DILIrank | NaN | NaN | ... | No DILI | No DILI | NaN | NaN | NaN | JUGOREOARAHOCO | 0 | No DILI | False | 0.3382606667842065 |
| Alatrofloxacin mesylate | LT00461 | Alatrofloxacin mesylate | Most-DILI-Concern | Withdrawn | 8.0 | NaN | Oral | DILIrank | NaN | NaN | ... | DILI | Withdrawn | NaN | NaN | NaN | CYETUYYEVKNSHZ | 1 | DILI | True | 0.7504044838184715 |
| Allopurinol | LT00043 | Allopurinol | Most-DILI-Concern | Warnings and precautions | 8.0 | NaN | Oral | DILIrank | A | Rheumatologic | ... | DILI | Known | 2020 Dec 26 | NaN | Hepatitis, Cholestasis/Biliary, Immune-mediate... | OFCNXPDARWKPPY | 1 | DILI | True | 0.7283365728907885 |
| Amifostine | LT01067 | Amifostine | No-DILI-Concern | No match | 0.0 | NaN | Intravenous | DILIrank | NaN | Toxicology | ... | No DILI | No DILI | NaN | NaN | NaN | JKOQGQFVAUAYPM | 0 | No DILI | False | 0.49992585412698437 |
| Amikacin | LT01068 | Amikacin | No-DILI-Concern | No match | 0.0 | NaN | Intravenous | DILIrank | E | Antimicrobial | ... | No DILI | No DILI | 2019 Apr 12 | NaN | NaN | LKCWBDHBTVXHDL | 0 | No DILI | False | 0.48674586244480567 |
5 rows × 65 columns
[53]:
dilipr_predictions.DILI_label_binary.value_counts()
[53]:
DILI_label_binary
No DILI 373
DILI 98
Name: count, dtype: int64
[54]:
# Calculate performance for DILIPredictor
df = dilipr_predictions.copy()
# Define label mapping if not already defined
label_mapping = {'DILI': 1, 'No DILI': 0}
df['True_label'] = df['DILI_label_binary'].map({'DILI': True, 'No DILI': False})
df['True_label_numeric'] = df['DILI_label_binary'].map(label_mapping)
# Remove any NaN values for AUROC calculations
valid_mask = ~(df['True_label_numeric'].isna() | df['DILIPredictor_probability'].isna())
df_clean = df[valid_mask]
# DILIPredictor (new)
dilipr_performance = calculate_comprehensive_metrics(
df_clean, 'True_label', ['DILIPredictor_binary']
)
dilipr_numeric_cols = dilipr_performance.select_dtypes(include=[np.number]).columns
dilipr_auroc = roc_auc_score(
df_clean['True_label_numeric'], df_clean['DILIPredictor_probability']
)
dilipr_summary = pd.DataFrame(
{
'Model': ['DILIPredictor'],
**{
col: [f'{dilipr_performance[col].iloc[0]:.3f}']
for col in dilipr_numeric_cols
},
}
)
print('\nDILIPredictor Performance Summary:')
display(dilipr_summary)
DILIPredictor Performance Summary:
| Model | Balanced Accuracy | Specificity | Sensitivity | |
|---|---|---|---|---|
| 0 | DILIPredictor | 0.646 | 0.405 | 0.888 |
4.4 Benchmark DILIPredictor against ToxPredictor (30 compounds)
[55]:
overlap_with_df_res_all = set(dilipr_predictions['compound_name']).intersection(
set(df_res_all['compound_name'])
)
len(overlap_with_df_res_all)
[55]:
30
[56]:
df_res_all_subset = df_res_all[
~(
(df_res_all.compound_name == 'Chlorpromazine')
& (df_res_all.Batch == 'cross-validation')
)
].copy()
df_res_all_subset = df_res_all_subset[
df_res_all_subset.index.isin(overlap_with_df_res_all)
].copy()
[57]:
dilipr_dmap = dilipr_predictions.copy()
dilipr_dmap = dilipr_dmap[dilipr_dmap.index.isin(overlap_with_df_res_all)].copy()
# Update dilipr_dmap with values from df_res_all
for col in ['True_label', 'Predicted_label', 'MOS_ToxPredictor']:
dilipr_dmap[col] = dilipr_dmap['compound_name'].map(
df_res_all_subset.set_index('compound_name')[col]
)
dilipr_dmap['ToxPredictor_Inverse_MOS'] = 1 / dilipr_dmap['MOS_ToxPredictor']
dilipr_dmap.head()
[57]:
| LTKBID | compound_name | DILIrank | label_section | severity_class | DILIst | roa | source | livertox_score | livertox_primary_classification | ... | livertox_mechanism_summary | InChiKey14 | label | DILIPredictor_prediction | DILIPredictor_binary | DILIPredictor_probability | True_label | Predicted_label | MOS_ToxPredictor | ToxPredictor_Inverse_MOS | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Allopurinol | LT00043 | Allopurinol | Most-DILI-Concern | Warnings and precautions | 8.0 | NaN | Oral | DILIrank | A | Rheumatologic | ... | Hepatitis, Cholestasis/Biliary, Immune-mediate... | OFCNXPDARWKPPY | 1 | DILI | True | 0.7283365728907885 | True | True | 27.891861 | 0.035853 |
| Chlorzoxazone | LT00067 | Chlorzoxazone | Most-DILI-Concern | Warnings and precautions | 8.0 | NaN | Oral | DILIrank | B | CNS | ... | Cholestasis/Biliary, Hypersensitivity | TZFWDZFKRBELIQ | 1 | DILI | True | 0.7944024200636947 | True | True | 7.905361 | 0.126496 |
| Dapsone | LT00289 | Dapsone | Less-DILI-Concern | Warnings and precautions | 3.0 | NaN | Oral | DILIrank | A | Antimicrobial | ... | Cholestasis/Biliary, Hypersensitivity | MQJKPEGWNLWLTK | 1 | DILI | True | 0.7104704927300859 | True | False | 300.000000 | 0.003333 |
| Dichlorphenamide | LT00802 | Dichlorphenamide | No-DILI-Concern | No match | 0.0 | NaN | Oral | DILIrank | NaN | NaN | ... | NaN | GJQPMPFPNINLKP | 0 | DILI | True | 0.6486831077610885 | False | False | 100.712706 | 0.009929 |
| Doxycycline | LT00393 | Doxycycline | Less-DILI-Concern | Adverse reactions | 3.0 | NaN | Oral | DILIrank | B | Antimicrobial | ... | Hepatitis, Cholestasis/Biliary, Steatosis, Imm... | XQTWDDCIUJNLTR | 1 | DILI | True | 0.6560019565062682 | True | True | 65.874667 | 0.015180 |
5 rows × 69 columns
[58]:
# Convert True_label to boolean values for the combined dataset
y_true = dilipr_dmap['True_label'].astype(str)
y_true_bool = (y_true == 'True').astype(bool)
print('\nOverall class distribution in true labels:')
print(pd.Series(y_true).value_counts())
# Define prediction and probability columns for combined analysis
prediction_columns = ['Predicted_label', 'DILIPredictor_binary']
probability_columns = ['ToxPredictor_Inverse_MOS', 'DILIPredictor_probability']
# Calculate comprehensive metrics for combined data
combined_performance_df = calculate_comprehensive_metrics(
dilipr_dmap, 'True_label', prediction_columns
)
# Rename model names in performance_df
combined_performance_df['Model'] = combined_performance_df['Model'].replace(
{
'Predicted Label': 'ToxPredictor',
'Dilipredictor Binary': 'DILIPredictor',
}
)
print('\nDILIPredictor Model Performance on unseen data')
display(combined_performance_df)
Overall class distribution in true labels:
True_label
True 23
False 7
Name: count, dtype: int64
DILIPredictor Model Performance on unseen data
| Model | Balanced Accuracy | Specificity | Sensitivity | |
|---|---|---|---|---|
| 0 | ToxPredictor | 0.792 | 0.714 | 0.87 |
| 1 | DILIPredictor | 0.571 | 0.143 | 1.00 |
[59]:
dilipr_dmap['DILIPredictor_Label'] = dilipr_dmap['DILIPredictor_binary'].astype(bool)
dilipr_dmap['True_label'] = dilipr_dmap['True_label'].astype(str)
print_confusion_matrix_analysis(
dilipr_dmap,
true_label_col='True_label',
pred_label_col='DILIPredictor_Label',
)
print_confusion_matrix_analysis(dilipr_dmap)
True_label distribution:
True_label
True 23
False 7
Name: count, dtype: int64
Confusion matrix: for DILIPredictor_Label
True Positives (TP): 23 / Total DILI 23
True Negatives (TN): 1/ Total Non-DILI 7
True_label distribution:
True_label
True 23
False 7
Name: count, dtype: int64
Confusion matrix: for Predicted_label
True Positives (TP): 20 / Total DILI 23
True Negatives (TN): 5/ Total Non-DILI 7
5 Evaluate performance of DILIGeNN and DILIPredictor on unseen smiles (314 smiles)
[60]:
## DILIPR vs DILIGeNN overlap
# Plot overlap between all_dilipr_seen_inchikey14, df_std_filtered['InChiKey14'] and df_dg['InChiKey14']
# Convert to sets for overlap analysis
dilipr_set = set(dilipr_predictions.index)
dmap_rnaseq = set(df_res_all.index)
diligenn_set = set(combined_df.compound_name)
# Create Venn diagram
plt.figure(figsize=(10, 8))
venn = venn3(
[dilipr_set, dmap_rnaseq, diligenn_set],
('DILIPredictor Unseen', 'Dilimap RNA-seq', 'DILIGeNN Unseen'),
)
plt.title('Overlap of InChiKey14 between datasets')
plt.show()
# Print overlap statistics
print(f'DILIPredictor unseen compounds: {len(dilipr_set)}')
print(f'DMAP transcriptomics smiles: {len(dmap_rnaseq)}')
print(f'DILIGeNN unseen compounds: {len(diligenn_set)}')
print()
# Calculate intersections
dilipr_dmap_overlap = dilipr_set & dmap_rnaseq
dilipr_dg_overlap = dilipr_set & diligenn_set
dmap_dg_overlap = dmap_rnaseq & diligenn_set
all_three_overlap = dilipr_set & dmap_rnaseq & diligenn_set
print(f'DILIPredictor Unseen & DMAP overlap: {len(dilipr_dmap_overlap)}')
print(f'DILIPredictor Unseen & DILiGeNN Unseen overlap: {len(dilipr_dg_overlap)}')
print(f'DILiGeNN Unseen & DMAP overlap: {len(dmap_dg_overlap)}')
print(
f'DILIPredictor Unseen & DILiGeNN Unseen overlap & DMAP: {len(all_three_overlap)}'
)
DILIPredictor unseen compounds: 471
DMAP transcriptomics smiles: 299
DILIGeNN unseen compounds: 331
DILIPredictor Unseen & DMAP overlap: 30
DILIPredictor Unseen & DILiGeNN Unseen overlap: 314
DILiGeNN Unseen & DMAP overlap: 8
DILIPredictor Unseen & DILiGeNN Unseen overlap & DMAP: 7
[61]:
overlap_cmpds = set(df_clean.compound_name).intersection(set(combined_df.compound_name))
overlap_df = df_clean[df_clean.index.isin(overlap_cmpds)].copy()
# Add columns from combined_df for warmstart model predictions
overlap_df = overlap_df.merge(
combined_df[
['compound_name', 'DILIGeNN_avg_prob_warmstart', 'DILIGeNN_mode_pred_warmstart']
],
on='compound_name',
how='left',
)
overlap_df.head()
[61]:
| LTKBID | compound_name | DILIrank | label_section | severity_class | DILIst | roa | source | livertox_score | livertox_primary_classification | ... | livertox_mechanism_summary | InChiKey14 | label | DILIPredictor_prediction | DILIPredictor_binary | DILIPredictor_probability | True_label | True_label_numeric | DILIGeNN_avg_prob_warmstart | DILIGeNN_mode_pred_warmstart | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | LT00964 | Anagrelide | Ambiguous DILI-concern | Warnings and precautions | 3.0 | NaN | NaN | DILIrank | E* | Hematologic | ... | NaN | OTBXOEAOVRKTNQ | 0 | DILI | True | 0.7810926747670023 | False | 0 | 0.876281 | 1.0 |
| 1 | LT01195 | Aprepitant | Ambiguous DILI-concern | Adverse reactions | 3.0 | NaN | NaN | DILIrank | E | Gastrointestinal | ... | NaN | ATALOFNDEOCMKK | 0 | DILI | True | 0.7342626234579888 | False | 0 | 0.357398 | 0.0 |
| 2 | LT00707 | Butabarbital | Ambiguous DILI-concern | Adverse reactions | 3.0 | NaN | NaN | DILIrank | E | CNS | ... | NaN | ZRIHAIZYIMGOAB | 0 | No DILI | False | 0.43585037248699443 | False | 0 | 0.909950 | 1.0 |
| 3 | LT01048 | Cevimeline | Ambiguous DILI-concern | Adverse reactions | 3.0 | NaN | NaN | DILIrank | E | Opthalmologic | ... | NaN | WUTYZMFRCNBCHQ | 0 | No DILI | False | 0.5858185692685399 | False | 0 | 0.499214 | 0.0 |
| 4 | LT00065 | Chenodiol | Ambiguous DILI-concern | Warnings and precautions | 8.0 | NaN | NaN | DILIrank | E* | Gastrointestinal | ... | NaN | RUDATBOHQWOJDD | 0 | No DILI | False | 0.5862837665853369 | False | 0 | 0.747384 | 1.0 |
5 rows × 69 columns
[62]:
overlap_df.DILI_label_binary.value_counts()
[62]:
DILI_label_binary
No DILI 269
DILI 45
Name: count, dtype: int64
[63]:
# overlap_df.to_csv('./insilico_benchmarks/predictions/benchmark_dilipr_diligenn_314_predictions.csv')
[64]:
# Use overlap_df which contains compounds present in both datasets
df = overlap_df.copy()
# Calculate performance for DILIGeNN Warm Start
diligenn_warmstart_performance = calculate_comprehensive_metrics(
df, 'True_label', ['DILIGeNN_mode_pred_warmstart']
)
# Calculate performance for DILIPredictor
dilipr_performance = calculate_comprehensive_metrics(
df, 'True_label', ['DILIPredictor_binary']
)
diligenn_numeric_cols = diligenn_warmstart_performance.select_dtypes(
include=[np.number]
).columns
dilipr_numeric_cols = dilipr_performance.select_dtypes(include=[np.number]).columns
# Calculate AUROC scores
diligenn_warmstart_auroc = roc_auc_score(
df['True_label_numeric'], df['DILIGeNN_avg_prob_warmstart']
)
dilipr_auroc = roc_auc_score(df['True_label_numeric'], df['DILIPredictor_probability'])
# Create performance summaries
diligenn_warmstart_summary = pd.DataFrame(
{
'Model': ['DILIGeNN Warm Start (GraphSAGE)'],
**{
col: [f'{diligenn_warmstart_performance[col].iloc[0]:.3f}']
for col in diligenn_numeric_cols
},
}
)
dilipr_summary = pd.DataFrame(
{
'Model': ['DILIPredictor'],
**{
col: [f'{dilipr_performance[col].iloc[0]:.3f}']
for col in dilipr_numeric_cols
},
}
)
# Combine summaries
combined_summary = pd.concat(
[diligenn_warmstart_summary, dilipr_summary], ignore_index=True
)
print('\nModel Performance Comparison:')
display(combined_summary)
# Plot ROC curves for both models
plot_roc_curves(
df,
'True_label',
['DILIGeNN_avg_prob_warmstart', 'DILIPredictor_probability'],
model_names=['DILIGeNN Warm Start (GraphSAGE)', 'DILIPredictor'],
)
Model Performance Comparison:
| Model | Balanced Accuracy | Specificity | Sensitivity | |
|---|---|---|---|---|
| 0 | DILIGeNN Warm Start (GraphSAGE) | 0.562 | 0.279 | 0.844 |
| 1 | DILIPredictor | 0.545 | 0.290 | 0.800 |
6. Run TxGemma on 715 unseen smiles
6.1 Load data and identify unseen subset using InChiKey14 and SMILES
[65]:
df_txg = pd.read_csv('./insilico_benchmarks/smiles_data/txgemma_train.csv')
[66]:
df_txg['InChiKey14'] = [smiles_to_inchikey14(smiles) for smiles in df_txg['Drug']]
df_txg.head()
[66]:
| Drug_ID | Drug | Y | InChiKey14 | |
|---|---|---|---|---|
| 0 | 187.0 | CC(=O)OCC[N+](C)(C)C | 0.0 | OIPILFWXSMYKGL |
| 1 | 247.0 | C[N+](C)(C)CC(=O)[O-] | 0.0 | KWIUHFFTVRNATP |
| 2 | 298.0 | O=C(NC(CO)C(O)c1ccc([N+](=O)[O-])cc1)C(Cl)Cl | 0.0 | WIIZWVCIJKGZOK |
| 3 | 338.0 | O=C(O)c1ccccc1O | 0.0 | YGSDEFSMJLZEOE |
| 4 | 444.0 | CC(NC(C)(C)C)C(=O)c1cccc(Cl)c1 | 0.0 | SNPPWIUOZRMYNY |
[67]:
# 1. Direct SMILES matching (intersection)
smiles_overlap_values = set(cellarity_smiles_dili_dataset['smiles']).intersection(
set(df_txg['Drug'])
)
smiles_overlap_mask = cellarity_smiles_dili_dataset['smiles'].isin(
smiles_overlap_values
)
# 2. InChiKey14 based overlap (intersection)
cellarity_inchikey14_lower = cellarity_smiles_dili_dataset['InChiKey14'].astype(str)
df_txg_inchikey14_lower = df_txg['InChiKey14'].astype(str)
inchikey14_overlap_values = set(cellarity_inchikey14_lower).intersection(
set(df_txg_inchikey14_lower)
)
inchikey14_overlap_mask = cellarity_inchikey14_lower.isin(inchikey14_overlap_values)
overlap_mask = smiles_overlap_mask | inchikey14_overlap_mask
txgemma_unseeen_smiles = cellarity_smiles_dili_dataset[~overlap_mask].copy()
print(f'Number of unseen compounds: {len(txgemma_unseeen_smiles)}')
Number of unseen compounds: 715
[68]:
# Plot overlap between txgemma_unseeen_smiles and df_res_all
overlap_with_df_res_all = set(txgemma_unseeen_smiles['compound_name']).intersection(
set(df_res_all['compound_name'])
)
print(
f'Overlap between txgemma_unseeen_smiles and df_res_all: {len(overlap_with_df_res_all)} compounds'
)
# Create a visualization of the overlap
plt.figure(figsize=(8, 6))
venn2(
[set(txgemma_unseeen_smiles['compound_name']), set(df_res_all['compound_name'])],
set_labels=('TxGemma Unseen', 'df_res_all'),
)
plt.title('Overlap between TxGemma Unseen Compounds and df_res_all')
plt.show()
Overlap between txgemma_unseeen_smiles and df_res_all: 97 compounds
[69]:
# txgemma_unseeen_smiles.to_csv('./insilico_benchmarks/smiles_data/unseen_smiles_txgemma_715.csv')
6.2 Run predictions using reproducible scripts
See ./insilico_benchmarks/run_txgemma_dili.py and Readme for more details
Results are available here: ./insilico_benchmarks/predictions/benchmark_txgemma_all_models_predictions.csv
[70]:
txgemma_dili_predictions = pd.read_csv(
'./insilico_benchmarks/predictions/benchmark_txgemma_all_variants_715_predictions.csv',
index_col=0,
)
txgemma_dili_predictions['True_label'] = txgemma_dili_predictions[
'DILI_label_binary'
].map({'DILI': True, 'No DILI': False})
label_mapping = {'DILI': 1, 'No DILI': 0}
txgemma_dili_predictions['True_label_numeric'] = txgemma_dili_predictions[
'DILI_label_binary'
].map(label_mapping)
txgemma_dili_predictions.head()
[70]:
| LTKBID | compound_name | DILIrank | label_section | severity_class | DILIst | roa | source | livertox_score | livertox_primary_classification | ... | livertox_mechanism_summary | InChiKey14 | TxGemma2B_predicted_DILI_Label_output | TxGemma9B_predicted_DILI_Label_output | TxGemma27B_predicted_DILI_Label_output | TxGemma2B_predicted_DILI_Label | TxGemma9B_predicted_DILI_Label | TxGemma27B_predicted_DILI_Label | True_label | True_label_numeric | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Unnamed: 0 | |||||||||||||||||||||
| 4-aminosalicylic acid | LT00505 | 4-aminosalicylic acid | Most-DILI-Concern | Warnings and precautions | 5.0 | NaN | Oral | DILIrank | NaN | NaN | ... | NaN | WUBBRNOQWQTFEX | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | DILI | DILI | DILI | True | 1 |
| Acetaminophen | LT00004 | Acetaminophen | Most-DILI-Concern | Warnings and precautions | 5.0 | NaN | Oral | DILIrank | A [HD] | Analgesic | ... | CYP, Hypersensitivity | RZVAJINKPMORJF | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | DILI | DILI | DILI | True | 1 |
| Acetylcholine chloride | LT02096 | Acetylcholine chloride | No-DILI-Concern | No match | 0.0 | NaN | Oral | DILIrank | NaN | NaN | ... | NaN | JUGOREOARAHOCO | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | No DILI | DILI | No DILI | False | 0 |
| Acetylcysteine | LT01041 | Acetylcysteine | No-DILI-Concern | No match | 0.0 | NaN | Oral | DILIrank | E | Toxicology | ... | NaN | PWKSKIMOESPYIA | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | DILI | DILI | DILI | False | 0 |
| Acitretin | LT01042 | Acitretin | Most-DILI-Concern | Box warning | 5.0 | NaN | Oral | DILIrank | B | Dermatologic | ... | Hepatitis, Cholestasis/Biliary, Hypersensitivity | IHUNBGSDBOWDMA | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | DILI | DILI | DILI | True | 1 |
5 rows × 69 columns
6.3 Evaluate predictions from TxGemma 2B predict, 9B predict and 27B predict variants on unseen smiles (715 compounds)
[71]:
y_true = txgemma_dili_predictions['True_label'].astype(str)
y_true_bool = (y_true == 'True').astype(bool)
print('\nClass distribution in true labels:')
print(pd.Series(y_true).value_counts())
# Define prediction columns for TxGemma models and ToxPredictor
prediction_columns = [
'Predicted_label',
'TxGemma2B_predicted_DILI_Label',
'TxGemma9B_predicted_DILI_Label',
'TxGemma27B_predicted_DILI_Label',
]
# Preprocess TxGemma predictions before calling calculate_comprehensive_metrics
for pred_col in prediction_columns:
if pred_col in txgemma_dili_predictions.columns and 'TxGemma' in pred_col:
txgemma_dili_predictions[pred_col] = (
txgemma_dili_predictions[pred_col].astype(str).str.upper() == 'DILI'
).astype(str)
# Calculate comprehensive metrics
performance_df = calculate_comprehensive_metrics(
txgemma_dili_predictions,
'True_label',
prediction_columns,
)
performance_df['Model'] = performance_df['Model'].replace(
{
'Predicted Label': 'ToxPredictor',
'Txgemma2B Predicted Dili Label': 'TxGemma 2B predict',
'Txgemma9B Predicted Dili Label': 'TxGemma 9B predict',
'Txgemma27B Predicted Dili Label': 'TxGemma 27B predict',
}
)
print('\nTxGemma Model Performance:')
display(performance_df)
Class distribution in true labels:
True_label
False 537
True 178
Name: count, dtype: int64
TxGemma Model Performance:
| Model | Balanced Accuracy | Specificity | Sensitivity | |
|---|---|---|---|---|
| 0 | TxGemma 2B predict | 0.494 | 0.184 | 0.803 |
| 1 | TxGemma 9B predict | 0.468 | 0.223 | 0.713 |
| 2 | TxGemma 27B predict | 0.470 | 0.372 | 0.567 |
6.4 Evaluate predictions from TxGemma 2B predict, 9B predict and 27B predict against other insilico methods (303 compounds)
[72]:
# Plot overlap between txgemma_unseeen_smiles, overlap_df and df_res_all
overlap_with_df_res_all = set(txgemma_dili_predictions['compound_name']).intersection(
set(df_res_all['compound_name'])
)
print(f'Overlap between TxGemma and DiliMap: {len(overlap_with_df_res_all)} compounds')
# Add overlap_df to the mix
overlap_with_overlap_df = set(txgemma_dili_predictions['compound_name']).intersection(
set(overlap_df['compound_name'])
)
print(
f'Overlap between TxGemma, DILIPredictor and DILIGeNN unseen smiles: {len(overlap_with_overlap_df)} compounds'
)
# Find overlap between df_res_all and overlap_df
overlap_df_res_all_overlap_df = set(df_res_all['compound_name']).intersection(
set(overlap_df['compound_name'])
)
print(
f'Overlap between DILIMap and Unseen smiles overlap of DILIGeNN and DILIPredictor: {len(overlap_df_res_all_overlap_df)} compounds'
)
# Create a visualization of the three-way overlap
plt.figure(figsize=(8, 6))
venn3(
[
set(txgemma_dili_predictions['compound_name']),
set(df_res_all['compound_name']),
set(overlap_df['compound_name']),
],
set_labels=(
'TxGemma Unseen',
'Dilimap',
'Intersection of DILIGeNN and DILIPredictor Unseen',
),
)
plt.title('Overlap between TxGemma Unseen Compounds, df_res_all, and overlap_df')
plt.show()
Overlap between TxGemma and DiliMap: 97 compounds
Overlap between TxGemma, DILIPredictor and DILIGeNN unseen smiles: 303 compounds
Overlap between DILIMap and Unseen smiles overlap of DILIGeNN and DILIPredictor: 7 compounds
[73]:
# Subset overlap_df and txgemma_unseeen_smiles to the 303 compounds in overlap_with_overlap_df
overlap_df_subset = overlap_df[
overlap_df['compound_name'].isin(overlap_with_overlap_df)
].copy()
txgemma_subset = txgemma_dili_predictions[
txgemma_dili_predictions['compound_name'].isin(overlap_with_overlap_df)
].copy()
# Create new dataframe starting with overlap_df_subset
benchmark_df = overlap_df_subset.copy()
# Add TxGemma predictions to the benchmark dataframe
txgemma_cols = [
'TxGemma2B_predicted_DILI_Label',
'TxGemma9B_predicted_DILI_Label',
'TxGemma27B_predicted_DILI_Label',
]
for col in txgemma_cols:
if col in txgemma_subset.columns:
benchmark_df[col] = benchmark_df['compound_name'].map(
txgemma_subset.set_index('compound_name')[col]
)
y_true = benchmark_df['True_label'].astype(str)
y_true_bool = (y_true == 'True').astype(bool)
print('\nClass distribution in true labels:')
print(pd.Series(y_true).value_counts())
# Define prediction columns for TxGemma models, DILIPredictor and DILIGeNN
prediction_columns = [
'DILIPredictor_binary',
'DILIGeNN_mode_pred_warmstart',
'TxGemma2B_predicted_DILI_Label',
'TxGemma9B_predicted_DILI_Label',
'TxGemma27B_predicted_DILI_Label',
]
# Calculate comprehensive metrics (excluding AUC/AUROC)
performance_df = calculate_comprehensive_metrics(
benchmark_df,
'True_label',
prediction_columns,
)
performance_df['Model'] = performance_df['Model'].replace(
{
'Dilipredictor Binary': 'DILIPredictor',
'Diligenn Mode Pred Warmstart': 'DILIGeNN Warm Start (GraphSAGE)',
'Txgemma2B Predicted Dili Label': 'TxGemma 2B predict',
'Txgemma9B Predicted Dili Label': 'TxGemma 9B predict',
'Txgemma27B Predicted Dili Label': 'TxGemma 27B predict',
}
)
print('\nModel Performance Comparison:')
display(performance_df)
Class distribution in true labels:
True_label
False 261
True 42
Name: count, dtype: int64
Model Performance Comparison:
| Model | Balanced Accuracy | Specificity | Sensitivity | |
|---|---|---|---|---|
| 0 | DILIPredictor | 0.553 | 0.272 | 0.833 |
| 1 | DILIGeNN Warm Start (GraphSAGE) | 0.576 | 0.272 | 0.881 |
| 2 | TxGemma 2B predict | 0.493 | 0.153 | 0.833 |
| 3 | TxGemma 9B predict | 0.423 | 0.180 | 0.667 |
| 4 | TxGemma 27B predict | 0.410 | 0.272 | 0.548 |
6.5 Benchmark TxGemma 2B predict, 9B predict and 27B predict against ToxPredictor (97 compounds)
[74]:
df_res_all_subset = df_res_all[
~(
(df_res_all.compound_name == 'Chlorpromazine')
& (df_res_all.Batch == 'cross-validation')
)
].copy()
[75]:
overlap_with_df_res_all = set(txgemma_unseeen_smiles['compound_name']).intersection(
set(df_res_all['compound_name'])
)
df_res_all_subset = df_res_all_subset[
df_res_all_subset.index.isin(overlap_with_df_res_all)
].copy()
txgemma_dmap = txgemma_dili_predictions.copy()
txgemma_dmap = txgemma_dmap[txgemma_dmap.index.isin(overlap_with_df_res_all)].copy()
# Update txgemma_dmap with values from df_res_all
for col in ['True_label', 'Predicted_label', 'MOS_ToxPredictor']:
txgemma_dmap[col] = txgemma_dmap['compound_name'].map(
df_res_all_subset.set_index('compound_name')[col]
)
txgemma_dmap.head()
[75]:
| LTKBID | compound_name | DILIrank | label_section | severity_class | DILIst | roa | source | livertox_score | livertox_primary_classification | ... | TxGemma2B_predicted_DILI_Label_output | TxGemma9B_predicted_DILI_Label_output | TxGemma27B_predicted_DILI_Label_output | TxGemma2B_predicted_DILI_Label | TxGemma9B_predicted_DILI_Label | TxGemma27B_predicted_DILI_Label | True_label | True_label_numeric | Predicted_label | MOS_ToxPredictor | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Unnamed: 0 | |||||||||||||||||||||
| Acetaminophen | LT00004 | Acetaminophen | Most-DILI-Concern | Warnings and precautions | 5.0 | NaN | Oral | DILIrank | A [HD] | Analgesic | ... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | True | True | True | True | 1 | True | 15.116015 |
| Alaproclate | LT02213 | Alaproclate | Most-DILI-Concern | Withdrawn | 8.0 | NaN | Oral | DILIrank | NaN | NaN | ... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | False | False | False | True | 1 | True | 1.593458 |
| Almotriptan | LT01053 | Almotriptan | No-DILI-Concern | No match | 0.0 | NaN | Oral | DILIrank | E | CNS | ... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | False | True | True | False | 0 | False | 300.000000 |
| Amiodarone | LT00046 | Amiodarone | Most-DILI-Concern | Box warning | 8.0 | NaN | Oral | DILIrank | A | Cardiovascular | ... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | True | True | False | True | 1 | True | 6.453200 |
| Amoxicillin | LT00507 | Amoxicillin | Less-DILI-Concern | Adverse reactions | 5.0 | NaN | Oral | DILIrank | B | Antimicrobial | ... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | Prompt:\nInstructions: Answer the following qu... | True | True | True | True | 1 | True | 71.093220 |
5 rows × 71 columns
[76]:
y_true = txgemma_dmap['True_label'].astype(str)
y_true_bool = (y_true == 'True').astype(bool)
print('\nClass distribution in true labels:')
print(pd.Series(y_true).value_counts())
# Define prediction columns for TxGemma models and ToxPredictor
prediction_columns = [
'Predicted_label',
'TxGemma2B_predicted_DILI_Label',
'TxGemma9B_predicted_DILI_Label',
'TxGemma27B_predicted_DILI_Label',
]
# Calculate comprehensive metrics
performance_df = calculate_comprehensive_metrics(
txgemma_dmap,
'True_label',
prediction_columns,
)
performance_df['Model'] = performance_df['Model'].replace(
{
'Predicted Label': 'ToxPredictor',
'Txgemma2B Predicted Dili Label': 'TxGemma 2B predict',
'Txgemma9B Predicted Dili Label': 'TxGemma 9B predict',
'Txgemma27B Predicted Dili Label': 'TxGemma 27B predict',
}
)
print('\nTxGemma Model Performance:')
display(performance_df)
Class distribution in true labels:
True_label
True 62
False 35
Name: count, dtype: int64
TxGemma Model Performance:
| Model | Balanced Accuracy | Specificity | Sensitivity | |
|---|---|---|---|---|
| 0 | ToxPredictor | 0.808 | 0.857 | 0.758 |
| 1 | TxGemma 2B predict | 0.512 | 0.314 | 0.710 |
| 2 | TxGemma 9B predict | 0.579 | 0.400 | 0.758 |
| 3 | TxGemma 27B predict | 0.600 | 0.571 | 0.629 |
[77]:
txgemma_dmap.TxGemma27B_predicted_DILI_Label.value_counts()
[77]:
TxGemma27B_predicted_DILI_Label
True 54
False 43
Name: count, dtype: int64
[78]:
txgemma_dmap[txgemma_dmap.True_label == 'True'][
'TxGemma27B_predicted_DILI_Label'
].value_counts()
[78]:
TxGemma27B_predicted_DILI_Label
True 39
False 23
Name: count, dtype: int64
[79]:
txgemma_dmap[txgemma_dmap.True_label == 'True']['Predicted_label'].value_counts()
[79]:
Predicted_label
True 47
False 15
Name: count, dtype: int64
[80]:
txgemma_dmap[txgemma_dmap.True_label == 'False'][
'TxGemma27B_predicted_DILI_Label'
].value_counts()
[80]:
TxGemma27B_predicted_DILI_Label
False 20
True 15
Name: count, dtype: int64
[81]:
txgemma_dmap[txgemma_dmap.True_label == 'False']['Predicted_label'].value_counts()
[81]:
Predicted_label
False 30
True 5
Name: count, dtype: int64
[ ]: