ribesstefano
commited on
Commit
•
91692a4
1
Parent(s):
022416b
Polished hparam code and added ablation studies
Browse files
notebooks/protac_degradation_predictor.py
CHANGED
@@ -1,34 +1,34 @@
|
|
1 |
import optuna
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import pandas as pd
|
|
|
|
|
|
|
3 |
from rdkit import Chem
|
4 |
from rdkit.Chem import AllChem
|
5 |
from rdkit import DataStructs
|
6 |
from collections import defaultdict
|
7 |
-
|
8 |
-
import
|
9 |
-
import numpy as np
|
10 |
from tqdm.auto import tqdm
|
11 |
-
|
12 |
-
import os
|
13 |
-
import urllib.request
|
14 |
-
|
15 |
-
from sklearn.preprocessing import StandardScaler
|
16 |
-
|
17 |
-
# ## Define Torch Dataset
|
18 |
-
|
19 |
from imblearn.over_sampling import SMOTE, ADASYN
|
20 |
-
from sklearn.preprocessing import LabelEncoder
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
|
26 |
-
import warnings
|
27 |
import torch
|
28 |
import torch.nn as nn
|
29 |
import torch.nn.functional as F
|
30 |
import torch.optim as optim
|
31 |
import pytorch_lightning as pl
|
|
|
32 |
from torchmetrics import (
|
33 |
Accuracy,
|
34 |
AUROC,
|
@@ -38,13 +38,11 @@ from torchmetrics import (
|
|
38 |
)
|
39 |
from torchmetrics import MetricCollection
|
40 |
|
41 |
-
import pickle
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
)
|
47 |
-
from sklearn.preprocessing import OrdinalEncoder
|
48 |
|
49 |
|
50 |
protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
|
@@ -71,8 +69,6 @@ print(f'Number of compounds in test set: {len(unlabeled_df)}')
|
|
71 |
# Protein embeddings downloaded from [Uniprot](https://www.uniprot.org/help/embeddings).
|
72 |
#
|
73 |
# Please note that running the following cell the first time might take a while.
|
74 |
-
|
75 |
-
|
76 |
download_link = "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/embeddings/UP000005640_9606/per-protein.h5"
|
77 |
embeddings_path = "../data/uniprot2embedding.h5"
|
78 |
if not os.path.exists(embeddings_path):
|
@@ -82,26 +78,17 @@ if not os.path.exists(embeddings_path):
|
|
82 |
|
83 |
protein_embeddings = {}
|
84 |
with h5py.File("../data/uniprot2embedding.h5", "r") as file:
|
85 |
-
print(f"number of entries: {len(file.items()):,}")
|
86 |
uniprots = protac_df['Uniprot'].unique().tolist()
|
87 |
uniprots += protac_df['E3 Ligase Uniprot'].unique().tolist()
|
88 |
for i, sequence_id in tqdm(enumerate(uniprots), desc='Loading protein embeddings'):
|
89 |
try:
|
90 |
embedding = file[sequence_id][:]
|
91 |
protein_embeddings[sequence_id] = np.array(embedding)
|
92 |
-
if i < 10:
|
93 |
-
print(
|
94 |
-
f"\tid: {sequence_id}, "
|
95 |
-
f"\tembeddings shape: {embedding.shape}, "
|
96 |
-
f"\tembeddings mean: {np.array(embedding).mean()}"
|
97 |
-
)
|
98 |
except KeyError:
|
99 |
print(f'KeyError for {sequence_id}')
|
100 |
protein_embeddings[sequence_id] = np.zeros((1024,))
|
101 |
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
cell2embedding_filepath = '../data/cell2embedding.pkl'
|
106 |
with open(cell2embedding_filepath, 'rb') as f:
|
107 |
cell2embedding = pickle.load(f)
|
@@ -113,8 +100,7 @@ for cell_line in protac_df['Cell Line Identifier'].unique():
|
|
113 |
if cell_line not in cell2embedding:
|
114 |
cell2embedding[cell_line] = np.zeros(emb_shape)
|
115 |
|
116 |
-
|
117 |
-
|
118 |
morgan_fpgen = AllChem.GetMorganGenerator(
|
119 |
radius=15,
|
120 |
fpSize=1024,
|
@@ -142,7 +128,6 @@ print(f'Number of SMILES with overlapping fingerprints: {len(overlapping_smiles)
|
|
142 |
print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smiles"].isin(overlapping_smiles)])}')
|
143 |
|
144 |
# Get the pair-wise tanimoto similarity between the PROTAC fingerprints
|
145 |
-
|
146 |
tanimoto_matrix = defaultdict(list)
|
147 |
for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
|
148 |
fp1 = smiles2fp[smiles1]
|
@@ -158,11 +143,6 @@ protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
|
|
158 |
|
159 |
smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
|
160 |
|
161 |
-
# ## Set the Column to Predict
|
162 |
-
|
163 |
-
active_col = 'Active'
|
164 |
-
# active_col = 'Active - OR'
|
165 |
-
|
166 |
|
167 |
class PROTAC_Dataset(Dataset):
|
168 |
def __init__(
|
@@ -274,25 +254,24 @@ class PROTAC_Dataset(Dataset):
|
|
274 |
}
|
275 |
return elem
|
276 |
|
277 |
-
# Ignore UserWarning from PyTorch Lightning
|
278 |
-
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
279 |
|
280 |
class PROTAC_Model(pl.LightningModule):
|
281 |
|
282 |
def __init__(
|
283 |
self,
|
284 |
-
hidden_dim,
|
285 |
-
smiles_emb_dim=1024,
|
286 |
-
poi_emb_dim=1024,
|
287 |
-
e3_emb_dim=1024,
|
288 |
-
cell_emb_dim=768,
|
289 |
-
batch_size=32,
|
290 |
-
learning_rate=1e-3,
|
291 |
-
dropout=0.2,
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
|
|
296 |
):
|
297 |
super().__init__()
|
298 |
self.poi_emb_dim = poi_emb_dim
|
@@ -302,6 +281,7 @@ class PROTAC_Model(pl.LightningModule):
|
|
302 |
self.hidden_dim = hidden_dim
|
303 |
self.batch_size = batch_size
|
304 |
self.learning_rate = learning_rate
|
|
|
305 |
self.train_dataset = train_dataset
|
306 |
self.val_dataset = val_dataset
|
307 |
self.test_dataset = test_dataset
|
@@ -318,48 +298,18 @@ class PROTAC_Model(pl.LightningModule):
|
|
318 |
|
319 |
if 'poi' not in self.disabled_embeddings:
|
320 |
self.poi_emb = nn.Linear(poi_emb_dim, hidden_dim)
|
321 |
-
# # Set the POI surrogate model as a Sequential model
|
322 |
-
# self.poi_emb = nn.Sequential(
|
323 |
-
# nn.Linear(poi_emb_dim, hidden_dim),
|
324 |
-
# nn.GELU(),
|
325 |
-
# nn.Dropout(p=dropout),
|
326 |
-
# nn.Linear(hidden_dim, hidden_dim),
|
327 |
-
# # nn.ReLU(),
|
328 |
-
# # nn.Dropout(p=dropout),
|
329 |
-
# )
|
330 |
if 'e3' not in self.disabled_embeddings:
|
331 |
self.e3_emb = nn.Linear(e3_emb_dim, hidden_dim)
|
332 |
-
# self.e3_emb = nn.Sequential(
|
333 |
-
# nn.Linear(e3_emb_dim, hidden_dim),
|
334 |
-
# # nn.ReLU(),
|
335 |
-
# nn.Dropout(p=dropout),
|
336 |
-
# # nn.Linear(hidden_dim, hidden_dim),
|
337 |
-
# # nn.ReLU(),
|
338 |
-
# # nn.Dropout(p=dropout),
|
339 |
-
# )
|
340 |
if 'cell' not in self.disabled_embeddings:
|
341 |
self.cell_emb = nn.Linear(cell_emb_dim, hidden_dim)
|
342 |
-
# self.cell_emb = nn.Sequential(
|
343 |
-
# nn.Linear(cell_emb_dim, hidden_dim),
|
344 |
-
# # nn.ReLU(),
|
345 |
-
# nn.Dropout(p=dropout),
|
346 |
-
# # nn.Linear(hidden_dim, hidden_dim),
|
347 |
-
# # nn.ReLU(),
|
348 |
-
# # nn.Dropout(p=dropout),
|
349 |
-
# )
|
350 |
if 'smiles' not in self.disabled_embeddings:
|
351 |
self.smiles_emb = nn.Linear(smiles_emb_dim, hidden_dim)
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
# # nn.Dropout(p=dropout),
|
359 |
-
# )
|
360 |
-
|
361 |
-
self.fc1 = nn.Linear(
|
362 |
-
hidden_dim * (4 - len(self.disabled_embeddings)), hidden_dim)
|
363 |
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
364 |
self.fc3 = nn.Linear(hidden_dim, 1)
|
365 |
|
@@ -394,9 +344,16 @@ class PROTAC_Model(pl.LightningModule):
|
|
394 |
embeddings.append(self.cell_emb(cell_emb))
|
395 |
if 'smiles' not in self.disabled_embeddings:
|
396 |
embeddings.append(self.smiles_emb(smiles_emb))
|
397 |
-
|
398 |
-
|
399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
x = self.fc3(x)
|
401 |
return x
|
402 |
|
@@ -468,178 +425,6 @@ class PROTAC_Model(pl.LightningModule):
|
|
468 |
shuffle=False,
|
469 |
)
|
470 |
|
471 |
-
# ## Test Sets
|
472 |
-
|
473 |
-
# We want a different test set per Cross-Validation (CV) experiment (see further down). We are interested in three scenarios:
|
474 |
-
# * Randomly splitting the data into training and test sets. Hence, the test st shall contain unique SMILES and Uniprots
|
475 |
-
# * Splitting the data according to their Uniprot. Hence, the test set shall contain unique Uniprots
|
476 |
-
# * Splitting the data according to their SMILES, _i.e._, the test set shall contain unique SMILES
|
477 |
-
|
478 |
-
test_indeces = {}
|
479 |
-
|
480 |
-
# Isolating the unique SMILES and Uniprots:
|
481 |
-
|
482 |
-
active_df = protac_df[protac_df[active_col].notna()].copy()
|
483 |
-
|
484 |
-
# Get the unique SMILES and Uniprot
|
485 |
-
unique_smiles = active_df['Smiles'].value_counts() == 1
|
486 |
-
unique_uniprot = active_df['Uniprot'].value_counts() == 1
|
487 |
-
print(f'Number of unique SMILES: {unique_smiles.sum()}')
|
488 |
-
print(f'Number of unique Uniprot: {unique_uniprot.sum()}')
|
489 |
-
# Sample 1% of the len(active_df) from unique SMILES and Uniprot and get the
|
490 |
-
# indices for a test set
|
491 |
-
n = int(0.05 * len(active_df)) // 2
|
492 |
-
unique_smiles = unique_smiles[unique_smiles].sample(n=n, random_state=42)
|
493 |
-
# unique_uniprot = unique_uniprot[unique_uniprot].sample(n=, random_state=42)
|
494 |
-
unique_indices = active_df[
|
495 |
-
active_df['Smiles'].isin(unique_smiles.index) &
|
496 |
-
active_df['Uniprot'].isin(unique_uniprot.index)
|
497 |
-
].index
|
498 |
-
print(f'Number of unique indices: {len(unique_indices)} ({len(unique_indices) / len(active_df):.1%})')
|
499 |
-
|
500 |
-
test_indeces['random'] = unique_indices
|
501 |
-
|
502 |
-
# # Get the test set
|
503 |
-
# test_df = active_df.loc[unique_indices]
|
504 |
-
# # Bar plot of the test Active distribution as percentage
|
505 |
-
# test_df['Active'].value_counts(normalize=True).plot(kind='bar')
|
506 |
-
# plt.title('Test set Active distribution')
|
507 |
-
# plt.show()
|
508 |
-
# # Bar plot of the test Active - OR distribution as percentage
|
509 |
-
# test_df['Active - OR'].value_counts(normalize=True).plot(kind='bar')
|
510 |
-
# plt.title('Test set Active - OR distribution')
|
511 |
-
# plt.show()
|
512 |
-
|
513 |
-
# Isolating the unique Uniprots:
|
514 |
-
|
515 |
-
active_df = protac_df[protac_df[active_col].notna()].copy()
|
516 |
-
|
517 |
-
unique_uniprot = active_df['Uniprot'].value_counts() == 1
|
518 |
-
print(f'Number of unique Uniprot: {unique_uniprot.sum()}')
|
519 |
-
|
520 |
-
# NOTE: Since they are very few, all unique Uniprot will be used as test set.
|
521 |
-
# Get the indices for a test set
|
522 |
-
unique_indices = active_df[active_df['Uniprot'].isin(unique_uniprot.index)].index
|
523 |
-
|
524 |
-
|
525 |
-
test_indeces['uniprot'] = unique_indices
|
526 |
-
print(f'Number of unique indices: {len(unique_indices)} ({len(unique_indices) / len(active_df):.1%})')
|
527 |
-
|
528 |
-
# DEPRECATED: The following results in a too Before starting any training, we isolate a small group of test data. Each element in the test set is selected so that all the following conditions are met:
|
529 |
-
# * its SMILES is unique
|
530 |
-
# * its POI is unique
|
531 |
-
# * its (SMILES, POI) pair is unique
|
532 |
-
|
533 |
-
active_df = protac_df[protac_df[active_col].notna()]
|
534 |
-
|
535 |
-
# Find the samples that:
|
536 |
-
# * have their SMILES appearing only once in the dataframe
|
537 |
-
# * have their Uniprot appearing only once in the dataframe
|
538 |
-
# * have their (Smiles, Uniprot) pair appearing only once in the dataframe
|
539 |
-
unique_smiles = active_df['Smiles'].value_counts() == 1
|
540 |
-
unique_uniprot = active_df['Uniprot'].value_counts() == 1
|
541 |
-
unique_smiles_uniprot = active_df.groupby(['Smiles', 'Uniprot']).size() == 1
|
542 |
-
|
543 |
-
# Get the indices of the unique samples
|
544 |
-
unique_smiles_idx = active_df['Smiles'].map(unique_smiles)
|
545 |
-
unique_uniprot_idx = active_df['Uniprot'].map(unique_uniprot)
|
546 |
-
unique_smiles_uniprot_idx = active_df.set_index(['Smiles', 'Uniprot']).index.map(unique_smiles_uniprot)
|
547 |
-
|
548 |
-
# Cross the indices to get the unique samples
|
549 |
-
# unique_samples = active_df[unique_smiles_idx & unique_uniprot_idx & unique_smiles_uniprot_idx].index
|
550 |
-
unique_samples = active_df[unique_smiles_idx & unique_uniprot_idx].index
|
551 |
-
test_df = active_df.loc[unique_samples]
|
552 |
-
|
553 |
-
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
554 |
-
|
555 |
-
# ## Cross-Validation Training
|
556 |
-
|
557 |
-
# Cross validation training with 5 splits. The split operation is done in three different ways:
|
558 |
-
#
|
559 |
-
# * Random split
|
560 |
-
# * POI-wise: some POIs never in both splits
|
561 |
-
# * Least Tanimoto similarity PROTAC-wise
|
562 |
-
|
563 |
-
# ### Plotting CV Folds
|
564 |
-
|
565 |
-
|
566 |
-
# NOTE: When set to 60, it will result in 29 groups, with nice distributions of
|
567 |
-
# the number of unique groups in the train and validation sets, together with
|
568 |
-
# the number of active and inactive PROTACs.
|
569 |
-
n_bins_tanimoto = 60 if active_col == 'Active' else 400
|
570 |
-
n_splits = 5
|
571 |
-
# The train and validation sets will be created from the active PROTACs only,
|
572 |
-
# i.e., the ones with 'Active' column not NaN, and that are NOT in the test set
|
573 |
-
active_df = protac_df[protac_df[active_col].notna()]
|
574 |
-
train_val_df = active_df[~active_df.index.isin(test_df.index)].copy()
|
575 |
-
|
576 |
-
# Make three groups for CV:
|
577 |
-
# * Random split
|
578 |
-
# * Split by Uniprot (POI)
|
579 |
-
# * Split by least tanimoto similarity PROTAC-wise
|
580 |
-
groups = [
|
581 |
-
'random',
|
582 |
-
'uniprot',
|
583 |
-
'tanimoto',
|
584 |
-
]
|
585 |
-
for group_type in groups:
|
586 |
-
if group_type == 'random':
|
587 |
-
kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
588 |
-
groups = None
|
589 |
-
elif group_type == 'uniprot':
|
590 |
-
# Split by Uniprot
|
591 |
-
kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
592 |
-
encoder = OrdinalEncoder()
|
593 |
-
groups = encoder.fit_transform(train_val_df['Uniprot'].values.reshape(-1, 1))
|
594 |
-
print(f'Number of unique groups: {len(encoder.categories_[0])}')
|
595 |
-
elif group_type == 'tanimoto':
|
596 |
-
# Split by tanimoto similarity, i.e., group_type PROTACs with similar Avg Tanimoto
|
597 |
-
kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
598 |
-
tanimoto_groups = pd.cut(train_val_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
|
599 |
-
encoder = OrdinalEncoder()
|
600 |
-
groups = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1))
|
601 |
-
print(f'Number of unique groups: {len(encoder.categories_[0])}')
|
602 |
-
|
603 |
-
|
604 |
-
X = train_val_df.drop(columns=active_col)
|
605 |
-
y = train_val_df[active_col].tolist()
|
606 |
-
|
607 |
-
# print(f'Group: {group_type}')
|
608 |
-
# fig, ax = plt.subplots(figsize=(6, 3))
|
609 |
-
# plot_cv_indices(kf, X=X, y=y, group=groups, ax=ax, n_splits=n_splits)
|
610 |
-
# plt.tight_layout()
|
611 |
-
# plt.show()
|
612 |
-
|
613 |
-
stats = []
|
614 |
-
for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
|
615 |
-
train_df = train_val_df.iloc[train_index]
|
616 |
-
val_df = train_val_df.iloc[val_index]
|
617 |
-
stat = {
|
618 |
-
'fold': k,
|
619 |
-
'train_len': len(train_df),
|
620 |
-
'val_len': len(val_df),
|
621 |
-
'train_perc': len(train_df) / len(train_val_df),
|
622 |
-
'val_perc': len(val_df) / len(train_val_df),
|
623 |
-
'train_active (%)': train_df[active_col].sum() / len(train_df) * 100,
|
624 |
-
'train_inactive (%)': (len(train_df) - train_df[active_col].sum()) / len(train_df) * 100,
|
625 |
-
'val_active (%)': val_df[active_col].sum() / len(val_df) * 100,
|
626 |
-
'val_inactive (%)': (len(val_df) - val_df[active_col].sum()) / len(val_df) * 100,
|
627 |
-
'num_leaking_uniprot': len(set(train_df['Uniprot']).intersection(set(val_df['Uniprot']))),
|
628 |
-
'num_leaking_smiles': len(set(train_df['Smiles']).intersection(set(val_df['Smiles']))),
|
629 |
-
}
|
630 |
-
if group_type != 'random':
|
631 |
-
stat['train_unique_groups'] = len(np.unique(groups[train_index]))
|
632 |
-
stat['val_unique_groups'] = len(np.unique(groups[val_index]))
|
633 |
-
stats.append(stat)
|
634 |
-
print('-' * 120)
|
635 |
-
|
636 |
-
# ### Run CV
|
637 |
-
|
638 |
-
import warnings
|
639 |
-
|
640 |
-
# Seed everything in pytorch lightning
|
641 |
-
pl.seed_everything(42)
|
642 |
-
|
643 |
|
644 |
def train_model(
|
645 |
train_df,
|
@@ -650,8 +435,9 @@ def train_model(
|
|
650 |
learning_rate=2e-5,
|
651 |
max_epochs=50,
|
652 |
smiles_emb_dim=1024,
|
|
|
653 |
smote_k_neighbors=5,
|
654 |
-
use_ored_activity=
|
655 |
fast_dev_run=False,
|
656 |
use_logger=True,
|
657 |
logger_name='protac',
|
@@ -669,7 +455,7 @@ def train_model(
|
|
669 |
max_epochs (int): The maximum number of epochs.
|
670 |
smiles_emb_dim (int): The dimension of the SMILES embeddings.
|
671 |
smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
|
672 |
-
use_ored_activity (bool): Whether to use the ORED activity column.
|
673 |
fast_dev_run (bool): Whether to run a fast development run.
|
674 |
disabled_embeddings (list): The list of disabled embeddings.
|
675 |
|
@@ -737,6 +523,7 @@ def train_model(
|
|
737 |
cell_emb_dim=768,
|
738 |
batch_size=batch_size,
|
739 |
learning_rate=learning_rate,
|
|
|
740 |
train_dataset=train_ds,
|
741 |
val_dataset=val_ds,
|
742 |
test_dataset=test_ds if test_df is not None else None,
|
@@ -763,12 +550,15 @@ def objective(
|
|
763 |
max_epochs_options,
|
764 |
smote_k_neighbors_options,
|
765 |
fast_dev_run=False,
|
|
|
|
|
766 |
) -> float:
|
767 |
# Generate the hyperparameters
|
768 |
hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
|
769 |
batch_size = trial.suggest_categorical('batch_size', batch_size_options)
|
770 |
learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
|
771 |
max_epochs = trial.suggest_categorical('max_epochs', max_epochs_options)
|
|
|
772 |
smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
|
773 |
|
774 |
# Train the model with the current set of hyperparameters
|
@@ -777,11 +567,14 @@ def objective(
|
|
777 |
val_df,
|
778 |
hidden_dim=hidden_dim,
|
779 |
batch_size=batch_size,
|
|
|
780 |
learning_rate=learning_rate,
|
781 |
max_epochs=max_epochs,
|
782 |
smote_k_neighbors=smote_k_neighbors,
|
783 |
use_logger=False,
|
784 |
fast_dev_run=fast_dev_run,
|
|
|
|
|
785 |
)
|
786 |
|
787 |
# Metrics is a dictionary containing at least the validation loss
|
@@ -800,6 +593,8 @@ def hyperparameter_tuning_and_training(
|
|
800 |
fast_dev_run=False,
|
801 |
n_trials=20,
|
802 |
logger_name='protac_hparam_search',
|
|
|
|
|
803 |
) -> tuple:
|
804 |
""" Hyperparameter tuning and training of a PROTAC model.
|
805 |
|
@@ -819,9 +614,13 @@ def hyperparameter_tuning_and_training(
|
|
819 |
max_epochs_options = [10, 20, 50]
|
820 |
smote_k_neighbors_options = list(range(3, 16))
|
821 |
|
|
|
|
|
822 |
# Create an Optuna study object
|
823 |
-
|
824 |
-
study.
|
|
|
|
|
825 |
trial,
|
826 |
train_df,
|
827 |
val_df,
|
@@ -830,117 +629,186 @@ def hyperparameter_tuning_and_training(
|
|
830 |
learning_rate_options,
|
831 |
max_epochs_options,
|
832 |
smote_k_neighbors_options=smote_k_neighbors_options,
|
833 |
-
fast_dev_run=fast_dev_run,
|
|
|
|
|
|
|
834 |
n_trials=n_trials,
|
835 |
)
|
836 |
|
837 |
-
# Retrieve the best hyperparameters
|
838 |
-
best_params = study.best_params
|
839 |
-
best_hidden_dim = best_params['hidden_dim']
|
840 |
-
best_batch_size = best_params['batch_size']
|
841 |
-
best_learning_rate = best_params['learning_rate']
|
842 |
-
best_max_epochs = best_params['max_epochs']
|
843 |
-
best_smote_k_neighbors = best_params['smote_k_neighbors']
|
844 |
-
|
845 |
# Retrain the model with the best hyperparameters
|
846 |
model, trainer, metrics = train_model(
|
847 |
train_df,
|
848 |
val_df,
|
849 |
test_df,
|
850 |
-
hidden_dim=best_hidden_dim,
|
851 |
-
batch_size=best_batch_size,
|
852 |
-
learning_rate=best_learning_rate,
|
853 |
-
max_epochs=best_max_epochs,
|
854 |
use_logger=True,
|
855 |
logger_name=logger_name,
|
856 |
fast_dev_run=fast_dev_run,
|
|
|
|
|
|
|
857 |
)
|
858 |
|
859 |
# Report the best hyperparameters found
|
860 |
-
metrics
|
861 |
-
metrics['batch_size'] = best_batch_size
|
862 |
-
metrics['learning_rate'] = best_learning_rate
|
863 |
-
metrics['max_epochs'] = best_max_epochs
|
864 |
-
metrics['smote_k_neighbors'] = best_smote_k_neighbors
|
865 |
|
866 |
# Return the best metrics
|
867 |
return model, trainer, metrics
|
868 |
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
|
899 |
-
|
900 |
-
|
901 |
-
|
902 |
-
|
903 |
-
|
904 |
-
|
905 |
-
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
|
914 |
-
|
915 |
-
|
916 |
-
|
917 |
-
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
|
923 |
-
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
|
928 |
-
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
|
933 |
-
|
934 |
-
|
935 |
-
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import optuna
|
2 |
+
from optuna.samplers import TPESampler
|
3 |
+
import h5py
|
4 |
+
import os
|
5 |
+
import pickle
|
6 |
+
import warnings
|
7 |
+
import logging
|
8 |
import pandas as pd
|
9 |
+
import numpy as np
|
10 |
+
import urllib.request
|
11 |
+
|
12 |
from rdkit import Chem
|
13 |
from rdkit.Chem import AllChem
|
14 |
from rdkit import DataStructs
|
15 |
from collections import defaultdict
|
16 |
+
from typing import Literal
|
17 |
+
from jsonargparse import CLI
|
|
|
18 |
from tqdm.auto import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
from imblearn.over_sampling import SMOTE, ADASYN
|
20 |
+
from sklearn.preprocessing import OrdinalEncoder, StandardScaler, LabelEncoder
|
21 |
+
from sklearn.model_selection import (
|
22 |
+
StratifiedKFold,
|
23 |
+
StratifiedGroupKFold,
|
24 |
+
)
|
25 |
|
|
|
26 |
import torch
|
27 |
import torch.nn as nn
|
28 |
import torch.nn.functional as F
|
29 |
import torch.optim as optim
|
30 |
import pytorch_lightning as pl
|
31 |
+
from torch.utils.data import Dataset, DataLoader
|
32 |
from torchmetrics import (
|
33 |
Accuracy,
|
34 |
AUROC,
|
|
|
38 |
)
|
39 |
from torchmetrics import MetricCollection
|
40 |
|
|
|
41 |
|
42 |
+
# Ignore UserWarning from Matplotlib
|
43 |
+
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
44 |
+
# Ignore UserWarning from PyTorch Lightning
|
45 |
+
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
|
|
46 |
|
47 |
|
48 |
protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
|
|
|
69 |
# Protein embeddings downloaded from [Uniprot](https://www.uniprot.org/help/embeddings).
|
70 |
#
|
71 |
# Please note that running the following cell the first time might take a while.
|
|
|
|
|
72 |
download_link = "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/embeddings/UP000005640_9606/per-protein.h5"
|
73 |
embeddings_path = "../data/uniprot2embedding.h5"
|
74 |
if not os.path.exists(embeddings_path):
|
|
|
78 |
|
79 |
protein_embeddings = {}
|
80 |
with h5py.File("../data/uniprot2embedding.h5", "r") as file:
|
|
|
81 |
uniprots = protac_df['Uniprot'].unique().tolist()
|
82 |
uniprots += protac_df['E3 Ligase Uniprot'].unique().tolist()
|
83 |
for i, sequence_id in tqdm(enumerate(uniprots), desc='Loading protein embeddings'):
|
84 |
try:
|
85 |
embedding = file[sequence_id][:]
|
86 |
protein_embeddings[sequence_id] = np.array(embedding)
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
except KeyError:
|
88 |
print(f'KeyError for {sequence_id}')
|
89 |
protein_embeddings[sequence_id] = np.zeros((1024,))
|
90 |
|
91 |
+
## Load Cell Embeddings
|
|
|
|
|
92 |
cell2embedding_filepath = '../data/cell2embedding.pkl'
|
93 |
with open(cell2embedding_filepath, 'rb') as f:
|
94 |
cell2embedding = pickle.load(f)
|
|
|
100 |
if cell_line not in cell2embedding:
|
101 |
cell2embedding[cell_line] = np.zeros(emb_shape)
|
102 |
|
103 |
+
## Precompute Molecular Fingerprints
|
|
|
104 |
morgan_fpgen = AllChem.GetMorganGenerator(
|
105 |
radius=15,
|
106 |
fpSize=1024,
|
|
|
128 |
print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smiles"].isin(overlapping_smiles)])}')
|
129 |
|
130 |
# Get the pair-wise tanimoto similarity between the PROTAC fingerprints
|
|
|
131 |
tanimoto_matrix = defaultdict(list)
|
132 |
for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
|
133 |
fp1 = smiles2fp[smiles1]
|
|
|
143 |
|
144 |
smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
|
145 |
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
class PROTAC_Dataset(Dataset):
|
148 |
def __init__(
|
|
|
254 |
}
|
255 |
return elem
|
256 |
|
|
|
|
|
257 |
|
258 |
class PROTAC_Model(pl.LightningModule):
|
259 |
|
260 |
def __init__(
|
261 |
self,
|
262 |
+
hidden_dim: int,
|
263 |
+
smiles_emb_dim: int = 1024,
|
264 |
+
poi_emb_dim: int = 1024,
|
265 |
+
e3_emb_dim: int = 1024,
|
266 |
+
cell_emb_dim: int = 768,
|
267 |
+
batch_size: int = 32,
|
268 |
+
learning_rate: float = 1e-3,
|
269 |
+
dropout: float = 0.2,
|
270 |
+
join_embeddings: Literal['concat', 'sum'] = 'concat',
|
271 |
+
train_dataset: PROTAC_Dataset = None,
|
272 |
+
val_dataset: PROTAC_Dataset = None,
|
273 |
+
test_dataset: PROTAC_Dataset = None,
|
274 |
+
disabled_embeddings: list = [],
|
275 |
):
|
276 |
super().__init__()
|
277 |
self.poi_emb_dim = poi_emb_dim
|
|
|
281 |
self.hidden_dim = hidden_dim
|
282 |
self.batch_size = batch_size
|
283 |
self.learning_rate = learning_rate
|
284 |
+
self.join_embeddings = join_embeddings
|
285 |
self.train_dataset = train_dataset
|
286 |
self.val_dataset = val_dataset
|
287 |
self.test_dataset = test_dataset
|
|
|
298 |
|
299 |
if 'poi' not in self.disabled_embeddings:
|
300 |
self.poi_emb = nn.Linear(poi_emb_dim, hidden_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
if 'e3' not in self.disabled_embeddings:
|
302 |
self.e3_emb = nn.Linear(e3_emb_dim, hidden_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
if 'cell' not in self.disabled_embeddings:
|
304 |
self.cell_emb = nn.Linear(cell_emb_dim, hidden_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
if 'smiles' not in self.disabled_embeddings:
|
306 |
self.smiles_emb = nn.Linear(smiles_emb_dim, hidden_dim)
|
307 |
+
|
308 |
+
if self.join_embeddings == 'concat':
|
309 |
+
joint_dim = hidden_dim * (4 - len(self.disabled_embeddings))
|
310 |
+
elif self.join_embeddings == 'sum':
|
311 |
+
joint_dim = hidden_dim
|
312 |
+
self.fc1 = nn.Linear(joint_dim, hidden_dim)
|
|
|
|
|
|
|
|
|
|
|
313 |
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
314 |
self.fc3 = nn.Linear(hidden_dim, 1)
|
315 |
|
|
|
344 |
embeddings.append(self.cell_emb(cell_emb))
|
345 |
if 'smiles' not in self.disabled_embeddings:
|
346 |
embeddings.append(self.smiles_emb(smiles_emb))
|
347 |
+
if self.join_embeddings == 'concat':
|
348 |
+
x = torch.cat(embeddings, dim=1)
|
349 |
+
elif self.join_embeddings == 'sum':
|
350 |
+
if len(embeddings) > 1:
|
351 |
+
embeddings = torch.stack(embeddings, dim=1)
|
352 |
+
x = torch.sum(embeddings, dim=1)
|
353 |
+
else:
|
354 |
+
x = embeddings[0]
|
355 |
+
x = self.dropout(F.relu(self.fc1(x)))
|
356 |
+
x = self.dropout(F.relu(self.fc2(x)))
|
357 |
x = self.fc3(x)
|
358 |
return x
|
359 |
|
|
|
425 |
shuffle=False,
|
426 |
)
|
427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
|
429 |
def train_model(
|
430 |
train_df,
|
|
|
435 |
learning_rate=2e-5,
|
436 |
max_epochs=50,
|
437 |
smiles_emb_dim=1024,
|
438 |
+
join_embeddings='concat',
|
439 |
smote_k_neighbors=5,
|
440 |
+
use_ored_activity=True,
|
441 |
fast_dev_run=False,
|
442 |
use_logger=True,
|
443 |
logger_name='protac',
|
|
|
455 |
max_epochs (int): The maximum number of epochs.
|
456 |
smiles_emb_dim (int): The dimension of the SMILES embeddings.
|
457 |
smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
|
458 |
+
use_ored_activity (bool): Whether to use the ORED activity column, i.e., "Active - OR" column.
|
459 |
fast_dev_run (bool): Whether to run a fast development run.
|
460 |
disabled_embeddings (list): The list of disabled embeddings.
|
461 |
|
|
|
523 |
cell_emb_dim=768,
|
524 |
batch_size=batch_size,
|
525 |
learning_rate=learning_rate,
|
526 |
+
join_embeddings=join_embeddings,
|
527 |
train_dataset=train_ds,
|
528 |
val_dataset=val_ds,
|
529 |
test_dataset=test_ds if test_df is not None else None,
|
|
|
550 |
max_epochs_options,
|
551 |
smote_k_neighbors_options,
|
552 |
fast_dev_run=False,
|
553 |
+
use_ored_activity=True,
|
554 |
+
disabled_embeddings=[],
|
555 |
) -> float:
|
556 |
# Generate the hyperparameters
|
557 |
hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
|
558 |
batch_size = trial.suggest_categorical('batch_size', batch_size_options)
|
559 |
learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
|
560 |
max_epochs = trial.suggest_categorical('max_epochs', max_epochs_options)
|
561 |
+
join_embeddings = trial.suggest_categorical('join_embeddings', ['concat', 'sum'])
|
562 |
smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
|
563 |
|
564 |
# Train the model with the current set of hyperparameters
|
|
|
567 |
val_df,
|
568 |
hidden_dim=hidden_dim,
|
569 |
batch_size=batch_size,
|
570 |
+
join_embeddings=join_embeddings,
|
571 |
learning_rate=learning_rate,
|
572 |
max_epochs=max_epochs,
|
573 |
smote_k_neighbors=smote_k_neighbors,
|
574 |
use_logger=False,
|
575 |
fast_dev_run=fast_dev_run,
|
576 |
+
use_ored_activity=use_ored_activity,
|
577 |
+
disabled_embeddings=disabled_embeddings,
|
578 |
)
|
579 |
|
580 |
# Metrics is a dictionary containing at least the validation loss
|
|
|
593 |
fast_dev_run=False,
|
594 |
n_trials=20,
|
595 |
logger_name='protac_hparam_search',
|
596 |
+
use_ored_activity=True,
|
597 |
+
disabled_embeddings=[],
|
598 |
) -> tuple:
|
599 |
""" Hyperparameter tuning and training of a PROTAC model.
|
600 |
|
|
|
614 |
max_epochs_options = [10, 20, 50]
|
615 |
smote_k_neighbors_options = list(range(3, 16))
|
616 |
|
617 |
+
# Set the verbosity of Optuna
|
618 |
+
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
619 |
# Create an Optuna study object
|
620 |
+
sampler = TPESampler(seed=42, multivariate=True)
|
621 |
+
study = optuna.create_study(direction='minimize', sampler=sampler)
|
622 |
+
study.optimize(
|
623 |
+
lambda trial: objective(
|
624 |
trial,
|
625 |
train_df,
|
626 |
val_df,
|
|
|
629 |
learning_rate_options,
|
630 |
max_epochs_options,
|
631 |
smote_k_neighbors_options=smote_k_neighbors_options,
|
632 |
+
fast_dev_run=fast_dev_run,
|
633 |
+
use_ored_activity=use_ored_activity,
|
634 |
+
disabled_embeddings=disabled_embeddings,
|
635 |
+
),
|
636 |
n_trials=n_trials,
|
637 |
)
|
638 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
639 |
# Retrain the model with the best hyperparameters
|
640 |
model, trainer, metrics = train_model(
|
641 |
train_df,
|
642 |
val_df,
|
643 |
test_df,
|
|
|
|
|
|
|
|
|
644 |
use_logger=True,
|
645 |
logger_name=logger_name,
|
646 |
fast_dev_run=fast_dev_run,
|
647 |
+
use_ored_activity=use_ored_activity,
|
648 |
+
disabled_embeddings=disabled_embeddings,
|
649 |
+
**study.best_params,
|
650 |
)
|
651 |
|
652 |
# Report the best hyperparameters found
|
653 |
+
metrics.update({f'hparam_{k}': v for k, v in study.best_params.items()})
|
|
|
|
|
|
|
|
|
654 |
|
655 |
# Return the best metrics
|
656 |
return model, trainer, metrics
|
657 |
|
658 |
+
|
659 |
+
def main(
|
660 |
+
use_ored_activity: bool = True,
|
661 |
+
n_trials: int = 50,
|
662 |
+
n_splits: int = 5,
|
663 |
+
fast_dev_run: bool = False,
|
664 |
+
):
|
665 |
+
""" Train a PROTAC model using the given datasets and hyperparameters.
|
666 |
+
|
667 |
+
Args:
|
668 |
+
use_ored_activity (bool): Whether to use the 'Active - OR' column.
|
669 |
+
n_trials (int): The number of hyperparameter optimization trials.
|
670 |
+
n_splits (int): The number of cross-validation splits.
|
671 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
672 |
+
"""
|
673 |
+
## Set the Column to Predict
|
674 |
+
active_col = 'Active - OR' if use_ored_activity else 'Active'
|
675 |
+
active_name = active_col.replace(' ', '').lower()
|
676 |
+
active_name = 'active-and' if active_name == 'active' else active_name
|
677 |
+
|
678 |
+
## Test Sets
|
679 |
+
|
680 |
+
active_df = protac_df[protac_df[active_col].notna()]
|
681 |
+
# Before starting any training, we isolate a small group of test data. Each element in the test set is selected so that all the following conditions are met:
|
682 |
+
# * its SMILES appears only once in the dataframe
|
683 |
+
# * its Uniprot appears only once in the dataframe
|
684 |
+
# * its (Smiles, Uniprot) pair appears only once in the dataframe
|
685 |
+
unique_smiles = active_df['Smiles'].value_counts() == 1
|
686 |
+
unique_uniprot = active_df['Uniprot'].value_counts() == 1
|
687 |
+
unique_smiles_uniprot = active_df.groupby(['Smiles', 'Uniprot']).size() == 1
|
688 |
+
|
689 |
+
# Get the indices of the unique samples
|
690 |
+
unique_smiles_idx = active_df['Smiles'].map(unique_smiles)
|
691 |
+
unique_uniprot_idx = active_df['Uniprot'].map(unique_uniprot)
|
692 |
+
unique_smiles_uniprot_idx = active_df.set_index(['Smiles', 'Uniprot']).index.map(unique_smiles_uniprot)
|
693 |
+
|
694 |
+
# Cross the indices to get the unique samples
|
695 |
+
unique_samples = active_df[unique_smiles_idx & unique_uniprot_idx & unique_smiles_uniprot_idx].index
|
696 |
+
test_df = active_df.loc[unique_samples]
|
697 |
+
train_val_df = active_df[~active_df.index.isin(unique_samples)]
|
698 |
+
|
699 |
+
## Cross-Validation Training
|
700 |
+
|
701 |
+
# Cross validation training with 5 splits. The split operation is done in three different ways:
|
702 |
+
#
|
703 |
+
# * Random split
|
704 |
+
# * POI-wise: some POIs never in both splits
|
705 |
+
# * Least Tanimoto similarity PROTAC-wise
|
706 |
+
|
707 |
+
# NOTE: When set to 60, it will result in 29 groups, with nice distributions of
|
708 |
+
# the number of unique groups in the train and validation sets, together with
|
709 |
+
# the number of active and inactive PROTACs.
|
710 |
+
n_bins_tanimoto = 60 if active_col == 'Active' else 400
|
711 |
+
|
712 |
+
# Make directory ../reports if it does not exist
|
713 |
+
if not os.path.exists('../reports'):
|
714 |
+
os.makedirs('../reports')
|
715 |
+
|
716 |
+
# Seed everything in pytorch lightning
|
717 |
+
pl.seed_everything(42)
|
718 |
+
|
719 |
+
# Loop over the different splits and train the model:
|
720 |
+
report = []
|
721 |
+
for group_type in ['random', 'uniprot', 'tanimoto']:
|
722 |
+
print('-' * 100)
|
723 |
+
print(f'Starting CV for group type: {group_type}')
|
724 |
+
print('-' * 100)
|
725 |
+
# Setup CV iterator and groups
|
726 |
+
if group_type == 'random':
|
727 |
+
kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
728 |
+
groups = None
|
729 |
+
elif group_type == 'uniprot':
|
730 |
+
# Split by Uniprot
|
731 |
+
kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
732 |
+
encoder = OrdinalEncoder()
|
733 |
+
groups = encoder.fit_transform(train_val_df['Uniprot'].values.reshape(-1, 1))
|
734 |
+
elif group_type == 'tanimoto':
|
735 |
+
# Split by tanimoto similarity, i.e., group_type PROTACs with similar Avg Tanimoto
|
736 |
+
kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
737 |
+
tanimoto_groups = pd.cut(train_val_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
|
738 |
+
encoder = OrdinalEncoder()
|
739 |
+
groups = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1))
|
740 |
+
# Start the CV over the folds
|
741 |
+
X = train_val_df.drop(columns=active_col)
|
742 |
+
y = train_val_df[active_col].tolist()
|
743 |
+
for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
|
744 |
+
print('-' * 100)
|
745 |
+
print(f'Starting CV for group type: {group_type}, fold: {k}')
|
746 |
+
print('-' * 100)
|
747 |
+
train_df = train_val_df.iloc[train_index]
|
748 |
+
val_df = train_val_df.iloc[val_index]
|
749 |
+
stats = {
|
750 |
+
'fold': k,
|
751 |
+
'group_type': group_type,
|
752 |
+
'train_len': len(train_df),
|
753 |
+
'val_len': len(val_df),
|
754 |
+
'train_perc': len(train_df) / len(train_val_df),
|
755 |
+
'val_perc': len(val_df) / len(train_val_df),
|
756 |
+
'train_active_perc': train_df[active_col].sum() / len(train_df),
|
757 |
+
'train_inactive_perc': (len(train_df) - train_df[active_col].sum()) / len(train_df),
|
758 |
+
'val_active_perc': val_df[active_col].sum() / len(val_df),
|
759 |
+
'val_inactive_perc': (len(val_df) - val_df[active_col].sum()) / len(val_df),
|
760 |
+
'test_active_perc': test_df[active_col].sum() / len(test_df),
|
761 |
+
'test_inactive_perc': (len(test_df) - test_df[active_col].sum()) / len(test_df),
|
762 |
+
'num_leaking_uniprot': len(set(train_df['Uniprot']).intersection(set(val_df['Uniprot']))),
|
763 |
+
'num_leaking_smiles': len(set(train_df['Smiles']).intersection(set(val_df['Smiles']))),
|
764 |
+
'disabled_embeddings': np.nan,
|
765 |
+
}
|
766 |
+
if group_type != 'random':
|
767 |
+
stats['train_unique_groups'] = len(np.unique(groups[train_index]))
|
768 |
+
stats['val_unique_groups'] = len(np.unique(groups[val_index]))
|
769 |
+
# Train and evaluate the model
|
770 |
+
model, trainer, metrics = hyperparameter_tuning_and_training(
|
771 |
+
train_df,
|
772 |
+
val_df,
|
773 |
+
test_df,
|
774 |
+
fast_dev_run=fast_dev_run,
|
775 |
+
n_trials=n_trials,
|
776 |
+
logger_name=f'protac_{active_name}_{group_type}_fold_{k}',
|
777 |
+
use_ored_activity=use_ored_activity,
|
778 |
+
)
|
779 |
+
hparams = {p.strip('hparam_'): v for p, v in stats.items() if p.startswith('hparam_')}
|
780 |
+
stats.update(metrics)
|
781 |
+
report.append(stats.copy())
|
782 |
+
del model
|
783 |
+
del trainer
|
784 |
+
|
785 |
+
# Ablation study: disable embeddings at a time
|
786 |
+
for disabled_embeddings in [['poi'], ['cell'], ['smiles'], ['e3', 'cell'], ['poi', 'e3', 'cell']]:
|
787 |
+
print('-' * 100)
|
788 |
+
print(f'Ablation study with disabled embeddings: {disabled_embeddings}')
|
789 |
+
print('-' * 100)
|
790 |
+
stats['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
|
791 |
+
model, trainer, metrics = train_model(
|
792 |
+
train_df,
|
793 |
+
val_df,
|
794 |
+
test_df,
|
795 |
+
fast_dev_run=fast_dev_run,
|
796 |
+
logger_name=f'protac_{active_name}_{group_type}_fold_{k}_disabled-{"-".join(disabled_embeddings)}',
|
797 |
+
use_ored_activity=use_ored_activity,
|
798 |
+
disabled_embeddings=disabled_embeddings,
|
799 |
+
**hparams,
|
800 |
+
)
|
801 |
+
stats.update(metrics)
|
802 |
+
report.append(stats.copy())
|
803 |
+
del model
|
804 |
+
del trainer
|
805 |
+
|
806 |
+
report = pd.DataFrame(report)
|
807 |
+
report.to_csv(
|
808 |
+
f'../reports/cv_report_hparam_search_{n_splits}-splits_{active_name}.csv',
|
809 |
+
index=False,
|
810 |
+
)
|
811 |
+
|
812 |
+
|
813 |
+
if __name__ == '__main__':
|
814 |
+
cli = CLI(main)
|