📥 Download all notebooks

2.1. Training: Gene signatures

📘 Overview

This notebook performs quality control on raw count data to filter out low-quality cells, and then uses DESeq2 to compute differential gene expression signatures (compound vs. DMSO) for each compound.

Inputs
Raw gene expression count matrix
Output
An AnnData file with DESeq2 differential expression results (log₂ fold changes, adjusted p-values)
Note
Please note that raw training data files are proprietary due to IP restrictions and will only be shared upon request, subject to data-sharing agreements and institutional policies.
[1]:
import numpy as np
import pandas as pd
import anndata as ad

import seaborn as sns
import matplotlib.pyplot as plt

import dilimap as dmap
[ ]:
%load_ext autoreload
%autoreload 2
[2]:
dmap.logging.print_version()
Running dilimap 1.0.2 (python 3.10.16) on 2025-06-29 15:28.

Pull raw data

[4]:
adata = dmap.s3.read('training_data_counts.h5ad', package_name='proprietary/data')
Package: s3://dilimap/proprietary/data. Top hash: ede9d7c164
[5]:
from dilimap.utils import platemap
[6]:
## Example plate design

adata.obs['WELL_ROW'] = adata.obs['WELL_ID'].str[0]
adata.obs['WELL_COL'] = adata.obs['WELL_ID'].str[1:]

data = adata[adata.obs['PLATE_NAME'] == 'RR1_S2']

CMPD_COLOR = (
    data.obs['COMPOUND']
    .map(
        {
            'DMSO': 'seashell',
            'DMSO_replaced': 'seashell',
            'Chlorpromazine': 'lightsalmon',
        }
    )
    .fillna('lightblue')
)
CMPD_COLOR_OPACITY = '; opacity: ' + data.obs['DOSE_LEVEL'].astype(str).map(
    {'Control': '1', 'Low': '.4', 'Middle': '.6', 'Mid-High': '.8', 'High': '1'}
)


def highlight_cells(x):
    return (
        'background-color: '
        + x.map(dict(zip(data.obs['CMPD_DOSE'], CMPD_COLOR))).fillna('white')
        + x.map(dict(zip(data.obs['CMPD_DOSE'], CMPD_COLOR_OPACITY)))
    )


df_cmpd_map = platemap(data, value_key='CMPD_DOSE', batch='PLATE_NAME')

col_borders = dict(subset=['4', '8'], **{'border-right': '1.5pt solid black'})
row_borders = dict(
    subset=(df_cmpd_map.index.str.endswith(('A', 'D', 'G')), df_cmpd_map.columns),
    **{'border-top': '1.5pt solid black'},
)

display(
    df_cmpd_map.style.apply(highlight_cells)
    .set_properties(**col_borders)
    .set_properties(**row_borders)
)
  1 2 3 4 5 6 7 8 9 10 11 12
