📥 Download all notebooks

2.3. Training: ToxPredictor model

📘 Overview

This notebook trains multiple ML models on the DILImap training data, performs hyperparameter optimization, and selects a final ensemble model (random forests) to be stored as ToxPredictor for downstream prediction tasks. Note, that we trained our models on scikit-learn 1.4.0 to ensure broad compatibility with downstream users and reproducibility across environments, while maintaining full compatibility with newer scikit-learn versions.

Inputs - Training data (e.g., pathway-level signatures with binarized DILI labels)
- Model configuration and search space definitions
Output
- Trained base models with performance metrics
- Optimal hyperparameters per model type
- Final ensemble model (random forests), saved to S3 - Suppl. Figures S2-S5 with model training details
[1]:
%%capture

!pip install scikit-learn==1.4.0
!pip install xgboost
!pip install lightgbm
!pip install matplotlib_venn
[2]:
import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score
from concurrent.futures import ThreadPoolExecutor, as_completed

import dilimap as dmap
from dilimap.utils import groupby, crosstab
from dilimap.models import ToxPredictor
[3]:
%load_ext autoreload
%autoreload 2
[4]:
dmap.logging.print_version()
Running dilimap 1.0.4.dev0+gcb566f869.d20251007 (python 3.11.8) on 2025-10-07 18:01.

1a. Pull training data and map DILI labels

[5]:
adata = dmap.s3.read('training_data_pathways.h5ad')
Package: s3://dilimap/public/data. Top hash: 155a2b3b63
[6]:
dmap.utils.map_dili_labels_and_cmax(adata)
Package: s3://dilimap/public/data. Top hash: 155a2b3b63
Package: s3://dilimap/public/data. Top hash: 155a2b3b63
Package: s3://dilimap/public/data. Top hash: 155a2b3b63
[7]:
df_groupby = groupby(adata, 'compound_name')

cols = [
    'Most-DILI-Concern',
    'Less-DILI-Concern',
    'Ambiguous DILI-concern',
    'No-DILI-Concern',
]
crosstab(df_groupby, ['livertox_score', 'DILIrank'], np.sum).replace(np.nan, '')[cols]
[7]:
DILIrank Most-DILI-Concern Less-DILI-Concern Ambiguous DILI-concern No-DILI-Concern
livertox_score
A 27 9 0 0
A [HD] 1 2 0 0
B 13 28 0 1
C 11 34 0 1
D 3 27 0 3
E 0 10 1 31
E* 0 7 0 7
[8]:
print(f'Total no. of training compounds: {len(df_groupby)}')
df_groupby['DILI_label'].value_counts(sort=False)
Total no. of training compounds: 249
[8]:
DILI_label
DILI (likely)         73
DILI (few cases)      69
DILI (known)          25
DILI (withdrawn)      13
No DILI               52
No DILI (unlikely)    17
Name: count, dtype: int64

1b. Define dose-specific DILI training labels

[9]:
## Dose/cmpd-specific DILI annotations

dose_uM, cmax_uM = adata.obs['dose_uM'], adata.obs['Cmax_uM']
is_neg = adata.obs['DILI_label'].isin(['No DILI']) & (
    (dose_uM < 80 * cmax_uM) | (adata.obs['dose_level'] == 'Low')
)
is_pos = adata.obs['DILI_label'].isin(
    ['DILI (withdrawn)', 'DILI (known)', 'DILI (likely)']
) & ((dose_uM > 20 * cmax_uM) | (adata.obs['dose_level'] == 'High'))

adata.obs['DILI'] = np.where(is_neg, False, np.where(is_pos, True, np.nan))
[10]:
# Visualizing labeling strategy

df_crosstab = crosstab(adata, ['compound_name', 'dose_level', 'DILI'])
df_crosstab['DILI_label'] = df_crosstab.index.map(
    adata.obs.groupby('compound_name')['DILI_label'].first()
)
df_crosstab[:10].style.background_gradient(cmap='RdBu_r', vmin=-2, vmax=2.5).format(
    precision=2
).highlight_null('white')
[10]:
dose_level Low Middle Mid-High High DILI_label
compound_name          
Abacavir nan nan 1.00 1.00 DILI (likely)
Acarbose 1.00 1.00 1.00 1.00 DILI (likely)
Acetaminophen nan nan nan 1.00 DILI (known)
Alaproclate nan nan nan 1.00 DILI (withdrawn)
Albendazole nan nan 1.00 1.00 DILI (likely)
Allopurinol nan nan nan 1.00 DILI (known)
Alvimopan 0.00 nan nan nan No DILI
Ambrisentan 0.00 0.00 nan nan No DILI
Amiodarone nan nan nan 1.00 DILI (known)
Amoxicillin nan nan nan 1.00 DILI (likely)
[11]:
adata.obs['DILI'].value_counts()
[11]:
DILI
1.0    232
0.0     91
Name: count, dtype: int64

1c. Compound filtering

[12]:
df_cmpd_DEGs = adata.obs.groupby('compound_name', observed=True)['n_DEG'].max()
# exclude inactive cmpds that have at max 40 DEGs
adata.obs['valid_cmpd'] = ~adata.obs_names.isin(df_cmpd_DEGs.index[df_cmpd_DEGs < 10])

