ribesstefano commited on
Commit
91692a4
1 Parent(s): 022416b

Polished hparam code and added ablation studies

Browse files
notebooks/protac_degradation_predictor.py CHANGED
@@ -1,34 +1,34 @@
1
  import optuna
 
 
 
 
 
 
2
  import pandas as pd
 
 
 
3
  from rdkit import Chem
4
  from rdkit.Chem import AllChem
5
  from rdkit import DataStructs
6
  from collections import defaultdict
7
-
8
- import h5py
9
- import numpy as np
10
  from tqdm.auto import tqdm
11
-
12
- import os
13
- import urllib.request
14
-
15
- from sklearn.preprocessing import StandardScaler
16
-
17
- # ## Define Torch Dataset
18
-
19
  from imblearn.over_sampling import SMOTE, ADASYN
20
- from sklearn.preprocessing import LabelEncoder
21
- import pandas as pd
22
- import numpy as np
23
-
24
- from torch.utils.data import Dataset, DataLoader
25
 
26
- import warnings
27
  import torch
28
  import torch.nn as nn
29
  import torch.nn.functional as F
30
  import torch.optim as optim
31
  import pytorch_lightning as pl
 
32
  from torchmetrics import (
33
  Accuracy,
34
  AUROC,
@@ -38,13 +38,11 @@ from torchmetrics import (
38
  )
39
  from torchmetrics import MetricCollection
40
 
41
- import pickle
42
 
43
- from sklearn.model_selection import (
44
- StratifiedKFold,
45
- StratifiedGroupKFold,
46
- )
47
- from sklearn.preprocessing import OrdinalEncoder
48
 
49
 
50
  protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
@@ -71,8 +69,6 @@ print(f'Number of compounds in test set: {len(unlabeled_df)}')
71
  # Protein embeddings downloaded from [Uniprot](https://www.uniprot.org/help/embeddings).
72
  #
73
  # Please note that running the following cell the first time might take a while.
74
-
75
-
76
  download_link = "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/embeddings/UP000005640_9606/per-protein.h5"
77
  embeddings_path = "../data/uniprot2embedding.h5"
78
  if not os.path.exists(embeddings_path):
@@ -82,26 +78,17 @@ if not os.path.exists(embeddings_path):
82
 
83
  protein_embeddings = {}
84
  with h5py.File("../data/uniprot2embedding.h5", "r") as file:
85
- print(f"number of entries: {len(file.items()):,}")
86
  uniprots = protac_df['Uniprot'].unique().tolist()
87
  uniprots += protac_df['E3 Ligase Uniprot'].unique().tolist()
88
  for i, sequence_id in tqdm(enumerate(uniprots), desc='Loading protein embeddings'):
89
  try:
90
  embedding = file[sequence_id][:]
91
  protein_embeddings[sequence_id] = np.array(embedding)
92
- if i < 10:
93
- print(
94
- f"\tid: {sequence_id}, "
95
- f"\tembeddings shape: {embedding.shape}, "
96
- f"\tembeddings mean: {np.array(embedding).mean()}"
97
- )
98
  except KeyError:
99
  print(f'KeyError for {sequence_id}')
100
  protein_embeddings[sequence_id] = np.zeros((1024,))
101
 
102
- # ## Load Cell Embeddings
103
-
104
-
105
  cell2embedding_filepath = '../data/cell2embedding.pkl'
106
  with open(cell2embedding_filepath, 'rb') as f:
107
  cell2embedding = pickle.load(f)
@@ -113,8 +100,7 @@ for cell_line in protac_df['Cell Line Identifier'].unique():
113
  if cell_line not in cell2embedding:
114
  cell2embedding[cell_line] = np.zeros(emb_shape)
115
 
116
- # ## Precompute Molecular Fingerprints
117
-
118
  morgan_fpgen = AllChem.GetMorganGenerator(
119
  radius=15,
120
  fpSize=1024,
@@ -142,7 +128,6 @@ print(f'Number of SMILES with overlapping fingerprints: {len(overlapping_smiles)
142
  print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smiles"].isin(overlapping_smiles)])}')
143
 
144
  # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
145
-
146
  tanimoto_matrix = defaultdict(list)
147
  for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
148
  fp1 = smiles2fp[smiles1]
@@ -158,11 +143,6 @@ protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
158
 
159
  smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
160
 
161
- # ## Set the Column to Predict
162
-
163
- active_col = 'Active'
164
- # active_col = 'Active - OR'
165
-
166
 
167
  class PROTAC_Dataset(Dataset):
168
  def __init__(
@@ -274,25 +254,24 @@ class PROTAC_Dataset(Dataset):
274
  }
275
  return elem
276
 
277
- # Ignore UserWarning from PyTorch Lightning
278
- warnings.filterwarnings("ignore", ".*does not have many workers.*")
279
 
280
  class PROTAC_Model(pl.LightningModule):
281
 
282
  def __init__(
283
  self,
284
- hidden_dim,
285
- smiles_emb_dim=1024,
286
- poi_emb_dim=1024,
287
- e3_emb_dim=1024,
288
- cell_emb_dim=768,
289
- batch_size=32,
290
- learning_rate=1e-3,
291
- dropout=0.2,
292
- train_dataset=None,
293
- val_dataset=None,
294
- test_dataset=None,
295
- disabled_embeddings=[],
 
296
  ):
297
  super().__init__()
298
  self.poi_emb_dim = poi_emb_dim
@@ -302,6 +281,7 @@ class PROTAC_Model(pl.LightningModule):
302
  self.hidden_dim = hidden_dim
303
  self.batch_size = batch_size
304
  self.learning_rate = learning_rate
 
305
  self.train_dataset = train_dataset
306
  self.val_dataset = val_dataset
307
  self.test_dataset = test_dataset
@@ -318,48 +298,18 @@ class PROTAC_Model(pl.LightningModule):
318
 
319
  if 'poi' not in self.disabled_embeddings:
320
  self.poi_emb = nn.Linear(poi_emb_dim, hidden_dim)
321
- # # Set the POI surrogate model as a Sequential model
322
- # self.poi_emb = nn.Sequential(
323
- # nn.Linear(poi_emb_dim, hidden_dim),
324
- # nn.GELU(),
325
- # nn.Dropout(p=dropout),
326
- # nn.Linear(hidden_dim, hidden_dim),
327
- # # nn.ReLU(),
328
- # # nn.Dropout(p=dropout),
329
- # )
330
  if 'e3' not in self.disabled_embeddings:
331
  self.e3_emb = nn.Linear(e3_emb_dim, hidden_dim)
332
- # self.e3_emb = nn.Sequential(
333
- # nn.Linear(e3_emb_dim, hidden_dim),
334
- # # nn.ReLU(),
335
- # nn.Dropout(p=dropout),
336
- # # nn.Linear(hidden_dim, hidden_dim),
337
- # # nn.ReLU(),
338
- # # nn.Dropout(p=dropout),
339
- # )
340
  if 'cell' not in self.disabled_embeddings:
341
  self.cell_emb = nn.Linear(cell_emb_dim, hidden_dim)
342
- # self.cell_emb = nn.Sequential(
343
- # nn.Linear(cell_emb_dim, hidden_dim),
344
- # # nn.ReLU(),
345
- # nn.Dropout(p=dropout),
346
- # # nn.Linear(hidden_dim, hidden_dim),
347
- # # nn.ReLU(),
348
- # # nn.Dropout(p=dropout),
349
- # )
350
  if 'smiles' not in self.disabled_embeddings:
351
  self.smiles_emb = nn.Linear(smiles_emb_dim, hidden_dim)
352
- # self.smiles_emb = nn.Sequential(
353
- # nn.Linear(smiles_emb_dim, hidden_dim),
354
- # # nn.ReLU(),
355
- # nn.Dropout(p=dropout),
356
- # # nn.Linear(hidden_dim, hidden_dim),
357
- # # nn.ReLU(),
358
- # # nn.Dropout(p=dropout),
359
- # )
360
-
361
- self.fc1 = nn.Linear(
362
- hidden_dim * (4 - len(self.disabled_embeddings)), hidden_dim)
363
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
364
  self.fc3 = nn.Linear(hidden_dim, 1)
365
 
@@ -394,9 +344,16 @@ class PROTAC_Model(pl.LightningModule):
394
  embeddings.append(self.cell_emb(cell_emb))
395
  if 'smiles' not in self.disabled_embeddings:
396
  embeddings.append(self.smiles_emb(smiles_emb))
397
- x = torch.cat(embeddings, dim=1)
398
- x = self.dropout(F.gelu(self.fc1(x)))
399
- x = self.dropout(F.gelu(self.fc2(x)))
 
 
 
 
 
 
 
400
  x = self.fc3(x)
401
  return x
402
 
@@ -468,178 +425,6 @@ class PROTAC_Model(pl.LightningModule):
468
  shuffle=False,
469
  )
470
 
471
- # ## Test Sets
472
-
473
- # We want a different test set per Cross-Validation (CV) experiment (see further down). We are interested in three scenarios:
474
- # * Randomly splitting the data into training and test sets. Hence, the test st shall contain unique SMILES and Uniprots
475
- # * Splitting the data according to their Uniprot. Hence, the test set shall contain unique Uniprots
476
- # * Splitting the data according to their SMILES, _i.e._, the test set shall contain unique SMILES
477
-
478
- test_indeces = {}
479
-
480
- # Isolating the unique SMILES and Uniprots:
481
-
482
- active_df = protac_df[protac_df[active_col].notna()].copy()
483
-
484
- # Get the unique SMILES and Uniprot
485
- unique_smiles = active_df['Smiles'].value_counts() == 1
486
- unique_uniprot = active_df['Uniprot'].value_counts() == 1
487
- print(f'Number of unique SMILES: {unique_smiles.sum()}')
488
- print(f'Number of unique Uniprot: {unique_uniprot.sum()}')
489
- # Sample 1% of the len(active_df) from unique SMILES and Uniprot and get the
490
- # indices for a test set
491
- n = int(0.05 * len(active_df)) // 2
492
- unique_smiles = unique_smiles[unique_smiles].sample(n=n, random_state=42)
493
- # unique_uniprot = unique_uniprot[unique_uniprot].sample(n=, random_state=42)
494
- unique_indices = active_df[
495
- active_df['Smiles'].isin(unique_smiles.index) &
496
- active_df['Uniprot'].isin(unique_uniprot.index)
497
- ].index
498
- print(f'Number of unique indices: {len(unique_indices)} ({len(unique_indices) / len(active_df):.1%})')
499
-
500
- test_indeces['random'] = unique_indices
501
-
502
- # # Get the test set
503
- # test_df = active_df.loc[unique_indices]
504
- # # Bar plot of the test Active distribution as percentage
505
- # test_df['Active'].value_counts(normalize=True).plot(kind='bar')
506
- # plt.title('Test set Active distribution')
507
- # plt.show()
508
- # # Bar plot of the test Active - OR distribution as percentage
509
- # test_df['Active - OR'].value_counts(normalize=True).plot(kind='bar')
510
- # plt.title('Test set Active - OR distribution')
511
- # plt.show()
512
-
513
- # Isolating the unique Uniprots:
514
-
515
- active_df = protac_df[protac_df[active_col].notna()].copy()
516
-
517
- unique_uniprot = active_df['Uniprot'].value_counts() == 1
518
- print(f'Number of unique Uniprot: {unique_uniprot.sum()}')
519
-
520
- # NOTE: Since they are very few, all unique Uniprot will be used as test set.
521
- # Get the indices for a test set
522
- unique_indices = active_df[active_df['Uniprot'].isin(unique_uniprot.index)].index
523
-
524
-
525
- test_indeces['uniprot'] = unique_indices
526
- print(f'Number of unique indices: {len(unique_indices)} ({len(unique_indices) / len(active_df):.1%})')
527
-
528
- # DEPRECATED: The following results in a too Before starting any training, we isolate a small group of test data. Each element in the test set is selected so that all the following conditions are met:
529
- # * its SMILES is unique
530
- # * its POI is unique
531
- # * its (SMILES, POI) pair is unique
532
-
533
- active_df = protac_df[protac_df[active_col].notna()]
534
-
535
- # Find the samples that:
536
- # * have their SMILES appearing only once in the dataframe
537
- # * have their Uniprot appearing only once in the dataframe
538
- # * have their (Smiles, Uniprot) pair appearing only once in the dataframe
539
- unique_smiles = active_df['Smiles'].value_counts() == 1
540
- unique_uniprot = active_df['Uniprot'].value_counts() == 1
541
- unique_smiles_uniprot = active_df.groupby(['Smiles', 'Uniprot']).size() == 1
542
-
543
- # Get the indices of the unique samples
544
- unique_smiles_idx = active_df['Smiles'].map(unique_smiles)
545
- unique_uniprot_idx = active_df['Uniprot'].map(unique_uniprot)
546
- unique_smiles_uniprot_idx = active_df.set_index(['Smiles', 'Uniprot']).index.map(unique_smiles_uniprot)
547
-
548
- # Cross the indices to get the unique samples
549
- # unique_samples = active_df[unique_smiles_idx & unique_uniprot_idx & unique_smiles_uniprot_idx].index
550
- unique_samples = active_df[unique_smiles_idx & unique_uniprot_idx].index
551
- test_df = active_df.loc[unique_samples]
552
-
553
- warnings.filterwarnings("ignore", ".*FixedLocator*")
554
-
555
- # ## Cross-Validation Training
556
-
557
- # Cross validation training with 5 splits. The split operation is done in three different ways:
558
- #
559
- # * Random split
560
- # * POI-wise: some POIs never in both splits
561
- # * Least Tanimoto similarity PROTAC-wise
562
-
563
- # ### Plotting CV Folds
564
-
565
-
566
- # NOTE: When set to 60, it will result in 29 groups, with nice distributions of
567
- # the number of unique groups in the train and validation sets, together with
568
- # the number of active and inactive PROTACs.
569
- n_bins_tanimoto = 60 if active_col == 'Active' else 400
570
- n_splits = 5
571
- # The train and validation sets will be created from the active PROTACs only,
572
- # i.e., the ones with 'Active' column not NaN, and that are NOT in the test set
573
- active_df = protac_df[protac_df[active_col].notna()]
574
- train_val_df = active_df[~active_df.index.isin(test_df.index)].copy()
575
-
576
- # Make three groups for CV:
577
- # * Random split
578
- # * Split by Uniprot (POI)
579
- # * Split by least tanimoto similarity PROTAC-wise
580
- groups = [
581
- 'random',
582
- 'uniprot',
583
- 'tanimoto',
584
- ]
585
- for group_type in groups:
586
- if group_type == 'random':
587
- kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
588
- groups = None
589
- elif group_type == 'uniprot':
590
- # Split by Uniprot
591
- kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
592
- encoder = OrdinalEncoder()
593
- groups = encoder.fit_transform(train_val_df['Uniprot'].values.reshape(-1, 1))
594
- print(f'Number of unique groups: {len(encoder.categories_[0])}')
595
- elif group_type == 'tanimoto':
596
- # Split by tanimoto similarity, i.e., group_type PROTACs with similar Avg Tanimoto
597
- kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
598
- tanimoto_groups = pd.cut(train_val_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
599
- encoder = OrdinalEncoder()
600
- groups = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1))
601
- print(f'Number of unique groups: {len(encoder.categories_[0])}')
602
-
603
-
604
- X = train_val_df.drop(columns=active_col)
605
- y = train_val_df[active_col].tolist()
606
-
607
- # print(f'Group: {group_type}')
608
- # fig, ax = plt.subplots(figsize=(6, 3))
609
- # plot_cv_indices(kf, X=X, y=y, group=groups, ax=ax, n_splits=n_splits)
610
- # plt.tight_layout()
611
- # plt.show()
612
-
613
- stats = []
614
- for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
615
- train_df = train_val_df.iloc[train_index]
616
- val_df = train_val_df.iloc[val_index]
617
- stat = {
618
- 'fold': k,
619
- 'train_len': len(train_df),
620
- 'val_len': len(val_df),
621
- 'train_perc': len(train_df) / len(train_val_df),
622
- 'val_perc': len(val_df) / len(train_val_df),
623
- 'train_active (%)': train_df[active_col].sum() / len(train_df) * 100,
624
- 'train_inactive (%)': (len(train_df) - train_df[active_col].sum()) / len(train_df) * 100,
625
- 'val_active (%)': val_df[active_col].sum() / len(val_df) * 100,
626
- 'val_inactive (%)': (len(val_df) - val_df[active_col].sum()) / len(val_df) * 100,
627
- 'num_leaking_uniprot': len(set(train_df['Uniprot']).intersection(set(val_df['Uniprot']))),
628
- 'num_leaking_smiles': len(set(train_df['Smiles']).intersection(set(val_df['Smiles']))),
629
- }
630
- if group_type != 'random':
631
- stat['train_unique_groups'] = len(np.unique(groups[train_index]))
632
- stat['val_unique_groups'] = len(np.unique(groups[val_index]))
633
- stats.append(stat)
634
- print('-' * 120)
635
-
636
- # ### Run CV
637
-
638
- import warnings
639
-
640
- # Seed everything in pytorch lightning
641
- pl.seed_everything(42)
642
-
643
 
644
  def train_model(
645
  train_df,
@@ -650,8 +435,9 @@ def train_model(
650
  learning_rate=2e-5,
651
  max_epochs=50,
652
  smiles_emb_dim=1024,
 
653
  smote_k_neighbors=5,
654
- use_ored_activity=False if active_col == 'Active' else True,
655
  fast_dev_run=False,
656
  use_logger=True,
657
  logger_name='protac',
@@ -669,7 +455,7 @@ def train_model(
669
  max_epochs (int): The maximum number of epochs.
670
  smiles_emb_dim (int): The dimension of the SMILES embeddings.
671
  smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
672
- use_ored_activity (bool): Whether to use the ORED activity column.
673
  fast_dev_run (bool): Whether to run a fast development run.
674
  disabled_embeddings (list): The list of disabled embeddings.
675
 
@@ -737,6 +523,7 @@ def train_model(
737
  cell_emb_dim=768,
738
  batch_size=batch_size,
739
  learning_rate=learning_rate,
 
740
  train_dataset=train_ds,
741
  val_dataset=val_ds,
742
  test_dataset=test_ds if test_df is not None else None,
@@ -763,12 +550,15 @@ def objective(
763
  max_epochs_options,
764
  smote_k_neighbors_options,
765
  fast_dev_run=False,
 
 
766
  ) -> float:
767
  # Generate the hyperparameters
768
  hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
769
  batch_size = trial.suggest_categorical('batch_size', batch_size_options)
770
  learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
771
  max_epochs = trial.suggest_categorical('max_epochs', max_epochs_options)
 
772
  smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
773
 
774
  # Train the model with the current set of hyperparameters
@@ -777,11 +567,14 @@ def objective(
777
  val_df,
778
  hidden_dim=hidden_dim,
779
  batch_size=batch_size,
 
780
  learning_rate=learning_rate,
781
  max_epochs=max_epochs,
782
  smote_k_neighbors=smote_k_neighbors,
783
  use_logger=False,
784
  fast_dev_run=fast_dev_run,
 
 
785
  )
786
 
787
  # Metrics is a dictionary containing at least the validation loss
@@ -800,6 +593,8 @@ def hyperparameter_tuning_and_training(
800
  fast_dev_run=False,
801
  n_trials=20,
802
  logger_name='protac_hparam_search',
 
 
803
  ) -> tuple:
804
  """ Hyperparameter tuning and training of a PROTAC model.
805
 
@@ -819,9 +614,13 @@ def hyperparameter_tuning_and_training(
819
  max_epochs_options = [10, 20, 50]
820
  smote_k_neighbors_options = list(range(3, 16))
821
 
 
 
822
  # Create an Optuna study object
823
- study = optuna.create_study(direction='minimize')
824
- study.optimize(lambda trial: objective(
 
 
825
  trial,
826
  train_df,
827
  val_df,
@@ -830,117 +629,186 @@ def hyperparameter_tuning_and_training(
830
  learning_rate_options,
831
  max_epochs_options,
832
  smote_k_neighbors_options=smote_k_neighbors_options,
833
- fast_dev_run=fast_dev_run,),
 
 
 
834
  n_trials=n_trials,
835
  )
836
 
837
- # Retrieve the best hyperparameters
838
- best_params = study.best_params
839
- best_hidden_dim = best_params['hidden_dim']
840
- best_batch_size = best_params['batch_size']
841
- best_learning_rate = best_params['learning_rate']
842
- best_max_epochs = best_params['max_epochs']
843
- best_smote_k_neighbors = best_params['smote_k_neighbors']
844
-
845
  # Retrain the model with the best hyperparameters
846
  model, trainer, metrics = train_model(
847
  train_df,
848
  val_df,
849
  test_df,
850
- hidden_dim=best_hidden_dim,
851
- batch_size=best_batch_size,
852
- learning_rate=best_learning_rate,
853
- max_epochs=best_max_epochs,
854
  use_logger=True,
855
  logger_name=logger_name,
856
  fast_dev_run=fast_dev_run,
 
 
 
857
  )
858
 
859
  # Report the best hyperparameters found
860
- metrics['hidden_dim'] = best_hidden_dim
861
- metrics['batch_size'] = best_batch_size
862
- metrics['learning_rate'] = best_learning_rate
863
- metrics['max_epochs'] = best_max_epochs
864
- metrics['smote_k_neighbors'] = best_smote_k_neighbors
865
 
866
  # Return the best metrics
867
  return model, trainer, metrics
868
 
869
- # Example usage
870
- # train_df, val_df, test_df = load_your_data() # You need to load your datasets here
871
- # model, trainer, best_metrics = hyperparameter_tuning_and_training(train_df, val_df, test_df)
872
-
873
- # Loop over the different splits and train the model:
874
- active_name = active_col.replace(' ', '').lower()
875
- active_name = 'active-and' if active_name == 'active' else active_name
876
- n_splits = 5
877
-
878
- report = []
879
- active_df = protac_df[protac_df[active_col].notna()]
880
- train_val_df = active_df[~active_df.index.isin(unique_samples)]
881
-
882
- # Make directory ../reports if it does not exist
883
- if not os.path.exists('../reports'):
884
- os.makedirs('../reports')
885
-
886
- for group_type in ['random', 'uniprot', 'tanimoto']:
887
- print(f'Starting CV for group type: {group_type}')
888
- # Setup CV iterator and groups
889
- if group_type == 'random':
890
- kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
891
- groups = None
892
- elif group_type == 'uniprot':
893
- # Split by Uniprot
894
- kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
895
- encoder = OrdinalEncoder()
896
- groups = encoder.fit_transform(train_val_df['Uniprot'].values.reshape(-1, 1))
897
- elif group_type == 'tanimoto':
898
- # Split by tanimoto similarity, i.e., group_type PROTACs with similar Avg Tanimoto
899
- kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
900
- tanimoto_groups = pd.cut(train_val_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
901
- encoder = OrdinalEncoder()
902
- groups = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1))
903
- # Start the CV over the folds
904
- X = train_val_df.drop(columns=active_col)
905
- y = train_val_df[active_col].tolist()
906
- for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
907
- train_df = train_val_df.iloc[train_index]
908
- val_df = train_val_df.iloc[val_index]
909
- stats = {
910
- 'fold': k,
911
- 'group_type': group_type,
912
- 'train_len': len(train_df),
913
- 'val_len': len(val_df),
914
- 'train_perc': len(train_df) / len(train_val_df),
915
- 'val_perc': len(val_df) / len(train_val_df),
916
- 'train_active_perc': train_df[active_col].sum() / len(train_df),
917
- 'train_inactive_perc': (len(train_df) - train_df[active_col].sum()) / len(train_df),
918
- 'val_active_perc': val_df[active_col].sum() / len(val_df),
919
- 'val_inactive_perc': (len(val_df) - val_df[active_col].sum()) / len(val_df),
920
- 'test_active_perc': test_df[active_col].sum() / len(test_df),
921
- 'test_inactive_perc': (len(test_df) - test_df[active_col].sum()) / len(test_df),
922
- 'num_leaking_uniprot': len(set(train_df['Uniprot']).intersection(set(val_df['Uniprot']))),
923
- 'num_leaking_smiles': len(set(train_df['Smiles']).intersection(set(val_df['Smiles']))),
924
- }
925
- if group_type != 'random':
926
- stats['train_unique_groups'] = len(np.unique(groups[train_index]))
927
- stats['val_unique_groups'] = len(np.unique(groups[val_index]))
928
- # Train and evaluate the model
929
- # model, trainer, metrics = train_model(train_df, val_df, test_df)
930
- model, trainer, metrics = hyperparameter_tuning_and_training(
931
- train_df,
932
- val_df,
933
- test_df,
934
- fast_dev_run=False,
935
- n_trials=50,
936
- logger_name=f'protac_{active_name}_{group_type}_fold_{k}',
937
- )
938
- stats.update(metrics)
939
- del model
940
- del trainer
941
- report.append(stats)
942
- report = pd.DataFrame(report)
943
- report.to_csv(
944
- f'../reports/cv_report_hparam_search_{n_splits}-splits_{active_name}.csv',
945
- index=False,
946
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import optuna
2
+ from optuna.samplers import TPESampler
3
+ import h5py
4
+ import os
5
+ import pickle
6
+ import warnings
7
+ import logging
8
  import pandas as pd
9
+ import numpy as np
10
+ import urllib.request
11
+
12
  from rdkit import Chem
13
  from rdkit.Chem import AllChem
14
  from rdkit import DataStructs
15
  from collections import defaultdict
16
+ from typing import Literal
17
+ from jsonargparse import CLI
 
18
  from tqdm.auto import tqdm
 
 
 
 
 
 
 
 
19
  from imblearn.over_sampling import SMOTE, ADASYN
20
+ from sklearn.preprocessing import OrdinalEncoder, StandardScaler, LabelEncoder
21
+ from sklearn.model_selection import (
22
+ StratifiedKFold,
23
+ StratifiedGroupKFold,
24
+ )
25
 
 
26
  import torch
27
  import torch.nn as nn
28
  import torch.nn.functional as F
29
  import torch.optim as optim
30
  import pytorch_lightning as pl
31
+ from torch.utils.data import Dataset, DataLoader
32
  from torchmetrics import (
33
  Accuracy,
34
  AUROC,
 
38
  )
39
  from torchmetrics import MetricCollection
40
 
 
41
 
42
+ # Ignore UserWarning from Matplotlib
43
+ warnings.filterwarnings("ignore", ".*FixedLocator*")
44
+ # Ignore UserWarning from PyTorch Lightning
45
+ warnings.filterwarnings("ignore", ".*does not have many workers.*")
 
46
 
47
 
48
  protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
 
69
  # Protein embeddings downloaded from [Uniprot](https://www.uniprot.org/help/embeddings).
70
  #
71
  # Please note that running the following cell the first time might take a while.
 
 
72
  download_link = "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/embeddings/UP000005640_9606/per-protein.h5"
73
  embeddings_path = "../data/uniprot2embedding.h5"
74
  if not os.path.exists(embeddings_path):
 
78
 
79
  protein_embeddings = {}
80
  with h5py.File("../data/uniprot2embedding.h5", "r") as file:
 
81
  uniprots = protac_df['Uniprot'].unique().tolist()
82
  uniprots += protac_df['E3 Ligase Uniprot'].unique().tolist()
83
  for i, sequence_id in tqdm(enumerate(uniprots), desc='Loading protein embeddings'):
84
  try:
85
  embedding = file[sequence_id][:]
86
  protein_embeddings[sequence_id] = np.array(embedding)
 
 
 
 
 
 
87
  except KeyError:
88
  print(f'KeyError for {sequence_id}')
89
  protein_embeddings[sequence_id] = np.zeros((1024,))
90
 
91
+ ## Load Cell Embeddings
 
 
92
  cell2embedding_filepath = '../data/cell2embedding.pkl'
93
  with open(cell2embedding_filepath, 'rb') as f:
94
  cell2embedding = pickle.load(f)
 
100
  if cell_line not in cell2embedding:
101
  cell2embedding[cell_line] = np.zeros(emb_shape)
102
 
103
+ ## Precompute Molecular Fingerprints
 
104
  morgan_fpgen = AllChem.GetMorganGenerator(
105
  radius=15,
106
  fpSize=1024,
 
128
  print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smiles"].isin(overlapping_smiles)])}')
129
 
130
  # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
 
131
  tanimoto_matrix = defaultdict(list)
132
  for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
133
  fp1 = smiles2fp[smiles1]
 
143
 
144
  smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
145
 
 
 
 
 
 
146
 
147
  class PROTAC_Dataset(Dataset):
148
  def __init__(
 
254
  }
255
  return elem
256
 
 
 
257
 
258
  class PROTAC_Model(pl.LightningModule):
259
 
260
  def __init__(
261
  self,
262
+ hidden_dim: int,
263
+ smiles_emb_dim: int = 1024,
264
+ poi_emb_dim: int = 1024,
265
+ e3_emb_dim: int = 1024,
266
+ cell_emb_dim: int = 768,
267
+ batch_size: int = 32,
268
+ learning_rate: float = 1e-3,
269
+ dropout: float = 0.2,
270
+ join_embeddings: Literal['concat', 'sum'] = 'concat',
271
+ train_dataset: PROTAC_Dataset = None,
272
+ val_dataset: PROTAC_Dataset = None,
273
+ test_dataset: PROTAC_Dataset = None,
274
+ disabled_embeddings: list = [],
275
  ):
276
  super().__init__()
277
  self.poi_emb_dim = poi_emb_dim
 
281
  self.hidden_dim = hidden_dim
282
  self.batch_size = batch_size
283
  self.learning_rate = learning_rate
284
+ self.join_embeddings = join_embeddings
285
  self.train_dataset = train_dataset
286
  self.val_dataset = val_dataset
287
  self.test_dataset = test_dataset
 
298
 
299
  if 'poi' not in self.disabled_embeddings:
300
  self.poi_emb = nn.Linear(poi_emb_dim, hidden_dim)
 
 
 
 
 
 
 
 
 
301
  if 'e3' not in self.disabled_embeddings:
302
  self.e3_emb = nn.Linear(e3_emb_dim, hidden_dim)
 
 
 
 
 
 
 
 
303
  if 'cell' not in self.disabled_embeddings:
304
  self.cell_emb = nn.Linear(cell_emb_dim, hidden_dim)
 
 
 
 
 
 
 
 
305
  if 'smiles' not in self.disabled_embeddings:
306
  self.smiles_emb = nn.Linear(smiles_emb_dim, hidden_dim)
307
+
308
+ if self.join_embeddings == 'concat':
309
+ joint_dim = hidden_dim * (4 - len(self.disabled_embeddings))
310
+ elif self.join_embeddings == 'sum':
311
+ joint_dim = hidden_dim
312
+ self.fc1 = nn.Linear(joint_dim, hidden_dim)
 
 
 
 
 
313
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
314
  self.fc3 = nn.Linear(hidden_dim, 1)
315
 
 
344
  embeddings.append(self.cell_emb(cell_emb))
345
  if 'smiles' not in self.disabled_embeddings:
346
  embeddings.append(self.smiles_emb(smiles_emb))
347
+ if self.join_embeddings == 'concat':
348
+ x = torch.cat(embeddings, dim=1)
349
+ elif self.join_embeddings == 'sum':
350
+ if len(embeddings) > 1:
351
+ embeddings = torch.stack(embeddings, dim=1)
352
+ x = torch.sum(embeddings, dim=1)
353
+ else:
354
+ x = embeddings[0]
355
+ x = self.dropout(F.relu(self.fc1(x)))
356
+ x = self.dropout(F.relu(self.fc2(x)))
357
  x = self.fc3(x)
358
  return x
359
 
 
425
  shuffle=False,
426
  )
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
  def train_model(
430
  train_df,
 
435
  learning_rate=2e-5,
436
  max_epochs=50,
437
  smiles_emb_dim=1024,
438
+ join_embeddings='concat',
439
  smote_k_neighbors=5,
440
+ use_ored_activity=True,
441
  fast_dev_run=False,
442
  use_logger=True,
443
  logger_name='protac',
 
455
  max_epochs (int): The maximum number of epochs.
456
  smiles_emb_dim (int): The dimension of the SMILES embeddings.
457
  smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
458
+ use_ored_activity (bool): Whether to use the ORED activity column, i.e., "Active - OR" column.
459
  fast_dev_run (bool): Whether to run a fast development run.
460
  disabled_embeddings (list): The list of disabled embeddings.
461
 
 
523
  cell_emb_dim=768,
524
  batch_size=batch_size,
525
  learning_rate=learning_rate,
526
+ join_embeddings=join_embeddings,
527
  train_dataset=train_ds,
528
  val_dataset=val_ds,
529
  test_dataset=test_ds if test_df is not None else None,
 
550
  max_epochs_options,
551
  smote_k_neighbors_options,
552
  fast_dev_run=False,
553
+ use_ored_activity=True,
554
+ disabled_embeddings=[],
555
  ) -> float:
556
  # Generate the hyperparameters
557
  hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
558
  batch_size = trial.suggest_categorical('batch_size', batch_size_options)
559
  learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
560
  max_epochs = trial.suggest_categorical('max_epochs', max_epochs_options)
561
+ join_embeddings = trial.suggest_categorical('join_embeddings', ['concat', 'sum'])
562
  smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
563
 
564
  # Train the model with the current set of hyperparameters
 
567
  val_df,
568
  hidden_dim=hidden_dim,
569
  batch_size=batch_size,
570
+ join_embeddings=join_embeddings,
571
  learning_rate=learning_rate,
572
  max_epochs=max_epochs,
573
  smote_k_neighbors=smote_k_neighbors,
574
  use_logger=False,
575
  fast_dev_run=fast_dev_run,
576
+ use_ored_activity=use_ored_activity,
577
+ disabled_embeddings=disabled_embeddings,
578
  )
579
 
580
  # Metrics is a dictionary containing at least the validation loss
 
593
  fast_dev_run=False,
594
  n_trials=20,
595
  logger_name='protac_hparam_search',
596
+ use_ored_activity=True,
597
+ disabled_embeddings=[],
598
  ) -> tuple:
599
  """ Hyperparameter tuning and training of a PROTAC model.
600
 
 
614
  max_epochs_options = [10, 20, 50]
615
  smote_k_neighbors_options = list(range(3, 16))
616
 
617
+ # Set the verbosity of Optuna
618
+ optuna.logging.set_verbosity(optuna.logging.WARNING)
619
  # Create an Optuna study object
620
+ sampler = TPESampler(seed=42, multivariate=True)
621
+ study = optuna.create_study(direction='minimize', sampler=sampler)
622
+ study.optimize(
623
+ lambda trial: objective(
624
  trial,
625
  train_df,
626
  val_df,
 
629
  learning_rate_options,
630
  max_epochs_options,
631
  smote_k_neighbors_options=smote_k_neighbors_options,
632
+ fast_dev_run=fast_dev_run,
633
+ use_ored_activity=use_ored_activity,
634
+ disabled_embeddings=disabled_embeddings,
635
+ ),
636
  n_trials=n_trials,
637
  )
638
 
 
 
 
 
 
 
 
 
639
  # Retrain the model with the best hyperparameters
640
  model, trainer, metrics = train_model(
641
  train_df,
642
  val_df,
643
  test_df,
 
 
 
 
644
  use_logger=True,
645
  logger_name=logger_name,
646
  fast_dev_run=fast_dev_run,
647
+ use_ored_activity=use_ored_activity,
648
+ disabled_embeddings=disabled_embeddings,
649
+ **study.best_params,
650
  )
651
 
652
  # Report the best hyperparameters found
653
+ metrics.update({f'hparam_{k}': v for k, v in study.best_params.items()})
 
 
 
 
654
 
655
  # Return the best metrics
656
  return model, trainer, metrics
657
 
658
+
659
+ def main(
660
+ use_ored_activity: bool = True,
661
+ n_trials: int = 50,
662
+ n_splits: int = 5,
663
+ fast_dev_run: bool = False,
664
+ ):
665
+ """ Train a PROTAC model using the given datasets and hyperparameters.
666
+
667
+ Args:
668
+ use_ored_activity (bool): Whether to use the 'Active - OR' column.
669
+ n_trials (int): The number of hyperparameter optimization trials.
670
+ n_splits (int): The number of cross-validation splits.
671
+ fast_dev_run (bool): Whether to run a fast development run.
672
+ """
673
+ ## Set the Column to Predict
674
+ active_col = 'Active - OR' if use_ored_activity else 'Active'
675
+ active_name = active_col.replace(' ', '').lower()
676
+ active_name = 'active-and' if active_name == 'active' else active_name
677
+
678
+ ## Test Sets
679
+
680
+ active_df = protac_df[protac_df[active_col].notna()]
681
+ # Before starting any training, we isolate a small group of test data. Each element in the test set is selected so that all the following conditions are met:
682
+ # * its SMILES appears only once in the dataframe
683
+ # * its Uniprot appears only once in the dataframe
684
+ # * its (Smiles, Uniprot) pair appears only once in the dataframe
685
+ unique_smiles = active_df['Smiles'].value_counts() == 1
686
+ unique_uniprot = active_df['Uniprot'].value_counts() == 1
687
+ unique_smiles_uniprot = active_df.groupby(['Smiles', 'Uniprot']).size() == 1
688
+
689
+ # Get the indices of the unique samples
690
+ unique_smiles_idx = active_df['Smiles'].map(unique_smiles)
691
+ unique_uniprot_idx = active_df['Uniprot'].map(unique_uniprot)
692
+ unique_smiles_uniprot_idx = active_df.set_index(['Smiles', 'Uniprot']).index.map(unique_smiles_uniprot)
693
+
694
+ # Cross the indices to get the unique samples
695
+ unique_samples = active_df[unique_smiles_idx & unique_uniprot_idx & unique_smiles_uniprot_idx].index
696
+ test_df = active_df.loc[unique_samples]
697
+ train_val_df = active_df[~active_df.index.isin(unique_samples)]
698
+
699
+ ## Cross-Validation Training
700
+
701
+ # Cross validation training with 5 splits. The split operation is done in three different ways:
702
+ #
703
+ # * Random split
704
+ # * POI-wise: some POIs never in both splits
705
+ # * Least Tanimoto similarity PROTAC-wise
706
+
707
+ # NOTE: When set to 60, it will result in 29 groups, with nice distributions of
708
+ # the number of unique groups in the train and validation sets, together with
709
+ # the number of active and inactive PROTACs.
710
+ n_bins_tanimoto = 60 if active_col == 'Active' else 400
711
+
712
+ # Make directory ../reports if it does not exist
713
+ if not os.path.exists('../reports'):
714
+ os.makedirs('../reports')
715
+
716
+ # Seed everything in pytorch lightning
717
+ pl.seed_everything(42)
718
+
719
+ # Loop over the different splits and train the model:
720
+ report = []
721
+ for group_type in ['random', 'uniprot', 'tanimoto']:
722
+ print('-' * 100)
723
+ print(f'Starting CV for group type: {group_type}')
724
+ print('-' * 100)
725
+ # Setup CV iterator and groups
726
+ if group_type == 'random':
727
+ kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
728
+ groups = None
729
+ elif group_type == 'uniprot':
730
+ # Split by Uniprot
731
+ kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
732
+ encoder = OrdinalEncoder()
733
+ groups = encoder.fit_transform(train_val_df['Uniprot'].values.reshape(-1, 1))
734
+ elif group_type == 'tanimoto':
735
+ # Split by tanimoto similarity, i.e., group_type PROTACs with similar Avg Tanimoto
736
+ kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
737
+ tanimoto_groups = pd.cut(train_val_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
738
+ encoder = OrdinalEncoder()
739
+ groups = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1))
740
+ # Start the CV over the folds
741
+ X = train_val_df.drop(columns=active_col)
742
+ y = train_val_df[active_col].tolist()
743
+ for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
744
+ print('-' * 100)
745
+ print(f'Starting CV for group type: {group_type}, fold: {k}')
746
+ print('-' * 100)
747
+ train_df = train_val_df.iloc[train_index]
748
+ val_df = train_val_df.iloc[val_index]
749
+ stats = {
750
+ 'fold': k,
751
+ 'group_type': group_type,
752
+ 'train_len': len(train_df),
753
+ 'val_len': len(val_df),
754
+ 'train_perc': len(train_df) / len(train_val_df),
755
+ 'val_perc': len(val_df) / len(train_val_df),
756
+ 'train_active_perc': train_df[active_col].sum() / len(train_df),
757
+ 'train_inactive_perc': (len(train_df) - train_df[active_col].sum()) / len(train_df),
758
+ 'val_active_perc': val_df[active_col].sum() / len(val_df),
759
+ 'val_inactive_perc': (len(val_df) - val_df[active_col].sum()) / len(val_df),
760
+ 'test_active_perc': test_df[active_col].sum() / len(test_df),
761
+ 'test_inactive_perc': (len(test_df) - test_df[active_col].sum()) / len(test_df),
762
+ 'num_leaking_uniprot': len(set(train_df['Uniprot']).intersection(set(val_df['Uniprot']))),
763
+ 'num_leaking_smiles': len(set(train_df['Smiles']).intersection(set(val_df['Smiles']))),
764
+ 'disabled_embeddings': np.nan,
765
+ }
766
+ if group_type != 'random':
767
+ stats['train_unique_groups'] = len(np.unique(groups[train_index]))
768
+ stats['val_unique_groups'] = len(np.unique(groups[val_index]))
769
+ # Train and evaluate the model
770
+ model, trainer, metrics = hyperparameter_tuning_and_training(
771
+ train_df,
772
+ val_df,
773
+ test_df,
774
+ fast_dev_run=fast_dev_run,
775
+ n_trials=n_trials,
776
+ logger_name=f'protac_{active_name}_{group_type}_fold_{k}',
777
+ use_ored_activity=use_ored_activity,
778
+ )
779
+ hparams = {p.strip('hparam_'): v for p, v in stats.items() if p.startswith('hparam_')}
780
+ stats.update(metrics)
781
+ report.append(stats.copy())
782
+ del model
783
+ del trainer
784
+
785
+ # Ablation study: disable embeddings at a time
786
+ for disabled_embeddings in [['poi'], ['cell'], ['smiles'], ['e3', 'cell'], ['poi', 'e3', 'cell']]:
787
+ print('-' * 100)
788
+ print(f'Ablation study with disabled embeddings: {disabled_embeddings}')
789
+ print('-' * 100)
790
+ stats['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
791
+ model, trainer, metrics = train_model(
792
+ train_df,
793
+ val_df,
794
+ test_df,
795
+ fast_dev_run=fast_dev_run,
796
+ logger_name=f'protac_{active_name}_{group_type}_fold_{k}_disabled-{"-".join(disabled_embeddings)}',
797
+ use_ored_activity=use_ored_activity,
798
+ disabled_embeddings=disabled_embeddings,
799
+ **hparams,
800
+ )
801
+ stats.update(metrics)
802
+ report.append(stats.copy())
803
+ del model
804
+ del trainer
805
+
806
+ report = pd.DataFrame(report)
807
+ report.to_csv(
808
+ f'../reports/cv_report_hparam_search_{n_splits}-splits_{active_name}.csv',
809
+ index=False,
810
+ )
811
+
812
+
813
+ if __name__ == '__main__':
814
+ cli = CLI(main)