RR1_S2_A Leflunomide_High Leflunomide_Mid-High Leflunomide_Middle Leflunomide_Low DMSO_replaced_Control DMSO_replaced_Control DMSO_replaced_Control DMSO_replaced_Control Citalopram_High Citalopram_Mid-High Citalopram_Middle Citalopram_Low
RR1_S2_B Leflunomide_High Leflunomide_Mid-High Leflunomide_Middle Leflunomide_Low DMSO_replaced_Control DMSO_replaced_Control DMSO_replaced_Control DMSO_replaced_Control Citalopram_High Citalopram_Mid-High Citalopram_Middle Citalopram_Low
RR1_S2_C Leflunomide_High Leflunomide_Mid-High Leflunomide_Middle Leflunomide_Low DMSO_replaced_Control DMSO_replaced_Control DMSO_replaced_Control DMSO_replaced_Control Citalopram_High Citalopram_Mid-High Citalopram_Middle Citalopram_Low
RR1_S2_D Nitazoxanide_High Nitazoxanide_Mid-High Nitazoxanide_Middle Nitazoxanide_Low Maraviroc_High Maraviroc_Mid-High Maraviroc_Middle Maraviroc_Low Mifepristone_High Mifepristone_Mid-High Mifepristone_Middle Mifepristone_Low
RR1_S2_E Nitazoxanide_High Nitazoxanide_Mid-High Nitazoxanide_Middle Nitazoxanide_Low Maraviroc_High Maraviroc_Mid-High Maraviroc_Middle Maraviroc_Low Mifepristone_High Mifepristone_Mid-High Mifepristone_Middle Mifepristone_Low
RR1_S2_F Nitazoxanide_High Nitazoxanide_Mid-High Nitazoxanide_Middle Nitazoxanide_Low Maraviroc_High Maraviroc_Mid-High Maraviroc_Middle Maraviroc_Low Mifepristone_High Mifepristone_Mid-High Mifepristone_Middle Mifepristone_Low
RR1_S2_G DMSO_Control DMSO_Control DMSO_Control DMSO_Control DMSO_Control DMSO_Control DMSO_Control DMSO_Control Chlorpromazine_High Chlorpromazine_High Chlorpromazine_High Chlorpromazine_High

QC (total counts, mtRNA, rRNA)

[7]:
adata.obs['LDH_QC'].value_counts()
[7]:
LDH_QC
              3978
tech error     181
cell death      16
Name: count, dtype: int64
[8]:
adata.obs['ldh_qc_pass'] = adata.obs['LDH_QC'] == ''
[9]:
## QC of total counts and mtRNA

dmap.pp.qc_metrics(adata)

thresh_counts = (
    np.median(adata.obs['log_totalRNA']) - 2.5 * adata.obs['log_totalRNA'].std()
)
adata.obs['rna_qc_pass'] = (adata.obs['log_totalRNA'] > thresh_counts) & (
    adata.obs['pct_mtRNA'] < 9
)

print(f'Total RNA cutoff = {int(10**thresh_counts)}')
Added the following to `adata.obs`: ['log_totalRNA', 'pct_mtRNA', 'pct_rRNA']
Total RNA cutoff = 732678
[10]:
fig, axs = plt.subplots(1, 3, figsize=(12, 3), gridspec_kw={'wspace': 0.4})

sns.histplot(adata.obs['log_totalRNA'], bins=50, ax=axs[0])
sns.histplot(adata.obs['pct_mtRNA'], bins=200, ax=axs[1])
sns.histplot(adata.obs['pct_rRNA'], bins=50, ax=axs[2])

axs[1].set_xlim(0, 20)

axs[0].axvline(thresh_counts, c='k')
axs[1].axvline(9, c='k')
axs[2].axvline(9, c='k')

plt.show()
../_images/reproducibility_2.1_Training_Gene_Signatures_12_0.png
[11]:
## Example plate with QC metrics

data = adata[adata.obs['PLATE_NAME'] == 'RR1_S2']


def color_boolean(val):
    return f'color: {"lightred" if val else ""}; font-weight: {"bold" if val else ""}'


for val in ['log_totalRNA', 'pct_mtRNA']:
    print(f'\033[1m{val}')

    df = platemap(data, val, batch='PLATE_NAME')
    df_bool = np.invert(platemap(data, 'rna_qc_pass', batch='PLATE_NAME').astype(bool))

    if val != 'notes_flags':
        mu, std = np.median(data.obs[val]), np.std(data.obs[val])

        display(
            df.style.format(precision=2)
            .background_gradient(vmin=mu - 3 * std, vmax=mu + 3 * std, cmap='RdBu_r')
            .apply(lambda c: df_bool[c.name].apply(color_boolean))
            .applymap(lambda x: 'background-color: white' if pd.isnull(x) else '')
            .set_properties(**col_borders)
            .set_properties(**row_borders)
        )

    else:
        display(
            df.style.format(precision=2)
            .apply(lambda c: df_bool[c.name].apply(color_boolean))
            .applymap(lambda x: 'background-color: salmon' if x != '' else '')
            .set_properties(**col_borders)
            .set_properties(**row_borders)
        )