# exclude cmpds with insufficient dose range (top dose < Cmax)
df_cmpd_dose = adata.obs.groupby('compound_name', observed=True)[
    ['dose_uM', 'Cmax_uM']
].max()
adata.obs['valid_cmpd'] &= ~adata.obs_names.isin(
    df_cmpd_dose.index[df_cmpd_dose['dose_uM'] < df_cmpd_dose['Cmax_uM']]
)

# include cmpds with strong signal even if dose range is insufficient
adata.obs['valid_cmpd'] |= adata.obs_names.isin(df_cmpd_DEGs.index[df_cmpd_DEGs > 2000])
adata.obs['valid_cmpd'] |= ~adata.obs['LDH_IC10_uM'].isna()

print(
    'Compounds left post-filtering:',
    len(set(adata.obs['compound_name'][adata.obs['valid_cmpd']])),
    '/',
    len(set(adata.obs['compound_name'])),
)
Compounds left post-filtering: 246 / 249

1d. Feature filtering & data clipping

[13]:
# Remove isignificant pathways or pseudogenes

if True:  ## pathway signature
    adata.var['FDR_min'] = np.nanmin(adata.layers['FDR'], 0)
    print(
        f'Retaining {(adata.var["FDR_min"] < 1e-3).sum()}/{adata.n_vars} significant pathways'
    )
    adata = adata[:, adata.var['FDR_min'] < 1e-3]

else:  ## gene signatures
    adata.var['pseudogene'] = adata.var_names.str.contains(r'\.')
    print(f'Filtering out {adata.var["pseudogene"].sum()} pseudogenes')
    adata = adata[:, ~adata.var['pseudogene']].copy()
Retaining 186/472 significant pathways
[14]:
if False:  # consider clipping data when observing outliers in gene signature
    adata.X = np.nan_to_num(adata.X)  # set nans to zero

    lb, ub = np.percentile(adata.X, [2, 98], axis=0)
    adata.X = np.clip(adata.X, lb, ub)  # clip to [2,98] percentiles

    adata.X = adata.X * (adata.X > 1)  # clip to DES > 1 (pval < 0.1)
    adata.X = np.clip(adata.X, 1, 10)  # clip to DES < 10 (pval > 1e-10)

2a. Model selection and hyperparameter optimization

[15]:
random_state = seed = 42
[16]:
classifiers = {
    # Logistic Regression with L2 regularization
    'LogisticRegression': [
        dict(solver='lbfgs', penalty='l2', max_iter=1000, random_state=random_state),
        dict(C=[0.0001, 0.001, 0.01, 0.1, 1, 10, 100], class_weight=['balanced', None]),
    ],
    # Support Vector Classifier with RBF kernel
    'SVC': [
        dict(kernel='rbf', gamma='scale', probability=True, random_state=random_state),
        dict(C=[0.1, 1, 10]),
    ],
    # Random Forest with max_depth and regularization parameters
    'RandomForestClassifier': [
        dict(n_estimators=100, random_state=random_state),
        dict(
            max_depth=[1, 2, 3, 4, 5, None],
            min_samples_split=[2, 3],
            min_samples_leaf=[1, 2],
            class_weight=['balanced', None],
        ),
    ],
    # Gradient Boosting Trees with learning rate and depth tuning
    'GradientBoostingClassifier': [
        dict(n_estimators=100, random_state=random_state),
        dict(
            learning_rate=[0.01, 0.1],
            max_depth=[1, 2, 3, 4, 5],
            min_samples_split=[2, 3],
            min_samples_leaf=[1, 2],
        ),
    ],
    # Histogram-based Gradient Boosting (faster, scalable)
    'HistGradientBoostingClassifier': [
        dict(random_state=random_state),
        dict(
            learning_rate=[0.01, 0.1],
            max_depth=[1, 2, 3, 4, 5],
            l2_regularization=[0.0, 0.1, 1.0],
        ),
    ],
    # Multi-layer Perceptron with deeper architectures and tuning
    'MLPClassifier': [
        dict(max_iter=1000, early_stopping=True, random_state=random_state),
        dict(
            learning_rate_init=[0.001, 0.0005],
            alpha=[1e-4, 1e-3, 1e-2, 1e-1],
            hidden_layer_sizes=[(30,), (50,), (100,), (100, 50), (100, 100)],
        ),
    ],
    # XGBoost (gradient boosting, often among top performers)
    'XGBClassifier': [
        dict(eval_metric='logloss', random_state=random_state),
        dict(learning_rate=[0.01, 0.1], max_depth=[1, 2, 3, 4, 5]),
    ],
    # LightGBM (fast gradient boosting with histogram-based learning)
    'LGBMClassifier': [
        dict(verbosity=-1, random_state=random_state),
        dict(learning_rate=[0.01, 0.1], max_depth=[1, 2, 3, 4, 5]),
    ],
}
[17]:
# Data Preparation
adata_clean = adata.copy()
adata_clean.var_names = adata_clean.var_names.str.replace(r'[^\w\-]', '_', regex=True)

# Split labeled and unlabeled data
mask_labeled = (~adata_clean.obs['DILI'].isna()) & adata.obs['valid_cmpd']
adata_labeled = adata_clean[mask_labeled].copy()
adata_unlabeled = adata_clean[~mask_labeled].copy()

