PROTAC-Degradation-Predictor / notebooks /protac_degradation_predictor.py
ribesstefano's picture
Started working on packaging the repository
5e01175
raw
history blame
47.7 kB
import os
import pickle
import warnings
import logging
from collections import defaultdict
from typing import Literal, List, Tuple, Optional
import urllib.request
import joblib
import optuna
from optuna.samplers import TPESampler
import h5py
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from jsonargparse import CLI
from tqdm.auto import tqdm
from imblearn.over_sampling import SMOTE, ADASYN
from sklearn.preprocessing import OrdinalEncoder, StandardScaler, LabelEncoder
from sklearn.model_selection import (
StratifiedKFold,
StratifiedGroupKFold,
)
from sklearn.base import ClassifierMixin
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from torchmetrics import (
Accuracy,
AUROC,
Precision,
Recall,
F1Score,
MetricCollection,
)
# Ignore UserWarning from Matplotlib
warnings.filterwarnings("ignore", ".*FixedLocator*")
# Ignore UserWarning from PyTorch Lightning
warnings.filterwarnings("ignore", ".*does not have many workers.*")
protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
# Map E3 Ligase Iap to IAP
protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
def is_active(DC50: float, Dmax: float, oring=False, pDC50_threshold=7.0, Dmax_threshold=0.8) -> bool:
""" Check if a PROTAC is active based on DC50 and Dmax.
Args:
DC50(float): DC50 in nM
Dmax(float): Dmax in %
Returns:
bool: True if active, False if inactive, np.nan if either DC50 or Dmax is NaN
"""
pDC50 = -np.log10(DC50 * 1e-9) if pd.notnull(DC50) else np.nan
Dmax = Dmax / 100
if pd.notnull(pDC50):
if pDC50 < pDC50_threshold:
return False
if pd.notnull(Dmax):
if Dmax < Dmax_threshold:
return False
if oring:
if pd.notnull(pDC50):
return True if pDC50 >= pDC50_threshold else False
elif pd.notnull(Dmax):
return True if Dmax >= Dmax_threshold else False
else:
return np.nan
else:
if pd.notnull(pDC50) and pd.notnull(Dmax):
return True if pDC50 >= pDC50_threshold and Dmax >= Dmax_threshold else False
else:
return np.nan
# ## Load Protein Embeddings
# Protein embeddings downloaded from [Uniprot](https://www.uniprot.org/help/embeddings).
#
# Please note that running the following cell the first time might take a while.
download_link = "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/embeddings/UP000005640_9606/per-protein.h5"
embeddings_path = "../data/uniprot2embedding.h5"
if not os.path.exists(embeddings_path):
# Download the file
print(f'Downloading embeddings from {download_link}')
urllib.request.urlretrieve(download_link, embeddings_path)
protein_embeddings = {}
with h5py.File("../data/uniprot2embedding.h5", "r") as file:
uniprots = protac_df['Uniprot'].unique().tolist()
uniprots += protac_df['E3 Ligase Uniprot'].unique().tolist()
for i, sequence_id in tqdm(enumerate(uniprots), desc='Loading protein embeddings'):
try:
embedding = file[sequence_id][:]
protein_embeddings[sequence_id] = np.array(embedding)
except KeyError:
print(f'KeyError for {sequence_id}')
protein_embeddings[sequence_id] = np.zeros((1024,))
## Load Cell Embeddings
cell2embedding_filepath = '../data/cell2embedding.pkl'
with open(cell2embedding_filepath, 'rb') as f:
cell2embedding = pickle.load(f)
print(f'Loaded {len(cell2embedding)} cell lines')
emb_shape = cell2embedding[list(cell2embedding.keys())[0]].shape
# Assign all-zero vectors to cell lines that are not in the embedding file
for cell_line in protac_df['Cell Line Identifier'].unique():
if cell_line not in cell2embedding:
cell2embedding[cell_line] = np.zeros(emb_shape)
## Precompute Molecular Fingerprints
fingerprint_size = 224
morgan_fpgen = AllChem.GetMorganGenerator(
radius=15,
fpSize=fingerprint_size,
includeChirality=True,
)
smiles2fp = {}
for smiles in tqdm(protac_df['Smiles'].unique().tolist(), desc='Precomputing fingerprints'):
# Get the fingerprint as a bit vector
morgan_fp = morgan_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
smiles2fp[smiles] = morgan_fp
# Count the number of unique SMILES and the number of unique Morgan fingerprints
print(f'Number of unique SMILES: {len(smiles2fp)}')
print(f'Number of unique fingerprints: {len(set([tuple(fp) for fp in smiles2fp.values()]))}')
# Get the list of SMILES with overlapping fingerprints
overlapping_smiles = []
unique_fps = set()
for smiles, fp in smiles2fp.items():
if tuple(fp) in unique_fps:
overlapping_smiles.append(smiles)
else:
unique_fps.add(tuple(fp))
print(f'Number of SMILES with overlapping fingerprints: {len(overlapping_smiles)}')
print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smiles"].isin(overlapping_smiles)])}')
# Get the pair-wise tanimoto similarity between the PROTAC fingerprints
tanimoto_matrix = defaultdict(list)
for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
fp1 = smiles2fp[smiles1]
# TODO: Use BulkTanimotoSimilarity for better performance
for j, smiles2 in enumerate(protac_df['Smiles'].unique()):
if j < i:
continue
fp2 = smiles2fp[smiles2]
tanimoto_dist = DataStructs.TanimotoSimilarity(fp1, fp2)
tanimoto_matrix[smiles1].append(tanimoto_dist)
avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
class PROTAC_Dataset(Dataset):
def __init__(
self,
protac_df,
protein_embeddings=protein_embeddings,
cell2embedding=cell2embedding,
smiles2fp=smiles2fp,
use_smote=False,
oversampler=None,
active_label='Active',
include_mol_graphs=False,
):
""" Initialize the PROTAC dataset
Args:
protac_df (pd.DataFrame): The PROTAC dataframe
protein_embeddings (dict): Dictionary of protein embeddings
cell2embedding (dict): Dictionary of cell line embeddings
smiles2fp (dict): Dictionary of SMILES to fingerprint
use_smote (bool): Whether to use SMOTE for oversampling
use_ored_activity (bool): Whether to use the 'Active - OR' column
"""
# Filter out examples with NaN in active_col column
self.data = protac_df # [~protac_df[active_col].isna()]
self.protein_embeddings = protein_embeddings
self.cell2embedding = cell2embedding
self.smiles2fp = smiles2fp
self.active_label = active_label
self.include_mol_graphs = include_mol_graphs
self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0]
self.protein_emb_dim = protein_embeddings[list(
protein_embeddings.keys())[0]].shape[0]
self.cell_emb_dim = cell2embedding[list(
cell2embedding.keys())[0]].shape[0]
# Look up the embeddings
self.data = pd.DataFrame({
'Smiles': self.data['Smiles'].apply(lambda x: smiles2fp[x].astype(np.float32)).tolist(),
'Uniprot': self.data['Uniprot'].apply(lambda x: protein_embeddings[x].astype(np.float32)).tolist(),
'E3 Ligase Uniprot': self.data['E3 Ligase Uniprot'].apply(lambda x: protein_embeddings[x].astype(np.float32)).tolist(),
'Cell Line Identifier': self.data['Cell Line Identifier'].apply(lambda x: cell2embedding[x].astype(np.float32)).tolist(),
self.active_label: self.data[self.active_label].astype(np.float32).tolist(),
})
# Apply SMOTE
self.use_smote = use_smote
self.oversampler = oversampler
if self.use_smote:
self.apply_smote()
def apply_smote(self):
# Prepare the dataset for SMOTE
features = []
labels = []
for _, row in self.data.iterrows():
features.append(np.hstack([
row['Smiles'],
row['Uniprot'],
row['E3 Ligase Uniprot'],
row['Cell Line Identifier'],
]))
labels.append(row[self.active_label])
# Convert to numpy array
features = np.array(features).astype(np.float32)
labels = np.array(labels).astype(np.float32)
# Initialize SMOTE and fit
if self.oversampler is None:
oversampler = SMOTE(random_state=42)
else:
oversampler = self.oversampler
features_smote, labels_smote = oversampler.fit_resample(features, labels)
# Separate the features back into their respective embeddings
smiles_embs = features_smote[:, :self.smiles_emb_dim]
poi_embs = features_smote[:,
self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]
e3_embs = features_smote[:, self.smiles_emb_dim +
self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]
cell_embs = features_smote[:, -self.cell_emb_dim:]
# Reconstruct the dataframe with oversampled data
df_smote = pd.DataFrame({
'Smiles': list(smiles_embs),
'Uniprot': list(poi_embs),
'E3 Ligase Uniprot': list(e3_embs),
'Cell Line Identifier': list(cell_embs),
self.active_label: labels_smote
})
self.data = df_smote
def fit_scaling(self, use_single_scaler=False, **scaler_kwargs) -> dict:
""" Fit the scalers for the data.
Returns:
dict: The fitted scalers.
"""
if use_single_scaler:
scaler = StandardScaler(**scaler_kwargs)
embeddings = np.hstack([
np.array(self.data['Smiles'].tolist()),
np.array(self.data['Uniprot'].tolist()),
np.array(self.data['E3 Ligase Uniprot'].tolist()),
np.array(self.data['Cell Line Identifier'].tolist()),
])
scaler.fit(embeddings)
return scaler
else:
scalers = {}
scalers['Smiles'] = StandardScaler(**scaler_kwargs)
scalers['Uniprot'] = StandardScaler(**scaler_kwargs)
scalers['E3 Ligase Uniprot'] = StandardScaler(**scaler_kwargs)
scalers['Cell Line Identifier'] = StandardScaler(**scaler_kwargs)
scalers['Smiles'].fit(np.stack(self.data['Smiles'].to_numpy()))
scalers['Uniprot'].fit(np.stack(self.data['Uniprot'].to_numpy()))
scalers['E3 Ligase Uniprot'].fit(np.stack(self.data['E3 Ligase Uniprot'].to_numpy()))
scalers['Cell Line Identifier'].fit(np.stack(self.data['Cell Line Identifier'].to_numpy()))
return scalers
def apply_scaling(self, scalers: dict, use_single_scaler=False):
""" Apply scaling to the data.
Args:
scalers (dict): The scalers for each feature.
"""
if use_single_scaler:
embeddings = np.hstack([
np.array(self.data['Smiles'].tolist()),
np.array(self.data['Uniprot'].tolist()),
np.array(self.data['E3 Ligase Uniprot'].tolist()),
np.array(self.data['Cell Line Identifier'].tolist()),
])
scaled_embeddings = scalers.transform(embeddings)
self.data = pd.DataFrame({
'Smiles': list(scaled_embeddings[:, :self.smiles_emb_dim]),
'Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]),
'E3 Ligase Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim+self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]),
'Cell Line Identifier': list(scaled_embeddings[:, -self.cell_emb_dim:]),
self.active_label: self.data[self.active_label]
})
else:
self.data['Smiles'] = self.data['Smiles'].apply(lambda x: scalers['Smiles'].transform(x[np.newaxis, :])[0])
self.data['Uniprot'] = self.data['Uniprot'].apply(lambda x: scalers['Uniprot'].transform(x[np.newaxis, :])[0])
self.data['E3 Ligase Uniprot'] = self.data['E3 Ligase Uniprot'].apply(lambda x: scalers['E3 Ligase Uniprot'].transform(x[np.newaxis, :])[0])
self.data['Cell Line Identifier'] = self.data['Cell Line Identifier'].apply(lambda x: scalers['Cell Line Identifier'].transform(x[np.newaxis, :])[0])
def get_numpy_arrays(self):
X = np.hstack([
np.array(self.data['Smiles'].tolist()),
np.array(self.data['Uniprot'].tolist()),
np.array(self.data['E3 Ligase Uniprot'].tolist()),
np.array(self.data['Cell Line Identifier'].tolist()),
]).copy()
y = self.data[self.active_label].values.copy()
return X, y
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
elem = {
'smiles_emb': self.data['Smiles'].iloc[idx],
'poi_emb': self.data['Uniprot'].iloc[idx],
'e3_emb': self.data['E3 Ligase Uniprot'].iloc[idx],
'cell_emb': self.data['Cell Line Identifier'].iloc[idx],
'active': self.data[self.active_label].iloc[idx],
}
return elem
def train_sklearn_model(
clf: ClassifierMixin,
train_df: pd.DataFrame,
val_df: pd.DataFrame,
test_df: Optional[pd.DataFrame] = None,
active_label: str = 'Active',
use_single_scaler: bool = True,
) -> Tuple[ClassifierMixin, nn.ModuleDict]:
""" Train a classifier model on train and val sets and evaluate it on a test set.
Args:
clf: The classifier model to train and evaluate.
train_df (pd.DataFrame): The training set.
val_df (pd.DataFrame): The validation set.
test_df (Optional[pd.DataFrame]): The test set.
Returns:
Tuple[ClassifierMixin, nn.ModuleDict]: The trained model and the metrics.
"""
# Initialize the datasets
train_ds = PROTAC_Dataset(
train_df,
protein_embeddings,
cell2embedding,
smiles2fp,
active_label=active_label,
use_smote=False,
)
scaler = train_ds.fit_scaling(use_single_scaler=use_single_scaler)
train_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
val_ds = PROTAC_Dataset(
val_df,
protein_embeddings,
cell2embedding,
smiles2fp,
active_label=active_label,
use_smote=False,
)
val_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
if test_df is not None:
test_ds = PROTAC_Dataset(
test_df,
protein_embeddings,
cell2embedding,
smiles2fp,
active_label=active_label,
use_smote=False,
)
test_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
# Get the numpy arrays
X_train, y_train = train_ds.get_numpy_arrays()
X_val, y_val = val_ds.get_numpy_arrays()
if test_df is not None:
X_test, y_test = test_ds.get_numpy_arrays()
# Train the model
clf.fit(X_train, y_train)
# Define the metrics as a module dict
stages = ['train_metrics', 'val_metrics', 'test_metrics']
metrics = nn.ModuleDict({s: MetricCollection({
'acc': Accuracy(task='binary'),
'roc_auc': AUROC(task='binary'),
'precision': Precision(task='binary'),
'recall': Recall(task='binary'),
'f1_score': F1Score(task='binary'),
'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
'hp_metric': Accuracy(task='binary'),
}, prefix=s.replace('metrics', '')) for s in stages})
# Get the predictions
metrics_out = {}
y_pred = torch.tensor(clf.predict_proba(X_train)[:, 1])
y_true = torch.tensor(y_train)
metrics['train_metrics'].update(y_pred, y_true)
metrics_out.update(metrics['train_metrics'].compute())
y_pred = torch.tensor(clf.predict_proba(X_val)[:, 1])
y_true = torch.tensor(y_val)
metrics['val_metrics'].update(y_pred, y_true)
metrics_out.update(metrics['val_metrics'].compute())
if test_df is not None:
y_pred = torch.tensor(clf.predict_proba(X_test)[:, 1])
y_true = torch.tensor(y_test)
metrics['test_metrics'].update(y_pred, y_true)
metrics_out.update(metrics['test_metrics'].compute())
return clf, metrics_out
class PROTAC_Model(pl.LightningModule):
def __init__(
self,
hidden_dim: int,
smiles_emb_dim: int = fingerprint_size,
poi_emb_dim: int = 1024,
e3_emb_dim: int = 1024,
cell_emb_dim: int = 768,
batch_size: int = 32,
learning_rate: float = 1e-3,
dropout: float = 0.2,
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
train_dataset: PROTAC_Dataset = None,
val_dataset: PROTAC_Dataset = None,
test_dataset: PROTAC_Dataset = None,
disabled_embeddings: list = [],
apply_scaling: bool = False,
):
super().__init__()
self.poi_emb_dim = poi_emb_dim
self.e3_emb_dim = e3_emb_dim
self.cell_emb_dim = cell_emb_dim
self.smiles_emb_dim = smiles_emb_dim
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.learning_rate = learning_rate
self.join_embeddings = join_embeddings
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.disabled_embeddings = disabled_embeddings
self.apply_scaling = apply_scaling
# Set our init args as class attributes
self.__dict__.update(locals()) # Add arguments as attributes
# Save the arguments passed to init
ignore_args_as_hyperparams = [
'train_dataset',
'test_dataset',
'val_dataset',
]
self.save_hyperparameters(ignore=ignore_args_as_hyperparams)
# Define "surrogate models" branches
if self.join_embeddings != 'beginning':
if 'poi' not in self.disabled_embeddings:
self.poi_emb = nn.Linear(poi_emb_dim, hidden_dim)
if 'e3' not in self.disabled_embeddings:
self.e3_emb = nn.Linear(e3_emb_dim, hidden_dim)
if 'cell' not in self.disabled_embeddings:
self.cell_emb = nn.Linear(cell_emb_dim, hidden_dim)
if 'smiles' not in self.disabled_embeddings:
self.smiles_emb = nn.Linear(smiles_emb_dim, hidden_dim)
# Define hidden dimension for joining layer
if self.join_embeddings == 'beginning':
joint_dim = smiles_emb_dim if 'smiles' not in self.disabled_embeddings else 0
joint_dim += poi_emb_dim if 'poi' not in self.disabled_embeddings else 0
joint_dim += e3_emb_dim if 'e3' not in self.disabled_embeddings else 0
joint_dim += cell_emb_dim if 'cell' not in self.disabled_embeddings else 0
elif self.join_embeddings == 'concat':
joint_dim = hidden_dim * (4 - len(self.disabled_embeddings))
elif self.join_embeddings == 'sum':
joint_dim = hidden_dim
self.fc0 = nn.Linear(joint_dim, joint_dim)
self.fc1 = nn.Linear(joint_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 1)
self.dropout = nn.Dropout(p=dropout)
stages = ['train_metrics', 'val_metrics', 'test_metrics']
self.metrics = nn.ModuleDict({s: MetricCollection({
'acc': Accuracy(task='binary'),
'roc_auc': AUROC(task='binary'),
'precision': Precision(task='binary'),
'recall': Recall(task='binary'),
'f1_score': F1Score(task='binary'),
'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
'hp_metric': Accuracy(task='binary'),
}, prefix=s.replace('metrics', '')) for s in stages})
# Misc settings
self.missing_dataset_error = \
'''Class variable `{0}` is None. If the model was loaded from a checkpoint, the dataset must be set manually:
model = {1}.load_from_checkpoint('checkpoint.ckpt')
model.{0} = my_{0}
'''
# Apply scaling in datasets
if self.apply_scaling:
use_single_scaler = True if self.join_embeddings == 'beginning' else False
self.scalers = self.train_dataset.fit_scaling(use_single_scaler)
self.train_dataset.apply_scaling(self.scalers, use_single_scaler)
self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
if self.test_dataset:
self.test_dataset.apply_scaling(self.scalers, use_single_scaler)
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
embeddings = []
if self.join_embeddings == 'beginning':
if 'poi' not in self.disabled_embeddings:
embeddings.append(poi_emb)
if 'e3' not in self.disabled_embeddings:
embeddings.append(e3_emb)
if 'cell' not in self.disabled_embeddings:
embeddings.append(cell_emb)
if 'smiles' not in self.disabled_embeddings:
embeddings.append(smiles_emb)
x = torch.cat(embeddings, dim=1)
x = self.dropout(F.relu(self.fc0(x)))
else:
if 'poi' not in self.disabled_embeddings:
embeddings.append(self.poi_emb(poi_emb))
if 'e3' not in self.disabled_embeddings:
embeddings.append(self.e3_emb(e3_emb))
if 'cell' not in self.disabled_embeddings:
embeddings.append(self.cell_emb(cell_emb))
if 'smiles' not in self.disabled_embeddings:
embeddings.append(self.smiles_emb(smiles_emb))
if self.join_embeddings == 'concat':
x = torch.cat(embeddings, dim=1)
elif self.join_embeddings == 'sum':
if len(embeddings) > 1:
embeddings = torch.stack(embeddings, dim=1)
x = torch.sum(embeddings, dim=1)
else:
x = embeddings[0]
x = self.dropout(F.relu(self.fc1(x)))
x = self.dropout(F.relu(self.fc2(x)))
x = self.fc3(x)
return x
def step(self, batch, batch_idx, stage):
poi_emb = batch['poi_emb']
e3_emb = batch['e3_emb']
cell_emb = batch['cell_emb']
smiles_emb = batch['smiles_emb']
y = batch['active'].float().unsqueeze(1)
y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
loss = F.binary_cross_entropy_with_logits(y_hat, y)
self.metrics[f'{stage}_metrics'].update(y_hat, y)
self.log(f'{stage}_loss', loss, on_epoch=True, prog_bar=True)
self.log_dict(self.metrics[f'{stage}_metrics'], on_epoch=True)
return loss
def training_step(self, batch, batch_idx):
return self.step(batch, batch_idx, 'train')
def validation_step(self, batch, batch_idx):
return self.step(batch, batch_idx, 'val')
def test_step(self, batch, batch_idx):
return self.step(batch, batch_idx, 'test')
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.learning_rate)
def predict_step(self, batch, batch_idx):
poi_emb = batch['poi_emb']
e3_emb = batch['e3_emb']
cell_emb = batch['cell_emb']
smiles_emb = batch['smiles_emb']
if self.apply_scaling:
if self.join_embeddings == 'beginning':
embeddings = np.hstack([
np.array(smiles_emb.tolist()),
np.array(poi_emb.tolist()),
np.array(e3_emb.tolist()),
np.array(cell_emb.tolist()),
])
embeddings = self.scalers.transform(embeddings)
smiles_emb = embeddings[:, :self.smiles_emb_dim]
poi_emb = embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.poi_emb_dim]
e3_emb = embeddings[:, self.smiles_emb_dim+self.poi_emb_dim:self.smiles_emb_dim+2*self.poi_emb_dim]
cell_emb = embeddings[:, -self.cell_emb_dim:]
else:
poi_emb = self.scalers['Uniprot'].transform(poi_emb)
e3_emb = self.scalers['E3 Ligase Uniprot'].transform(e3_emb)
cell_emb = self.scalers['Cell Line Identifier'].transform(cell_emb)
smiles_emb = self.scalers['Smiles'].transform(smiles_emb)
y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
return torch.sigmoid(y_hat)
def train_dataloader(self):
if self.train_dataset is None:
format = 'train_dataset', self.__class__.__name__
raise ValueError(self.missing_dataset_error.format(*format))
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
# drop_last=True,
)
def val_dataloader(self):
if self.val_dataset is None:
format = 'val_dataset', self.__class__.__name__
raise ValueError(self.missing_dataset_error.format(*format))
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
)
def test_dataloader(self):
if self.test_dataset is None:
format = 'test_dataset', self.__class__.__name__
raise ValueError(self.missing_dataset_error.format(*format))
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
)
def train_model(
train_df: pd.DataFrame,
val_df: pd.DataFrame,
test_df: Optional[pd.DataFrame] = None,
hidden_dim: int = 768,
batch_size: int = 8,
learning_rate: float = 2e-5,
dropout: float = 0.2,
max_epochs: int = 50,
smiles_emb_dim: int = fingerprint_size,
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
smote_k_neighbors:int = 5,
use_smote: bool = True,
apply_scaling: bool = False,
active_label:str = 'Active',
fast_dev_run: bool = False,
use_logger: bool = True,
logger_name: str = 'protac',
disabled_embeddings: List[str] = [],
) -> tuple:
""" Train a PROTAC model using the given datasets and hyperparameters.
Args:
train_df (pd.DataFrame): The training set.
val_df (pd.DataFrame): The validation set.
test_df (pd.DataFrame): The test set. If provided, the returned metrics will include test performance.
hidden_dim (int): The hidden dimension of the model.
batch_size (int): The batch size.
learning_rate (float): The learning rate.
max_epochs (int): Th e maximum number of epochs.
smiles_emb_dim (int): The dimension of the SMILES embeddings.
smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
fast_dev_run (bool): Whether to run a fast development run.
disabled_embeddings (list): The list of disabled embeddings.
Returns:
tuple: The trained model, the trainer, and the metrics.
"""
oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42)
train_ds = PROTAC_Dataset(
train_df,
protein_embeddings,
cell2embedding,
smiles2fp,
use_smote=use_smote,
oversampler=oversampler if use_smote else None,
active_label=active_label,
)
val_ds = PROTAC_Dataset(
val_df,
protein_embeddings,
cell2embedding,
smiles2fp,
active_label=active_label,
)
if test_df is not None:
test_ds = PROTAC_Dataset(
test_df,
protein_embeddings,
cell2embedding,
smiles2fp,
active_label=active_label,
)
logger = pl.loggers.TensorBoardLogger(
save_dir='../logs',
name=logger_name,
)
callbacks = [
pl.callbacks.EarlyStopping(
monitor='train_loss',
patience=10,
mode='min',
verbose=False,
),
pl.callbacks.EarlyStopping(
monitor='val_loss',
patience=5,
mode='min',
verbose=False,
),
pl.callbacks.EarlyStopping(
monitor='val_acc',
patience=10,
mode='max',
verbose=False,
),
# pl.callbacks.ModelCheckpoint(
# monitor='val_acc',
# mode='max',
# verbose=True,
# filename='{epoch}-{val_metrics_opt_score:.4f}',
# ),
]
# Define Trainer
trainer = pl.Trainer(
logger=logger if use_logger else False,
callbacks=callbacks,
max_epochs=max_epochs,
fast_dev_run=fast_dev_run,
enable_model_summary=False,
enable_checkpointing=False,
enable_progress_bar=False,
devices=1,
num_nodes=1,
)
model = PROTAC_Model(
hidden_dim=hidden_dim,
smiles_emb_dim=smiles_emb_dim,
poi_emb_dim=1024,
e3_emb_dim=1024,
cell_emb_dim=768,
batch_size=batch_size,
join_embeddings=join_embeddings,
dropout=dropout,
learning_rate=learning_rate,
apply_scaling=apply_scaling,
train_dataset=train_ds,
val_dataset=val_ds,
test_dataset=test_ds if test_df is not None else None,
disabled_embeddings=disabled_embeddings,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
trainer.fit(model)
metrics = trainer.validate(model, verbose=False)[0]
if test_df is not None:
test_metrics = trainer.test(model, verbose=False)[0]
metrics.update(test_metrics)
return model, trainer, metrics
# Setup hyperparameter optimization:
def objective(
trial: optuna.Trial,
train_df: pd.DataFrame,
val_df: pd.DataFrame,
hidden_dim_options: List[int] = [256, 512, 768],
batch_size_options: List[int] = [8, 16, 32],
learning_rate_options: Tuple[float, float] = (1e-5, 1e-3),
smote_k_neighbors_options: List[int] = list(range(3, 16)),
dropout_options: Tuple[float, float] = (0.1, 0.5),
fast_dev_run: bool = False,
active_label: str = 'Active',
disabled_embeddings: List[str] = [],
) -> float:
""" Objective function for hyperparameter optimization.
Args:
trial (optuna.Trial): The Optuna trial object.
train_df (pd.DataFrame): The training set.
val_df (pd.DataFrame): The validation set.
hidden_dim_options (List[int]): The hidden dimension options.
batch_size_options (List[int]): The batch size options.
learning_rate_options (Tuple[float, float]): The learning rate options.
smote_k_neighbors_options (List[int]): The SMOTE k neighbors options.
dropout_options (Tuple[float, float]): The dropout options.
fast_dev_run (bool): Whether to run a fast development run.
active_label (str): The active label column.
disabled_embeddings (List[str]): The list of disabled embeddings.
"""
# Generate the hyperparameters
hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
batch_size = trial.suggest_categorical('batch_size', batch_size_options)
learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
join_embeddings = trial.suggest_categorical('join_embeddings', ['beginning', 'concat', 'sum'])
smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
use_smote = trial.suggest_categorical('use_smote', [True, False])
apply_scaling = trial.suggest_categorical('apply_scaling', [True, False])
dropout = trial.suggest_float('dropout', *dropout_options)
# Train the model with the current set of hyperparameters
_, _, metrics = train_model(
train_df,
val_df,
hidden_dim=hidden_dim,
batch_size=batch_size,
join_embeddings=join_embeddings,
learning_rate=learning_rate,
dropout=dropout,
max_epochs=100,
smote_k_neighbors=smote_k_neighbors,
apply_scaling=apply_scaling,
use_smote=use_smote,
use_logger=False,
fast_dev_run=fast_dev_run,
active_label=active_label,
disabled_embeddings=disabled_embeddings,
)
# Metrics is a dictionary containing at least the validation loss
val_loss = metrics['val_loss']
val_acc = metrics['val_acc']
val_roc_auc = metrics['val_roc_auc']
# Optuna aims to minimize the objective
return val_loss - val_acc - val_roc_auc
def hyperparameter_tuning_and_training(
train_df: pd.DataFrame,
val_df: pd.DataFrame,
test_df: pd.DataFrame,
fast_dev_run: bool = False,
n_trials: int = 50,
logger_name: str = 'protac_hparam_search',
active_label: str = 'Active',
disabled_embeddings: List[str] = [],
study_filename: Optional[str] = None,
) -> tuple:
""" Hyperparameter tuning and training of a PROTAC model.
Args:
train_df (pd.DataFrame): The training set.
val_df (pd.DataFrame): The validation set.
test_df (pd.DataFrame): The test set.
fast_dev_run (bool): Whether to run a fast development run.
n_trials (int): The number of hyperparameter optimization trials.
logger_name (str): The name of the logger.
active_label (str): The active label column.
disabled_embeddings (List[str]): The list of disabled embeddings.
Returns:
tuple: The trained model, the trainer, and the best metrics.
"""
# Define the search space
hidden_dim_options = [256, 512, 768]
batch_size_options = [8, 16, 32]
learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
smote_k_neighbors_options = list(range(3, 16))
# Set the verbosity of Optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)
# Create an Optuna study object
sampler = TPESampler(seed=42, multivariate=True)
study = optuna.create_study(direction='minimize', sampler=sampler)
study_loaded = False
if study_filename:
if os.path.exists(study_filename):
study = joblib.load(study_filename)
study_loaded = True
print(f'Loaded study from {study_filename}')
if not study_loaded:
study.optimize(
lambda trial: objective(
trial,
train_df,
val_df,
hidden_dim_options=hidden_dim_options,
batch_size_options=batch_size_options,
learning_rate_options=learning_rate_options,
smote_k_neighbors_options=smote_k_neighbors_options,
fast_dev_run=fast_dev_run,
active_label=active_label,
disabled_embeddings=disabled_embeddings,
),
n_trials=n_trials,
)
if study_filename:
joblib.dump(study, study_filename)
# Retrain the model with the best hyperparameters
model, trainer, metrics = train_model(
train_df,
val_df,
test_df,
use_logger=True,
logger_name=logger_name,
fast_dev_run=fast_dev_run,
active_label=active_label,
disabled_embeddings=disabled_embeddings,
**study.best_params,
)
# Report the best hyperparameters found
metrics.update({f'hparam_{k}': v for k, v in study.best_params.items()})
# Return the best metrics
return model, trainer, metrics
def main(
active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
n_trials: int = 50,
fast_dev_run: bool = False,
test_split: float = 0.2,
cv_n_splits: int = 5,
):
""" Train a PROTAC model using the given datasets and hyperparameters.
Args:
use_ored_activity (bool): Whether to use the 'Active - OR' column.
n_trials (int): The number of hyperparameter optimization trials.
n_splits (int): The number of cross-validation splits.
fast_dev_run (bool): Whether to run a fast development run.
"""
## Set the Column to Predict
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
# Get Dmax_threshold from the active_col
Dmax_threshold = float(active_col.split('Dmax')[1].split(',')[0].strip('(').strip(')').strip())
pDC50_threshold = float(active_col.split('pDC50')[1].strip('(').strip(')').strip())
protac_df[active_col] = protac_df.apply(
lambda x: is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1
)
## Test Sets
test_indeces = {}
### Random Split
# Randomly select 20% of the active PROTACs as the test set
active_df = protac_df[protac_df[active_col].notna()].copy()
test_df = active_df.sample(frac=test_split, random_state=42)
test_indeces['random'] = test_df.index
### E3-based Split
encoder = OrdinalEncoder()
protac_df['E3 Group'] = encoder.fit_transform(protac_df[['E3 Ligase']]).astype(int)
active_df = protac_df[protac_df[active_col].notna()].copy()
test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
test_indeces['e3_ligase'] = test_df.index
### Tanimoto-based Split
n_bins_tanimoto = 200
tanimoto_groups = pd.cut(protac_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
encoder = OrdinalEncoder()
protac_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
active_df = protac_df[protac_df[active_col].notna()].copy()
# Sort the groups so that samples with the highest tanimoto similarity,
# i.e., the "less similar" ones, are placed in the test set first
tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
test_df = []
# For each group, get the number of active and inactive entries. Then, add those
# entries to the test_df if: 1) the test_df lenght + the group entries is less
# 20% of the active_df lenght, and 2) the percentage of True and False entries
# in the active_col in test_df is roughly 50%.
for group in tanimoto_groups:
group_df = active_df[active_df['Tanimoto Group'] == group]
if test_df == []:
test_df.append(group_df)
continue
num_entries = len(group_df)
num_active_group = group_df[active_col].sum()
num_inactive_group = num_entries - num_active_group
tmp_test_df = pd.concat(test_df)
num_entries_test = len(tmp_test_df)
num_active_test = tmp_test_df[active_col].sum()
num_inactive_test = num_entries_test - num_active_test
# Check if the group entries can be added to the test_df
if num_entries_test + num_entries < test_split * len(active_df):
# Add anything at the beggining
if num_entries_test + num_entries < test_split / 2 * len(active_df):
test_df.append(group_df)
continue
# Be more selective and make sure that the percentage of active and
# inactive is balanced
if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
test_df.append(group_df)
test_df = pd.concat(test_df)
# Save to global dictionary of test indeces
test_indeces['tanimoto'] = test_df.index
### Target-based Split
encoder = OrdinalEncoder()
protac_df['Uniprot Group'] = encoder.fit_transform(protac_df[['Uniprot']]).astype(int)
active_df = protac_df[protac_df[active_col].notna()].copy()
test_df = []
# For each group, get the number of active and inactive entries. Then, add those
# entries to the test_df if: 1) the test_df lenght + the group entries is less
# 20% of the active_df lenght, and 2) the percentage of True and False entries
# in the active_col in test_df is roughly 50%.
# Start the loop from the groups containing the smallest number of entries.
for group in reversed(active_df['Uniprot'].value_counts().index):
group_df = active_df[active_df['Uniprot'] == group]
if test_df == []:
test_df.append(group_df)
continue
num_entries = len(group_df)
num_active_group = group_df[active_col].sum()
num_inactive_group = num_entries - num_active_group
tmp_test_df = pd.concat(test_df)
num_entries_test = len(tmp_test_df)
num_active_test = tmp_test_df[active_col].sum()
num_inactive_test = num_entries_test - num_active_test
# Check if the group entries can be added to the test_df
if num_entries_test + num_entries < test_split * len(active_df):
# Add anything at the beggining
if num_entries_test + num_entries < test_split / 2 * len(active_df):
test_df.append(group_df)
continue
# Be more selective and make sure that the percentage of active and
# inactive is balanced
if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
test_df.append(group_df)
test_df = pd.concat(test_df)
# Save to global dictionary of test indeces
test_indeces['uniprot'] = test_df.index
## Cross-Validation Training
# Make directory ../reports if it does not exist
if not os.path.exists('../reports'):
os.makedirs('../reports')
report = []
for split_type, indeces in test_indeces.items():
if split_type != 'tanimoto':
continue
active_df = protac_df[protac_df[active_col].notna()].copy()
test_df = active_df.loc[indeces]
train_val_df = active_df[~active_df.index.isin(test_df.index)]
if split_type == 'random':
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
group = None
elif split_type == 'e3_ligase':
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
group = train_val_df['E3 Group'].to_numpy()
elif split_type == 'tanimoto':
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
group = train_val_df['Tanimoto Group'].to_numpy()
elif split_type == 'uniprot':
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
group = train_val_df['Uniprot Group'].to_numpy()
# Start the CV over the folds
X = train_val_df.drop(columns=active_col)
y = train_val_df[active_col].tolist()
for k, (train_index, val_index) in enumerate(kf.split(X, y, group)):
print('-' * 100)
print(f'Starting CV for group type: {split_type}, fold: {k}')
print('-' * 100)
train_df = train_val_df.iloc[train_index]
val_df = train_val_df.iloc[val_index]
leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(val_df['Uniprot'])))
leaking_smiles = list(set(train_df['Smiles']).intersection(set(val_df['Smiles'])))
stats = {
'fold': k,
'split_type': split_type,
'train_len': len(train_df),
'val_len': len(val_df),
'train_perc': len(train_df) / len(train_val_df),
'val_perc': len(val_df) / len(train_val_df),
'train_active_perc': train_df[active_col].sum() / len(train_df),
'train_inactive_perc': (len(train_df) - train_df[active_col].sum()) / len(train_df),
'val_active_perc': val_df[active_col].sum() / len(val_df),
'val_inactive_perc': (len(val_df) - val_df[active_col].sum()) / len(val_df),
'test_active_perc': test_df[active_col].sum() / len(test_df),
'test_inactive_perc': (len(test_df) - test_df[active_col].sum()) / len(test_df),
'num_leaking_uniprot': len(leaking_uniprot),
'num_leaking_smiles': len(leaking_smiles),
'train_leaking_uniprot_perc': len(train_df[train_df['Uniprot'].isin(leaking_uniprot)]) / len(train_df),
'train_leaking_smiles_perc': len(train_df[train_df['Smiles'].isin(leaking_smiles)]) / len(train_df),
}
if split_type != 'random':
stats['train_unique_groups'] = len(np.unique(group[train_index]))
stats['val_unique_groups'] = len(np.unique(group[val_index]))
# Train and evaluate the model
model, trainer, metrics = hyperparameter_tuning_and_training(
train_df,
val_df,
test_df,
fast_dev_run=fast_dev_run,
n_trials=n_trials,
logger_name=f'protac_{active_name}_{split_type}_fold_{k}_test_split_{test_split}',
active_label=active_col,
study_filename=f'../reports/study_{active_name}_{split_type}_fold_{k}_test_split_{test_split}.pkl',
)
hparams = {p.strip('hparam_'): v for p, v in stats.items() if p.startswith('hparam_')}
stats.update(metrics)
report.append(stats.copy())
del model
del trainer
# Ablation study: disable embeddings at a time
for disabled_embeddings in [['e3'], ['poi'], ['cell'], ['smiles'], ['e3', 'cell'], ['poi', 'e3', 'cell']]:
print('-' * 100)
print(f'Ablation study with disabled embeddings: {disabled_embeddings}')
print('-' * 100)
stats['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
model, trainer, metrics = train_model(
train_df,
val_df,
test_df,
fast_dev_run=fast_dev_run,
logger_name=f'protac_{active_name}_{split_type}_fold_{k}_disabled-{"-".join(disabled_embeddings)}',
active_label=active_col,
disabled_embeddings=disabled_embeddings,
**hparams,
)
stats.update(metrics)
report.append(stats.copy())
del model
del trainer
report_df = pd.DataFrame(report)
report_df.to_csv(
f'../reports/cv_report_hparam_search_{cv_n_splits}-splits_{active_name}_test_split_{test_split}_tanimoto.csv',
index=False,
)
if __name__ == '__main__':
cli = CLI(main)