log_totalRNA
/var/folders/lz/prv79nmj5msg8h6nzqn0w7cw0000gn/T/ipykernel_22385/3827332189.py:22: FutureWarning: Styler.applymap has been deprecated. Use Styler.map instead.
  .applymap(lambda x: 'background-color: white' if pd.isnull(x) else '')
  1 2 3 4 5 6 7 8 9 10 11 12
RR1_S2_A 6.31 6.63 6.56 6.56 6.45 6.53 6.68 6.59 6.64 6.61 6.52 6.51
RR1_S2_B 6.47 6.50 6.66 6.43 6.57 6.40 6.74 6.62 6.62 6.60 6.58 6.43
RR1_S2_C 6.56 6.40 6.13 6.28 6.18 6.61 6.62 6.14 6.58 6.66 6.63 6.63
RR1_S2_D 6.27 6.50 6.43 6.56 6.42 6.39 6.41 6.34 6.13 6.43 6.43 6.23
RR1_S2_E 6.69 6.62 6.67 6.57 6.65 3.82 3.73 6.61 6.56 6.63 6.74 3.87
RR1_S2_F 6.60 6.54 6.35 6.30 6.63 6.57 6.71 6.30 6.72 6.58 6.70 6.57
RR1_S2_G 6.50 6.68 6.59 6.60 6.45 6.57 6.29 6.45 6.60 6.53 6.52 6.48
pct_mtRNA
/var/folders/lz/prv79nmj5msg8h6nzqn0w7cw0000gn/T/ipykernel_22385/3827332189.py:22: FutureWarning: Styler.applymap has been deprecated. Use Styler.map instead.
  .applymap(lambda x: 'background-color: white' if pd.isnull(x) else '')
  1 2 3 4 5 6 7 8 9 10 11 12
RR1_S2_A 3.30 3.68 3.90 3.51 3.30 3.08 3.16 3.65 3.47 3.00 2.90 2.49
RR1_S2_B 3.09 3.34 3.55 3.83 3.19 3.26 2.41 2.97 2.89 2.84 2.62 2.13
RR1_S2_C 2.62 2.17 2.33 2.41 2.51 2.58 2.12 2.43 2.42 2.17 2.45 2.40
RR1_S2_D 3.31 3.46 3.65 3.71 3.66 3.30 3.45 3.30 2.72 2.91 2.54 2.61
RR1_S2_E 3.82 3.99 3.87 4.24 4.14 3.45 3.54 3.75 3.24 3.15 2.88 3.42
RR1_S2_F 3.35 2.96 3.27 3.42 3.62 3.13 3.19 2.86 2.56 2.49 2.39 2.30
RR1_S2_G 3.19 2.71 3.17 3.08 3.38 3.26 3.11 2.67 2.30 2.38 2.15 1.89

QC (cross-replicate correlation)

[12]:
dmap.pp.qc_cross_rep_correlation(adata, group_key='CMPD_DOSE', plate_key='PLATE_NAME')
Added the following to `adata.obs`: ['cross_rep_correlation', 'rep_corr_qc_pass']
[13]:
## Example plate with cross-replicate correlations

data = adata[adata.obs['PLATE_NAME'] == 'RR1_S2']

val = 'cross_rep_correlation'

df = platemap(data, val, batch='PLATE_NAME')
df_bool = np.invert(
    platemap(data, 'rep_corr_qc_pass', batch='PLATE_NAME').fillna(False).astype(bool)
)
mu, std = np.nanmean(data.obs[val]), np.nanstd(data.obs[val])

col_borders = dict(subset=['4', '8'], **{'border-right': '1.5pt solid black'})
row_borders = dict(
    subset=(df.index.str.endswith(('A', 'D', 'G')), df.columns),
    **{'border-top': '1.5pt solid black'},
)