# Set meaningful obs_names for grouping
adata_labeled.obs_names = (
    adata_labeled.obs['compound_name'].astype(str)
    + '_'
    + adata_labeled.obs['dose_level'].astype(str)
).values

# Prepare inputs for model training
X_train = adata_labeled.to_df()
y_train = adata_labeled.obs['DILI']
group_labels = adata_labeled.obs['compound_name']
compound_dose_index = adata_labeled.obs_names

# For inter-fold correlation test: only high-dose samples
X_high_dose = adata_clean[
    (~adata_clean.obs['DILI'].isna()) & (adata_clean.obs['dose_level'] == 'High')
].to_df()
/opt/miniconda3/envs/py311/lib/python3.11/site-packages/anndata/_core/anndata.py:1758: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
  utils.warn_names_duplicates("obs")
/opt/miniconda3/envs/py311/lib/python3.11/site-packages/anndata/_core/anndata.py:1758: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
  utils.warn_names_duplicates("obs")
[18]:
# Model Evaluation Function

import random
from sklearn.metrics import (
    precision_score,
    recall_score,
    f1_score,
    accuracy_score,
    balanced_accuracy_score,
    average_precision_score,
)


def evaluate_model(
    tune_id,
    model_name,
    model_params,
    X_train,
    y_train,
    group_labels,
    compound_dose_index,
    X_high_dose,
    adata_unlabeled,
    n_splits=5,
):
    np.random.seed(seed)
    random.seed(seed)

    try:
        model_params = model_params.copy()
        model = ToxPredictor(model=model_name, model_params=model_params)

        # Train with cross-validation
        model.cross_validate(X_train, y_train, n_splits=n_splits, groups=group_labels)

        # Predict with CV and ensemble
        df_preds = pd.DataFrame({'DILI': y_train})
        df_preds['cv_pred'] = model.predict_proba_cv()[1].astype(float)
        df_preds['ens_pred'] = model.predict_proba(X_train)[1].astype(float)

        # Aggregate by compound-dose
        df_grouped = df_preds.groupby(compound_dose_index).mean()
        y_true = df_grouped['DILI']
        y_score = df_grouped['cv_pred']
        y_pred_bin = (y_score > 0.5).astype(int)

        # Inter-fold robustness
        y_preds_folds = np.stack(
            [m.predict_proba(X_high_dose)[:, -1] for m in model.estimators]
        )
        is_extreme = (
            y_preds_folds < np.percentile(y_preds_folds, 10, axis=1)[:, None]
        ) | (y_preds_folds > np.percentile(y_preds_folds, 90, axis=1)[:, None])
        interfold_robustness = np.corrcoef(
            y_preds_folds[:, np.any(is_extreme, axis=0)]
        ).min()

        # Dose monotonicity
        adata_unlabeled.obs['cv_pred'] = model.predict_proba(
            adata_unlabeled.to_df()
        ).iloc[:, -1]
        df_dose = crosstab(adata_unlabeled, ['compound_name', 'dose_level', 'cv_pred'])[
            ['Mid-High', 'High']
        ].dropna()
        dose_monotonicity = np.mean(df_dose['High'] > df_dose['Mid-High'])

        # Standard metrics
        metrics = {
            'AUC (validation)': roc_auc_score(y_true, y_score),
            'AUC (training)': roc_auc_score(y_true, df_grouped['ens_pred']),
            'Precision': precision_score(y_true, y_pred_bin),
            'Recall': recall_score(y_true, y_pred_bin),
            'F1 score': f1_score(y_true, y_pred_bin),
            'Accuracy': accuracy_score(y_true, y_pred_bin),
            'Balanced Accuracy': balanced_accuracy_score(y_true, y_pred_bin),
            'Average Precision': average_precision_score(y_true, y_score),
            'Inter-fold correlation': interfold_robustness,
            'Monotonic dose response': dose_monotonicity,
            'Classifier': model_name,
        }

        return tune_id, df_grouped, metrics

    except Exception as e:
        print(f'[{tune_id}] Evaluation failed: {e}')
        return tune_id, None, None
[19]:
# Hyperparameter Sweep
from itertools import product
from copy import deepcopy

np.random.seed(seed)
random.seed(seed)

# Generate hyperparameter tuning tasks
tasks = []
for model_name, (base_params, tunables) in classifiers.items():
    for combo in product(*tunables.values()):
        params = deepcopy(base_params)
        for k, v in zip(tunables.keys(), combo):
            params[k] = v
        tune_id = f'{model_name}__' + '__'.join(
            f'{k}={str(v).replace(" ", "")}' for k, v in zip(tunables.keys(), combo)
        )
        tasks.append((tune_id, model_name, params))

# Run evaluations in parallel
results = []
predictions = {}

with ThreadPoolExecutor(max_workers=8) as executor:
    futures = {
        executor.submit(
            evaluate_model,
            tune_id,
            model_name,
            model_params,
            X_train,
            y_train,
            group_labels,
            compound_dose_index,
            X_high_dose,
            adata_unlabeled,
        ): tune_id
        for (tune_id, model_name, model_params) in tasks
    }

    for future in as_completed(futures):
        tune_id = futures[future]
        try:
            tune_id, df_preds, metrics = future.result()
            if metrics:
                metrics['tune_id'] = tune_id
                predictions[tune_id] = df_preds
                results.append(metrics)
        except Exception as e:
            print(f'[{tune_id}] Future failed: {e}')