print(val, f'CI=[{mu - 2 * std:0.3f},{1}]')
display(
    df.style.format(precision=3)
    .background_gradient(vmin=mu - 5 * std, vmax=1, cmap='Reds_r')
    .apply(lambda c: df_bool[c.name].apply(color_boolean))
    .applymap(lambda x: 'background-color: white' if pd.isnull(x) else '')
    .set_properties(**col_borders)
    .set_properties(**row_borders)
)
cross_rep_correlation CI=[0.990,1]
/var/folders/lz/prv79nmj5msg8h6nzqn0w7cw0000gn/T/ipykernel_22385/2606076238.py:18: FutureWarning: Styler.applymap has been deprecated. Use Styler.map instead.
  .applymap(lambda x: 'background-color: white' if pd.isnull(x) else '')
  1 2 3 4 5 6 7 8 9 10 11 12
RR1_S2_A 0.997 1.000 0.999 0.994 0.999 0.999 0.999 0.998 0.997 0.999 0.998 0.994
RR1_S2_B 0.999 1.000 0.999 0.994 0.999 0.999 0.998 0.999 0.998 0.999 0.999 0.996
RR1_S2_C 0.999 0.995 0.997 0.994 0.998 0.998 0.998 0.997 0.998 0.998 0.999 0.996
RR1_S2_D 0.998 0.997 0.999 0.995 0.999 0.997 0.998 0.996 0.999 0.998 0.995 0.998
RR1_S2_E 0.999 0.997 0.997 0.996 0.997 0.984 0.979 0.995 0.999 0.998 0.995 0.978
RR1_S2_F 0.999 0.997 0.999 0.996 0.999 0.997 0.998 0.996 0.999 0.997 0.995 0.998
RR1_S2_G 0.999 0.998 0.999 0.998 0.999 0.999 0.999 0.997 0.999 0.999 0.999 0.999
[14]:
from natsort import index_natsorted

sns.set(font_scale=0.8)

df_cmpd_flags = data.obs[~data.obs['rep_corr_qc_pass']]['COMPOUND'].value_counts()
df_cmpd_flags = df_cmpd_flags[~df_cmpd_flags.index.str.contains('DMSO')]
cmpd_flags = df_cmpd_flags.index[:8]

fig, axs = plt.subplots(
    1, len(cmpd_flags), figsize=(5 * len(cmpd_flags), 3), gridspec_kw={'wspace': 0.5}
)
for i, cmpd_name in enumerate(cmpd_flags):
    df_sub = data[data.obs['COMPOUND'] == cmpd_name].to_df()
    df_sub.index = data[data.obs['COMPOUND'] == cmpd_name].obs['WELL_ID']

    df_sub_corr = df_sub.T.corr()
    idx_sort = index_natsorted(df_sub_corr.index.str[1:] + df_sub_corr.index.str[0])
    df_sub_corr = df_sub_corr.iloc[idx_sort, idx_sort]

    g = sns.heatmap(df_sub_corr, vmax=1, vmin=0.9, yticklabels=True, ax=axs[i])

    flagged = (
        data[(data.obs['COMPOUND'] == cmpd_name) & ~data.obs['rep_corr_qc_pass']]
        .obs['WELL_ID']
        .values
    )
    for label in g.get_yticklabels():
        if label.get_text() in flagged:
            label.set_color('red')

    axs[i].set_title(cmpd_name)
../_images/reproducibility_2.1_Training_Gene_Signatures_17_0.png
[15]:
cmpd_flags = ['DMSO', 'Chlorpromazine']