# Results summary
df_scores = pd.DataFrame(results).set_index('tune_id').reset_index()
[20]:
pd.DataFrame(results).set_index('tune_id').reset_index().to_csv('model_selection_results.csv')
[21]:
print(len(df_scores))
195
[22]:
# Filter failed runs
df_scores = df_scores[(df_scores['Precision'] > 0.1) & (df_scores['Recall'] > 0.1)]

# Filter models with strong performance and cross-validation stability
selected_models = df_scores[
    (df_scores['AUC (validation)'] > 0.7)
    & (df_scores['Precision'] > 0.8)
    & (df_scores['Recall'] > 0.6)
    & (df_scores['Monotonic dose response'] > 0.6)
    & (df_scores['Inter-fold correlation'] > 0.8)
]

# Select top 5 models per classifier based on validation AUC
top_models = (
    selected_models.sort_values('AUC (validation)', ascending=False)
    .groupby('Classifier')
    .head(5)
)
[23]:
import pandas as pd

# Select and rename columns for compact display
df_display = top_models[
    [
        'tune_id',
        'AUC (validation)',
        'AUC (training)',
        'Precision',
        'Recall',
        'F1 score',
        'Accuracy',
        'Balanced Accuracy',
        'Average Precision',
        'Inter-fold correlation',
        'Monotonic dose response',
    ]
].copy()

# Abbreviate column names
df_display.columns = [
    'tune_id',
    'AUC (val)',
    'AUC (train)',
    'Precision',
    'Recall',
    'F1',
    'Accuracy',
    'BA',
    'AP',
    'IC',
    'MD',
]

# Abbreviate model names
model_map = {
    'RandomForestClassifier': 'RF',
    'HistGradientBoostingClassifier': 'HistGB',
    'XGBClassifier': 'XGB',
    'LGBMClassifier': 'LGBM',
    'GradientBoostingClassifier': 'GB',
    'LogisticRegression': 'LogReg',
    'SVC': 'SVC',
    'MLPClassifier': 'MLP',
}
df_display['Model'] = (
    df_display['tune_id'].str.extract(r'([^\_]+Classifier)').iloc[:, 0].map(model_map)
)
df_display['Details'] = df_display['tune_id'].str.replace(
    r'^.*?Classifier__', '', regex=True
)

# Reorder columns for presentation
df_display = df_display[
    [
        'Model',
        'AUC (val)',
        'AUC (train)',
        'Precision',
        'Recall',
        'F1',
        'Accuracy',
        'BA',
        'AP',
        'IC',
    ]
]  #'Details',

# Plot table
fig, ax = plt.subplots(figsize=(11, 0.5 + 0.4 * len(df_display)))
ax.axis('off')

table = ax.table(
    cellText=df_display.round(2).values,
    colLabels=df_display.columns,
    cellLoc='center',
    loc='center',
)

table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.1, 1.4)  # slightly wider column spacing

plt.title('Model performance summary', fontsize=12, pad=10)
plt.tight_layout()
plt.show()
../_images/reproducibility_2.3_Training_ToxPredictor_Model_28_0.png
[24]:
id_selected_model = top_models.query(
    "tune_id == 'RandomForestClassifier__max_depth=2__min_samples_split=2__min_samples_leaf=1__class_weight=None'"
).index[0]

print(id_selected_model, top_models.loc[id_selected_model]['tune_id'].split('__'))
26 ['RandomForestClassifier', 'max_depth=2', 'min_samples_split=2', 'min_samples_leaf=1', 'class_weight=None']
[25]:
def plot_metrics(df, x, y, hue='Classifier', highlight_id=None):
    # Example: AUC vs. Dose Response
    plt.figure(figsize=(4, 3))
    sns.scatterplot(data=df, x=x, y=y, hue=hue, s=50)

    if highlight_id in df.index:
        row = df.loc[highlight_id]
        plt.scatter(
            row[x], row[y], marker='*', s=100, color='green', label='Selected method'
        )

    plt.xlabel(x)
    plt.ylabel(y)
    plt.grid(True)

    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.show()
[26]:
highlight_id = id_selected_model

plot_metrics(df_scores, 'Recall', 'Precision', highlight_id=highlight_id)
plot_metrics(
    df_scores, 'Inter-fold correlation', 'AUC (validation)', highlight_id=highlight_id
)
../_images/reproducibility_2.3_Training_ToxPredictor_Model_31_0.png
../_images/reproducibility_2.3_Training_ToxPredictor_Model_31_1.png

2b. Selection of number of folds for ensembling

[27]:
np.random.seed(seed)
random.seed(seed)

n_splits_list = [5, 10, 20, 30, 40, 50]
results = []
predictions = []

# Only evaluate RandomForest
model_name = 'RandomForestClassifier'
model_params = dict(
    random_state=random_state,
    n_estimators=100,
    max_depth=2,
    min_samples_split=2,
    min_samples_leaf=1,
)

# Create evaluation tasks
tasks = [
    (f'{model_name}__n_splits={n}', model_name, model_params, n) for n in n_splits_list
]

# Run evaluations in parallel
with ThreadPoolExecutor(max_workers=5) as executor:
    futures = {
        executor.submit(
            evaluate_model,
            tune_id,
            model_name,
            model_params,
            X_train,
            y_train,
            group_labels,
            compound_dose_index,
            X_high_dose,
            adata_unlabeled,
            n_splits,
        ): tune_id
        for (tune_id, model_name, model_params, n_splits) in tasks
    }

    for future in as_completed(futures):
        tune_id = futures[future]
        try:
            tune_id, df_preds, metrics = future.result()
            if metrics:
                metrics['tune_id'] = tune_id
                predictions.append((tune_id, df_preds))
                results.append(metrics)
        except Exception as e:
            print(f'[{tune_id}] Future failed: {e}')

# Convert results to DataFrame
df_scores_splits = pd.DataFrame(results).set_index('tune_id').reset_index()
[28]:
std_dict = {}
for n_splits in n_splits_list:
    model = ToxPredictor(model=model_name, model_params=model_params)
    model.cross_validate(X_train, y_train, n_splits=n_splits, groups=group_labels)
    std_dict[str(n_splits)] = (
        model.predict_proba_across_estimators(adata_unlabeled).std(1).values
    )
[29]:
import pandas as pd
import matplotlib.pyplot as plt

# Setup
df_melted = pd.DataFrame(std_dict).melt(var_name='x', value_name='y')
df_melted['x'] = df_melted['x'].astype(int)
x_vals = sorted(df_melted['x'].unique())
data_per_x = [df_melted[df_melted['x'] == x]['y'].values for x in x_vals]
df_scores_splits['n_splits'] = n_splits_list

# Create plot
fig, ax1 = plt.subplots(figsize=(5, 3))

# Matplotlib boxplot with correct numeric alignment
ax1.boxplot(
    data_per_x,
    positions=x_vals,
    widths=3,
    patch_artist=True,
    boxprops=dict(facecolor='lightgray', color='gray'),
    medianprops=dict(color='black'),
    whiskerprops=dict(color='gray'),
    capprops=dict(color='gray'),
    flierprops=dict(marker='o', markersize=2, linestyle='none', markerfacecolor='gray'),
)

ax1.set_ylabel('Standard deviation', color='black')
ax1.set_xlabel('Number of splits (n_splits)')
ax1.tick_params(axis='y', labelcolor='black')

# AUC line on secondary axis
ax2 = ax1.twinx()
ax2.plot(
    df_scores_splits['n_splits'],
    df_scores_splits['AUC (validation)'],
    marker='o',
    color='blue',
    label='AUC (validation)',
)
ax2.set_ylabel('AUC (validation)', color='blue')
ax2.tick_params(axis='y', labelcolor='blue')
ax2.set_ylim(0.5, 0.8)

ax1.set_title('Standard Deviation vs. AUC across n_splits')
fig.tight_layout()
plt.show()
../_images/reproducibility_2.3_Training_ToxPredictor_Model_35_0.png

Conclusion: We chose the 30-fold ensemble model for its optimal balance of performance and stability. While overall performance is similar across folds, 30-folds offer greater consistency and smoother predictions, making it best suited for ensembling the fold models.

2c. Selection of optimal MOS threshold

[30]:
from dilimap.models import ToxPredictor
[31]:
idx = (~adata.obs['DILI'].isna()) & adata.obs['valid_cmpd']
data = adata[idx].copy()

cmpd_dose = (
    data.obs['compound_name'].astype(str) + '_' + data.obs['dose_level'].astype(str)
)
X = data.to_df().set_index(cmpd_dose)
y, groups = data.obs['DILI'], data.obs['compound_name']
/opt/miniconda3/envs/py311/lib/python3.11/site-packages/anndata/_core/anndata.py:1758: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
  utils.warn_names_duplicates("obs")
[32]:
model_params = dict(
    random_state=random_state,
    n_estimators=100,
    max_depth=2,
    min_samples_split=2,
    min_samples_leaf=1,
)

model = ToxPredictor(model='rf', model_params=model_params)
[33]:
model.cross_validate(X, y, n_splits=30, groups=data.obs['compound_name'])
[34]:
data.obs['DILI_prob_cv'] = model.predict_cv().values
data.obs['DILI_prob_ens'] = model.predict(X).values
[35]:
df_res_probs = (
    data.obs[['DILI', 'DILI_prob_cv', 'DILI_prob_ens']]
    .astype(float)
    .groupby(cmpd_dose)
    .mean()
)
[36]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

thresholds = np.linspace(0.2, 0.9, 15)
sens, spec, aucs = [], [], []

for t in thresholds:
    pred = (df_res_probs['DILI_prob_cv'] > t).astype(int)
    tn, fp, fn, tp = confusion_matrix(df_res_probs['DILI'], pred).ravel()

    sens.append(tp / (tp + fn) if (tp + fn) else np.nan)
    spec.append(tn / (tn + fp) if (tn + fp) else np.nan)
    aucs.append(roc_auc_score(df_res_probs['DILI'], pred))