fig, axs = plt.subplots(
    1, len(cmpd_flags), figsize=(5 * len(cmpd_flags), 3), gridspec_kw={'wspace': 0.5}
)
for i, cmpd_name in enumerate(cmpd_flags):
    df_sub = data[data.obs['COMPOUND'] == cmpd_name].to_df()
    df_sub.index = data[data.obs['COMPOUND'] == cmpd_name].obs['PLATE_NAME'].astype(str)

    df_sub_corr = df_sub.T.corr()
    idx_sort = index_natsorted(df_sub_corr.index.str[1:] + df_sub_corr.index.str[0])
    df_sub_corr = df_sub_corr.iloc[idx_sort, idx_sort]

    g = sns.heatmap(df_sub_corr, vmax=1, vmin=0.9, yticklabels=4, ax=axs[i])

    flagged = (
        data[(data.obs['COMPOUND'] == cmpd_name) & ~data.obs['rep_corr_qc_pass']]
        .obs['WELL_ID']
        .values
    )
    for label in g.get_yticklabels():
        if label.get_text() in flagged:
            label.set_color('red')

    axs[i].set_title(cmpd_name)
../_images/reproducibility_2.1_Training_Gene_Signatures_18_0.png

QC (hepatocyte fidelity sanity check)

[33]:
# Log-normalize total counts (similar to sc.pp.normalize_total and sc.pp.log1p)

X = adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray()
cell_sums = X.sum(axis=1, keepdims=True)
X_norm = (X / cell_sums) * np.median(cell_sums)
X_norm = np.log10(X_norm + 1)

adata_norm = ad.AnnData(X=X_norm, obs=adata.obs, var=adata.var)
/opt/anaconda3/envs/py310/lib/python3.10/site-packages/anndata/_core/anndata.py:1756: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
  utils.warn_names_duplicates("obs")
[79]:
# Celltype markers curated from https://panglaodb.se/markers.html and https://pmc.ncbi.nlm.nih.gov/articles/PMC6687507

marker_celltypes = {
    # Hepatocytes, cluster 14 in liver atlas
    'Hepatocytes': [
        'ALB',
        'CYP3A4',
        'APOA1',
        'APOA2',
        'APOC3',
        'TF',
        'APOH',
        'ALDOB',
        'TTR',
        'GSTA2',
    ],
    # Kupffer, from panglaodb
    'Kupffer': ['C1QB', 'C1QA', 'CD163', 'MAFB', 'CSF1R', 'CLEC4F', 'FCNA', 'ADGRE1'],
    # Liver Sinusoidal Endothelial Cells (LSEC), cluster 9 in liver atlas
    'LSEC': ['EGFL7', 'FLT1', 'CLEC4M', 'CLEC4G', 'FCN3', 'STAB1', 'DNASE1L3'],
    # Hepatic Stellate Cells (HSC), from panglaodb
    'HSC': ['IGFBP6', 'COL1A2', 'SPARC', 'TPM2', 'COL3A1', 'MYL9', 'DCN'],
    # Cholangiocyte, cluster 4+7 in liver atlas
    'Cholangiocyte': ['AGR2', 'ELF3', 'CLDN4', 'KRT7', 'AQP1'],  # "KRT19",
    # Dedifferentiation, low in mature, high in dedifferentiated
    'Dedifferentiated': [
        'KRT19',  # Biliary/progenitor marker
        'EPCAM',  # Also indicates epithelial progenitors
        'LAMB1',  # ECM remodeling and stress-induced dedifferentiation
        'PROM1',  # Progenitor/stemness marker
        'LGR5',  # Liver progenitor marker
        'AFP',  # Strongly fetal/hepatoblast marker
    ],
}

all_markers = np.hstack(list(marker_celltypes.values()))
all_markers = [g for g in all_markers if g in adata_norm.var_names]

idx = (
    adata_norm.obs['ldh_qc_pass']
    & adata_norm.obs['rna_qc_pass']
    & adata_norm.obs['rep_corr_qc_pass']
)
marker_expr = adata_norm[
    idx & (adata_norm.obs['COMPOUND'] == 'DMSO'), all_markers
].to_df()
[80]:
import matplotlib.pyplot as plt
import numpy as np

# Color palette
celltype_palette = {
    'Hepatocytes': '#1f77b4',
    'Kupffer': '#d62728',
    'LSEC': '#2ca02c',
    'HSC': '#9467bd',
    'Cholangiocyte': '#ff7f0e',
    'Dedifferentiated': '#7f7f7f',
}

# Flatten mapping: gene → celltype
marker_to_celltype = {g: ct for ct, genes in marker_celltypes.items() for g in genes}
genes = list(marker_expr.columns)
celltypes = [marker_to_celltype.get(g, 'Unknown') for g in genes]

# Plot
fig, ax = plt.subplots(figsize=(8, 3.5))
flierprops = dict(
    marker='o', markersize=1.5, linestyle='none', markerfacecolor='black', alpha=0.4
)
box = marker_expr.boxplot(
    rot=90,
    patch_artist=True,
    flierprops=flierprops,
    boxprops=dict(color='black', linewidth=1.2),
    ax=ax,
)

# Main x-axis: gene names
xticks = np.arange(1, len(genes) + 1)
ax.set_xticks(xticks)
ax.set_xticklabels(genes, rotation=90, fontsize=8)
for label in ax.get_xticklabels():
    gene = label.get_text()
    celltype = marker_to_celltype.get(gene, 'black')
    label.set_color(celltype_palette.get(celltype, 'black'))

# Y-axis: log10(counts) → raw counts
ax.set_yticks([0, 1, 2, 3, 4, 5])
ax.set_yticklabels(['1', '10', '100', '1k', '10k', '100k'])

# Horizontal guide
ax.axhline(3, ls='--', lw=1, color='gray')

# Compute midpoints for grouped cell types
unique_blocks = []
start_idx = 0
for i in range(1, len(genes)):
    if celltypes[i] != celltypes[i - 1]:
        mid = (start_idx + i - 1) / 2 + 1
        unique_blocks.append((mid, celltypes[i - 1]))
        start_idx = i
mid = (start_idx + len(genes) - 1) / 2 + 1
unique_blocks.append((mid, celltypes[-1]))

# Add cell type labels below the plot (just below x-axis)
y_pos = -0.5  # adjust as needed
for x, label_text in unique_blocks:
    ax.text(
        x,
        y_pos,
        label_text,
        ha='center',
        va='top',
        fontsize=9,
        fontweight='bold',
        color=celltype_palette.get(label_text, 'black'),
        transform=ax.get_xaxis_transform(),
    )

# Final touches
ax.set_title('Expression of Liver Cell-Type Markers', fontsize=12)
ax.set_ylabel('Counts (normalized)', fontsize=10)
plt.tight_layout()
plt.show()
../_images/reproducibility_2.1_Training_Gene_Signatures_22_0.png

DESeq2

[ ]:
%%capture

from dask.distributed import LocalCluster, Client

cluster = LocalCluster(n_workers=8, threads_per_worker=1, memory_limit='auto')
client = Client(cluster)
[ ]:
adatas = {}

for plate in np.unique(adata.obs['PLATE_NAME']):
    adata_batch = adata[adata.obs['PLATE_NAME'] == plate].copy()
    adata_batch = adata_batch[
        adata_batch.obs['ldh_qc_pass']
        & adata_batch.obs['rna_qc_pass']
        & adata_batch.obs['rep_corr_qc_pass']
    ].copy()

    adata_batch.obs_names_make_unique()

    adatas[plate] = dmap.pp.deseq2(
        adata_batch,
        pert_name_col='COMPOUND',
        other_pert_cols=['DOSE_LEVEL'],
        dask_client=client,
    )

    adatas[plate].obs['PLATE_NAME'] = plate
[ ]:
adata_deseq = ad.concat(adatas.values(), join='outer')
[ ]:
adata_deseq = adata_deseq[~np.all(np.isnan(adata_deseq.X), 1)].copy()

Push file to S3

[ ]:
# dmap.s3.write(adata_deseq, 'training_data_deseq2.h5ad', package_name='proprietary/data')