plt.figure(figsize=(5, 3.5))
plt.plot(thresholds, aucs, '--', label='AUC', color='grey', lw=2)
plt.plot(thresholds, sens, '--', label='Sensitivity', color='red', lw=2)
plt.plot(thresholds, spec, '--', label='Specificity', color='blue', lw=2)
plt.axvline(0.7, color='k', alpha=0.5)
plt.xlabel('DILI probability threshold')
plt.ylabel('Metric value')
plt.title('DILI threshold selection')
plt.legend()
plt.grid(True)
plt.show()
../_images/reproducibility_2.3_Training_ToxPredictor_Model_44_0.png
[37]:
df_res_margins = model.compute_safety_margin(data, y_thresh=0.7)
df_res_margins['DILI_label'] = df_res_margins.index.map(
    groupby(adata, 'compound_name')['DILI_label']
)
[38]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, roc_auc_score

# Binary DILI label: 1 = toxic, 0 = non-toxic
dili_labels = ['DILI (withdrawn)', 'DILI (known)', 'DILI (likely)', 'No DILI']
df = df_res_margins[df_res_margins['DILI_label'].isin(dili_labels)].copy()
df['DILI'] = df['DILI_label'].str.startswith('DILI').astype(int)

thresholds = np.linspace(10, 120, 56)
sens, spec, aucs = [], [], []

for t in thresholds:
    pred = (df['MOS_ToxPredictor'] < t).astype(int)
    tn, fp, fn, tp = confusion_matrix(df['DILI'], pred).ravel()

    sens.append(tp / (tp + fn) if (tp + fn) else np.nan)
    spec.append(tn / (tn + fp) if (tn + fp) else np.nan)
    try:
        aucs.append(roc_auc_score(df['DILI'], pred))
    except ValueError:
        aucs.append(np.nan)  # In case only one class is predicted

# Plot
plt.figure(figsize=(5, 3.5))
plt.plot(thresholds, aucs, '--', label='AUC', color='grey', lw=2)
plt.plot(thresholds, sens, '--', label='Sensitivity', color='red', lw=2)
plt.plot(thresholds, spec, '--', label='Specificity', color='blue', lw=2)
plt.axvline(80, color='k', alpha=0.5)
plt.xlabel('MOS threshold')
plt.ylabel('Metric value')
plt.title('MOS threshold selection')
plt.legend()
plt.grid(True)
plt.show()
../_images/reproducibility_2.3_Training_ToxPredictor_Model_46_0.png

2d. Selection of total Cmax over free Cmax

[39]:
adata_sub = adata[~adata.obs['free_Cmax_uM'].isna()].copy()

mult = adata_sub.obs['Cmax_uM'].mean() / adata_sub.obs['free_Cmax_uM'].mean()
vmax = 300

df_res_margins = model.compute_safety_margin(adata_sub, vmax=vmax)
df_res_margins['DILI_label'] = df_res_margins.index.map(
    groupby(adata_sub, 'compound_name')['DILI_label']
)
df_res_margins['n_DEG_mean'] = df_res_margins.index.map(
    groupby(adata_sub, 'compound_name')['n_DEG_mean']
)

df_res_margins['DILI'] = ~df_res_margins['DILI_label'].str.startswith('No DILI')
df_res_margins['MOS_total_Cmax'] = df_res_margins['MOS_ToxPredictor']

df_res_free = model.compute_safety_margin(adata_sub, cmax_col='free_Cmax_uM', vmax=vmax * mult)
df_res_margins['MOS_free_Cmax'] = df_res_free['MOS_ToxPredictor']
/opt/miniconda3/envs/py311/lib/python3.11/site-packages/anndata/_core/anndata.py:1758: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
  utils.warn_names_duplicates("obs")
[40]:
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import numpy as np

# Drop NaNs and non-positive values (log scale can't handle zero or negative)
df_filtered = adata.obs.dropna(subset=['Cmax_uM', 'free_Cmax_uM']).copy()
df_filtered = df_filtered[df_filtered['free_Cmax_uM'] > 0]

# Compute Pearson correlation on log-transformed values
log_x = np.log10(df_filtered['Cmax_uM'])
log_y = np.log10(df_filtered['free_Cmax_uM'])
r, p_value = pearsonr(log_x, log_y)

# Plot
plt.figure(figsize=(4, 3))
sns.regplot(
    x=log_x, y=log_y, ci=None, color='darkred', scatter_kws={'s': 40, 'color': 'k'}
)

plt.xlabel('Cmax (log10 uM)')
plt.ylabel('free_Cmax (log10 uM)')
plt.title(f'Pearson correlation (log-log): r = {r:.2f}')
plt.tight_layout()
plt.show()
../_images/reproducibility_2.3_Training_ToxPredictor_Model_49_0.png
[41]:
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Drop NaNs and non-positive values (log scale can't handle zero or negative)
df_filtered = df_res_margins.dropna(subset=['MOS_total_Cmax', 'MOS_free_Cmax']).copy()

# Compute Pearson correlation on log-transformed values
r, p_value = pearsonr(df_filtered['MOS_total_Cmax'], df_filtered['MOS_free_Cmax'])

# Plot
plt.figure(figsize=(4, 3))
sns.regplot(
    x=df_filtered['MOS_total_Cmax'],
    y=df_filtered['MOS_free_Cmax'],
    ci=None,
    color='darkred',
    scatter_kws={'s': 40, 'color': 'k'},
)

plt.xlabel('MOS using total Cmax')
plt.ylabel('MOS using free Cmax')
plt.title(f'Pearson correlation: r = {r:.2f}')
plt.tight_layout()
plt.show()
../_images/reproducibility_2.3_Training_ToxPredictor_Model_50_0.png
[42]:
df = df_res_margins[
    df_res_margins['DILI_label'].isin(
        ['DILI (withdrawn)', 'DILI (known)', 'DILI (likely)', 'No DILI']
    )
]
y_true = ~df['DILI_label'].isin(['No DILI'])

plt.figure(figsize=(5, 3))
dmap.pl.roc_curve(
    y_true, df['MOS_total_Cmax'], label='total Cmax', inverse=True, show=False
)
dmap.pl.roc_curve(
    y_true,
    df['MOS_free_Cmax'],
    label='free Cmax',
    color='orange',
    inverse=True,
    show=True,
)
../_images/reproducibility_2.3_Training_ToxPredictor_Model_51_0.png
[43]:
adata_dilimap = dmap.datasets.DILImap_data()
Package: s3://dilimap/public/data. Top hash: 155a2b3b63
Package: s3://dilimap/public/data. Top hash: 155a2b3b63
/opt/miniconda3/envs/py311/lib/python3.11/site-packages/anndata/_core/merge.py:1309: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
  concat_annot = pd.concat(
[44]:
dmap.utils.map_dili_labels_and_cmax(adata_dilimap)
Package: s3://dilimap/public/data. Top hash: 155a2b3b63
Package: s3://dilimap/public/data. Top hash: 155a2b3b63
Package: s3://dilimap/public/data. Top hash: 155a2b3b63
[45]:
from matplotlib_venn import venn2
import matplotlib.pyplot as plt

df = adata_dilimap.obs

# Assume these are sets of compound IDs
compounds_with_total = set(
    df.groupby(['compound_name', 'SPLIT'], observed=True)['Cmax_uM']
    .mean()
    .dropna()
    .index
)
compounds_with_free = set(
    df.groupby(['compound_name', 'SPLIT'], observed=True)['free_Cmax_uM']
    .mean()
    .dropna()
    .index
)

plt.figure(figsize=(4, 4))
venn = venn2(
    [compounds_with_total, compounds_with_free], set_labels=('Total Cmax', 'Free Cmax')
)

if venn.get_label_by_id('01'):
    if venn.get_label_by_id('01').get_text() == '0':
        venn.get_label_by_id('01').set_text('')

# Set font size for all region labels
for label in venn.set_labels:
    label.set_fontsize(16)  # set size for set labels (e.g., "Total Cmax")

for label in venn.subset_labels:
    if label:  # skip None (e.g., if a region is empty)
        label.set_fontsize(16)  # set size for numbers inside the circles

plt.title('Cmax Annotation Coverage')
plt.show()
../_images/reproducibility_2.3_Training_ToxPredictor_Model_54_0.png

3a. Final model training and cross-validation

[46]:
random_state = seed = 42
[47]:
from dilimap.models import ToxPredictor
[48]:
idx = (~adata.obs['DILI'].isna()) & adata.obs['valid_cmpd']
data = adata[idx].copy()

cmpd_dose = (
    data.obs['compound_name'].astype(str) + '_' + data.obs['dose_level'].astype(str)
)
X = data.to_df().set_index(cmpd_dose)
y, groups = data.obs['DILI'], data.obs['compound_name']
/opt/miniconda3/envs/py311/lib/python3.11/site-packages/anndata/_core/anndata.py:1758: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
  utils.warn_names_duplicates("obs")
[49]:
model_params = dict(
    random_state=random_state,
    n_estimators=100,
    max_depth=2,
    min_samples_split=2,
    min_samples_leaf=1,
)

model = ToxPredictor(model='rf', model_params=model_params)
[50]:
model.cross_validate(X, y, n_splits=30, groups=data.obs['compound_name'])
[51]:
data.obs['DILI_prob_cv'] = model.predict_cv().values
data.obs['DILI_prob_ens'] = model.predict(X).values

3b. Cross-validation of DILI probabilities

[52]:
df_res = (
    data.obs[['DILI', 'DILI_prob_cv', 'DILI_prob_ens']]
    .astype(float)
    .groupby(cmpd_dose)
    .mean()
)
[53]:
from sklearn.metrics import confusion_matrix, roc_auc_score

tn, fp, fn, tp = confusion_matrix(df_res['DILI'], df_res['DILI_prob_cv'] > 0.7).ravel()
auc_score = roc_auc_score(df_res['DILI'], df_res['DILI_prob_cv'])

print(
    f'AUC: {np.round(auc_score, 2)} '
    f'\nSensitivity: {tp}/{(tp + fn)} ({int(tp / (tp + fn) * 100)}%) '
    f'\nSpecificity: {tn}/{(tn + fp)} ({int(tn / (tn + fp) * 100)}%)'
)
AUC: 0.73
Sensitivity: 89/176 (50%)
Specificity: 78/91 (85%)
[54]:
tn, fp, fn, tp = confusion_matrix(df_res['DILI'], df_res['DILI_prob_ens'] > 0.7).ravel()
auc_score = roc_auc_score(df_res['DILI'], df_res['DILI_prob_ens'])

print(
    f'AUC: {np.round(auc_score, 2)} '
    f'\nSensitivity: {tp}/{(tp + fn)} ({int(tp / (tp + fn) * 100)}%) '
    f'\nSpecificity: {tn}/{(tn + fp)} ({int(tn / (tn + fp) * 100)}%)'
)
AUC: 0.84
Sensitivity: 99/176 (56%)
Specificity: 82/91 (90%)
[55]:
dmap.pl.roc_curve(df_res['DILI'], df_res['DILI_prob_ens'], label='ensemble', show=False)
dmap.pl.roc_curve(
    df_res['DILI'], df_res['DILI_prob_cv'], color='orange', label='cross-val', show=True
)
../_images/reproducibility_2.3_Training_ToxPredictor_Model_66_0.png

3c. Cross-validation of safety margins

[56]:
idx = (~adata.obs['DILI'].isna()) & adata.obs['valid_cmpd']
adata.obs['DILI_prob_cv'] = np.nan

adata.obs.loc[idx, 'DILI_prob_cv'] = data.obs[
    'DILI_prob_cv'
]  # cross-validation results
adata.obs.loc[~idx, 'DILI_prob_cv'] = model.predict(adata[~idx].to_df())[
    'DILI_probability'
]  # use ensemble model to predict remaining compoumds
/var/folders/4g/179wj7wn0dn71d8dl8c91nb40000gp/T/ipykernel_46760/2746853614.py:2: ImplicitModificationWarning: Trying to modify attribute `.obs` of view, initializing view as actual.
  adata.obs['DILI_prob_cv'] = np.nan
/opt/miniconda3/envs/py311/lib/python3.11/site-packages/anndata/_core/anndata.py:1758: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
  utils.warn_names_duplicates("obs")
/opt/miniconda3/envs/py311/lib/python3.11/site-packages/anndata/_core/anndata.py:1758: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
  utils.warn_names_duplicates("obs")
[57]:
from dilimap.utils import crosstab

df_crosstab = crosstab(adata, ['compound_name', 'dose_level', 'DILI_prob_cv'])[
    ['Low', 'Middle', 'Mid-High', 'High']
]
display(
    df_crosstab.head(10)
    .style.format(precision=2)
    .background_gradient('RdBu_r', vmin=-0.2, vmax=1.25)
)
dose_level Low Middle Mid-High High
compound_name        
Abacavir 0.00 0.10 0.79 0.94
Acarbose 0.13 0.69 0.91 0.90
Acebutolol 0.32 0.59 0.85 0.93
Aceclofenac 0.00 0.00 0.12 0.89
Acetaminophen 0.23 0.00 0.09 0.57
Alaproclate 0.56 0.67 0.80 0.92
Albendazole 0.61 0.77 0.82 0.87
Alfuzosin 0.00 0.00 0.11 0.90
Allopurinol 0.00 0.38 0.05 0.89
Alvimopan 0.00 0.00 0.13 0.22
[58]:
df_res_margins = model.compute_safety_margin(
    adata, 'compound_name', 'dose_uM', 'Cmax_uM', 'DILI_prob_cv', y_thresh=0.7
)
df_res_margins['DILI_label'] = df_res_margins.index.map(
    groupby(adata, 'compound_name')['DILI_label']
)
df_res_margins['n_DEG_mean'] = df_res_margins.index.map(
    groupby(adata, 'compound_name')['n_DEG_mean']
)

df_res_margins['DILI'] = ~df_res_margins['DILI_label'].str.startswith('No DILI')
[59]:
fig, ax = plt.subplots(figsize=(4.5, 3), dpi=120)
dmap.pl.boxplot_with_swarm(
    df_res_margins,
    box_width=0.4,
    swarm_size=4,
    x='DILI_label',
    y='MOS_ToxPredictor',
    hue_order=['No DILI', 'DILI (likely)', 'DILI (known)', 'DILI (withdrawn)'],
    palette=['blue', 'orange', 'darkred', 'k'],
    axhline=80,
    ylabel='Predicted Safety Margin',
    xlabel='',
)
../_images/reproducibility_2.3_Training_ToxPredictor_Model_71_0.png
[60]:
df_res_DILI = df_res_margins[
    df_res_margins['DILI_label'].isin(
        ['DILI (withdrawn)', 'DILI (known)', 'DILI (likely)', 'No DILI']
    )
]
y_true = ~df_res_DILI['DILI_label'].isin(['No DILI'])

plt.figure(figsize=(6, 4))
dmap.pl.roc_curve(
    y_true,
    df_res_DILI['MOS_ToxPredictor'],
    threshold_points=[20, 50, 80],
    threshold_name='MOS',
    label='ToxPredictor',
    inverse=True,
    show=True,
)
../_images/reproducibility_2.3_Training_ToxPredictor_Model_72_0.png

Push model to quilt

[61]:
# model.save_model('ToxPredictor_v1.joblib', push_to_s3=True)
[62]:
# dmap.s3.write(data, 'ToxPredictor_v1_training_data.h5ad')