ribesstefano commited on
Commit
d36ec1d
1 Parent(s): 87b14e7

Polished scripted version of Hparam-CV training

Browse files
notebooks/protac_degradation_predictor.py CHANGED
@@ -1,47 +1,75 @@
1
- # %% [markdown]
2
- # # PROTAC-Degradation-Predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- # %%
 
5
  import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
8
  protac_df.head()
9
 
10
- # %%
11
  # Get the unique Article IDs of the entries with NaN values in the Active column
12
  nan_active = protac_df[protac_df['Active'].isna()]['Article DOI'].unique()
13
  nan_active
14
 
15
- # %%
16
  # Map E3 Ligase Iap to IAP
17
  protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
18
 
19
- # %%
20
- protac_df.columns
21
-
22
- # %%
23
  cells = sorted(protac_df['Cell Type'].dropna().unique().tolist())
24
  print(f'Number of non-cleaned cell lines: {len(cells)}')
25
 
26
- # %%
27
  cells = sorted(protac_df['Cell Line Identifier'].dropna().unique().tolist())
28
  print(f'Number of cleaned cell lines: {len(cells)}')
29
 
30
- # %%
31
  unlabeled_df = protac_df[protac_df['Active'].isna()]
32
  print(f'Number of compounds in test set: {len(unlabeled_df)}')
33
 
34
- # %% [markdown]
35
  # ## Load Protein Embeddings
36
 
37
- # %% [markdown]
38
  # Protein embeddings downloaded from [Uniprot](https://www.uniprot.org/help/embeddings).
39
  #
40
  # Please note that running the following cell the first time might take a while.
41
 
42
- # %%
43
- import os
44
- import urllib.request
45
 
46
  download_link = "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/embeddings/UP000005640_9606/per-protein.h5"
47
  embeddings_path = "../data/uniprot2embedding.h5"
@@ -50,11 +78,6 @@ if not os.path.exists(embeddings_path):
50
  print(f'Downloading embeddings from {download_link}')
51
  urllib.request.urlretrieve(download_link, embeddings_path)
52
 
53
- # %%
54
- import h5py
55
- import numpy as np
56
- from tqdm.auto import tqdm
57
-
58
  protein_embeddings = {}
59
  with h5py.File("../data/uniprot2embedding.h5", "r") as file:
60
  print(f"number of entries: {len(file.items()):,}")
@@ -74,37 +97,27 @@ with h5py.File("../data/uniprot2embedding.h5", "r") as file:
74
  print(f'KeyError for {sequence_id}')
75
  protein_embeddings[sequence_id] = np.zeros((1024,))
76
 
77
- # %% [markdown]
78
  # ## Load Cell Embeddings
79
 
80
- # %%
81
- import pickle
82
 
83
  cell2embedding_filepath = '../data/cell2embedding.pkl'
84
  with open(cell2embedding_filepath, 'rb') as f:
85
  cell2embedding = pickle.load(f)
86
  print(f'Loaded {len(cell2embedding)} cell lines')
87
 
88
- # %%
89
  emb_shape = cell2embedding[list(cell2embedding.keys())[0]].shape
90
  # Assign all-zero vectors to cell lines that are not in the embedding file
91
  for cell_line in protac_df['Cell Line Identifier'].unique():
92
  if cell_line not in cell2embedding:
93
  cell2embedding[cell_line] = np.zeros(emb_shape)
94
 
95
- # %% [markdown]
96
  # ## Precompute Molecular Fingerprints
97
-
98
- # %%
99
- from rdkit import Chem
100
- from rdkit.Chem import AllChem
101
- from rdkit.Chem import Draw
102
-
103
  morgan_radius = 15
104
  n_bits = 1024
105
 
106
  # fpgen = AllChem.GetAtomPairGenerator()
107
- rdkit_fpgen = AllChem.GetRDKitFPGenerator(maxPath=5, fpSize=512)
108
  morgan_fpgen = AllChem.GetMorganGenerator(radius=morgan_radius, fpSize=n_bits)
109
 
110
  smiles2fp = {}
@@ -129,7 +142,6 @@ for smiles, fp in smiles2fp.items():
129
  print(f'Number of SMILES with overlapping fingerprints: {len(overlapping_smiles)}')
130
  print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smiles"].isin(overlapping_smiles)])}')
131
 
132
- # %%
133
  # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
134
  from rdkit import DataStructs
135
  from collections import defaultdict
@@ -147,44 +159,14 @@ for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing T
147
  avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
148
  protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
149
 
150
- # %%
151
- # # Plot the distribution of the average Tanimoto similarity
152
- # import seaborn as sns
153
- # import matplotlib.pyplot as plt
154
-
155
- # sns.histplot(protac_df['Avg Tanimoto'], bins=50)
156
- # plt.xlabel('Average Tanimoto similarity')
157
- # plt.ylabel('Count')
158
- # plt.title('Distribution of average Tanimoto similarity')
159
- # plt.grid(axis='y', alpha=0.5)
160
- # plt.show()
161
-
162
- # %%
163
  smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
164
 
165
- # %% [markdown]
166
  # ## Set the Column to Predict
167
 
168
- # %%
169
  # active_col = 'Active'
170
  active_col = 'Active - OR'
171
 
172
 
173
- from sklearn.preprocessing import StandardScaler
174
-
175
- # %% [markdown]
176
- # ## Define Torch Dataset
177
-
178
- # %%
179
- from imblearn.over_sampling import SMOTE, ADASYN
180
- from sklearn.preprocessing import LabelEncoder
181
- import pandas as pd
182
- import numpy as np
183
-
184
- # %%
185
- from torch.utils.data import Dataset, DataLoader
186
-
187
-
188
  class PROTAC_Dataset(Dataset):
189
  def __init__(
190
  self,
@@ -295,22 +277,6 @@ class PROTAC_Dataset(Dataset):
295
  }
296
  return elem
297
 
298
- # %%
299
- import warnings
300
- import torch
301
- import torch.nn as nn
302
- import torch.nn.functional as F
303
- import torch.optim as optim
304
- import pytorch_lightning as pl
305
- from torchmetrics import (
306
- Accuracy,
307
- AUROC,
308
- Precision,
309
- Recall,
310
- F1Score,
311
- )
312
- from torchmetrics import MetricCollection
313
-
314
  # Ignore UserWarning from PyTorch Lightning
315
  warnings.filterwarnings("ignore", ".*does not have many workers.*")
316
 
@@ -505,22 +471,17 @@ class PROTAC_Model(pl.LightningModule):
505
  shuffle=False,
506
  )
507
 
508
- # %% [markdown]
509
  # ## Test Sets
510
 
511
- # %% [markdown]
512
  # We want a different test set per Cross-Validation (CV) experiment (see further down). We are interested in three scenarios:
513
  # * Randomly splitting the data into training and test sets. Hence, the test st shall contain unique SMILES and Uniprots
514
  # * Splitting the data according to their Uniprot. Hence, the test set shall contain unique Uniprots
515
  # * Splitting the data according to their SMILES, _i.e._, the test set shall contain unique SMILES
516
 
517
- # %%
518
  test_indeces = {}
519
 
520
- # %% [markdown]
521
  # Isolating the unique SMILES and Uniprots:
522
 
523
- # %%
524
  active_df = protac_df[protac_df[active_col].notna()].copy()
525
 
526
  # Get the unique SMILES and Uniprot
@@ -552,10 +513,8 @@ test_indeces['random'] = unique_indices
552
  # plt.title('Test set Active - OR distribution')
553
  # plt.show()
554
 
555
- # %% [markdown]
556
  # Isolating the unique Uniprots:
557
 
558
- # %%
559
  active_df = protac_df[protac_df[active_col].notna()].copy()
560
 
561
  unique_uniprot = active_df['Uniprot'].value_counts() == 1
@@ -569,13 +528,11 @@ unique_indices = active_df[active_df['Uniprot'].isin(unique_uniprot.index)].inde
569
  test_indeces['uniprot'] = unique_indices
570
  print(f'Number of unique indices: {len(unique_indices)} ({len(unique_indices) / len(active_df):.1%})')
571
 
572
- # %% [markdown]
573
  # DEPRECATED: The following results in a too Before starting any training, we isolate a small group of test data. Each element in the test set is selected so that all the following conditions are met:
574
  # * its SMILES is unique
575
  # * its POI is unique
576
  # * its (SMILES, POI) pair is unique
577
 
578
- # %%
579
  active_df = protac_df[protac_df[active_col].notna()]
580
 
581
  # Find the samples that:
@@ -598,25 +555,16 @@ test_df = active_df.loc[unique_samples]
598
 
599
  warnings.filterwarnings("ignore", ".*FixedLocator*")
600
 
601
- # %% [markdown]
602
  # ## Cross-Validation Training
603
 
604
- # %% [markdown]
605
  # Cross validation training with 5 splits. The split operation is done in three different ways:
606
  #
607
  # * Random split
608
  # * POI-wise: some POIs never in both splits
609
  # * Least Tanimoto similarity PROTAC-wise
610
 
611
- # %% [markdown]
612
  # ### Plotting CV Folds
613
 
614
- # %%
615
- from sklearn.model_selection import (
616
- StratifiedKFold,
617
- StratifiedGroupKFold,
618
- )
619
- from sklearn.preprocessing import OrdinalEncoder
620
 
621
  # NOTE: When set to 60, it will result in 29 groups, with nice distributions of
622
  # the number of unique groups in the train and validation sets, together with
@@ -688,10 +636,8 @@ for group_type in groups:
688
  stats.append(stat)
689
  print('-' * 120)
690
 
691
- # %% [markdown]
692
  # ### Run CV
693
 
694
- # %%
695
  import warnings
696
 
697
  # Seed everything in pytorch lightning
@@ -805,14 +751,8 @@ def train_model(
805
  metrics.update(test_metrics)
806
  return model, trainer, metrics
807
 
808
- # %% [markdown]
809
  # Setup hyperparameter optimization:
810
 
811
- # %%
812
- import optuna
813
- import pandas as pd
814
-
815
-
816
  def objective(
817
  trial,
818
  train_df,
@@ -926,10 +866,8 @@ def hyperparameter_tuning_and_training(
926
  # train_df, val_df, test_df = load_your_data() # You need to load your datasets here
927
  # model, trainer, best_metrics = hyperparameter_tuning_and_training(train_df, val_df, test_df)
928
 
929
- # %% [markdown]
930
  # Loop over the different splits and train the model:
931
 
932
- # %%
933
  n_splits = 5
934
  report = []
935
  active_df = protac_df[protac_df[active_col].notna()]
@@ -998,4 +936,3 @@ report = pd.DataFrame(report)
998
  report.to_csv(
999
  f'../reports/cv_report_hparam_search_{n_splits}-splits.csv', index=False,
1000
  )
1001
-
 
1
+ import optuna
2
+ import pandas as pd
3
+ from rdkit import Chem
4
+ from rdkit.Chem import AllChem
5
+
6
+ import h5py
7
+ import numpy as np
8
+ from tqdm.auto import tqdm
9
+
10
+ import os
11
+ import urllib.request
12
+
13
+ from sklearn.preprocessing import StandardScaler
14
+
15
+ # ## Define Torch Dataset
16
 
17
+ from imblearn.over_sampling import SMOTE, ADASYN
18
+ from sklearn.preprocessing import LabelEncoder
19
  import pandas as pd
20
+ import numpy as np
21
+
22
+ from torch.utils.data import Dataset, DataLoader
23
+
24
+ import warnings
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import torch.optim as optim
29
+ import pytorch_lightning as pl
30
+ from torchmetrics import (
31
+ Accuracy,
32
+ AUROC,
33
+ Precision,
34
+ Recall,
35
+ F1Score,
36
+ )
37
+ from torchmetrics import MetricCollection
38
+
39
+ import pickle
40
+
41
+ from sklearn.model_selection import (
42
+ StratifiedKFold,
43
+ StratifiedGroupKFold,
44
+ )
45
+ from sklearn.preprocessing import OrdinalEncoder
46
+
47
 
48
  protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
49
  protac_df.head()
50
 
 
51
  # Get the unique Article IDs of the entries with NaN values in the Active column
52
  nan_active = protac_df[protac_df['Active'].isna()]['Article DOI'].unique()
53
  nan_active
54
 
 
55
  # Map E3 Ligase Iap to IAP
56
  protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
57
 
 
 
 
 
58
  cells = sorted(protac_df['Cell Type'].dropna().unique().tolist())
59
  print(f'Number of non-cleaned cell lines: {len(cells)}')
60
 
 
61
  cells = sorted(protac_df['Cell Line Identifier'].dropna().unique().tolist())
62
  print(f'Number of cleaned cell lines: {len(cells)}')
63
 
 
64
  unlabeled_df = protac_df[protac_df['Active'].isna()]
65
  print(f'Number of compounds in test set: {len(unlabeled_df)}')
66
 
 
67
  # ## Load Protein Embeddings
68
 
 
69
  # Protein embeddings downloaded from [Uniprot](https://www.uniprot.org/help/embeddings).
70
  #
71
  # Please note that running the following cell the first time might take a while.
72
 
 
 
 
73
 
74
  download_link = "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/embeddings/UP000005640_9606/per-protein.h5"
75
  embeddings_path = "../data/uniprot2embedding.h5"
 
78
  print(f'Downloading embeddings from {download_link}')
79
  urllib.request.urlretrieve(download_link, embeddings_path)
80
 
 
 
 
 
 
81
  protein_embeddings = {}
82
  with h5py.File("../data/uniprot2embedding.h5", "r") as file:
83
  print(f"number of entries: {len(file.items()):,}")
 
97
  print(f'KeyError for {sequence_id}')
98
  protein_embeddings[sequence_id] = np.zeros((1024,))
99
 
 
100
  # ## Load Cell Embeddings
101
 
 
 
102
 
103
  cell2embedding_filepath = '../data/cell2embedding.pkl'
104
  with open(cell2embedding_filepath, 'rb') as f:
105
  cell2embedding = pickle.load(f)
106
  print(f'Loaded {len(cell2embedding)} cell lines')
107
 
 
108
  emb_shape = cell2embedding[list(cell2embedding.keys())[0]].shape
109
  # Assign all-zero vectors to cell lines that are not in the embedding file
110
  for cell_line in protac_df['Cell Line Identifier'].unique():
111
  if cell_line not in cell2embedding:
112
  cell2embedding[cell_line] = np.zeros(emb_shape)
113
 
 
114
  # ## Precompute Molecular Fingerprints
115
+
 
 
 
 
 
116
  morgan_radius = 15
117
  n_bits = 1024
118
 
119
  # fpgen = AllChem.GetAtomPairGenerator()
120
+ # rdkit_fpgen = AllChem.GetRDKitFPGenerator(maxPath=5, fpSize=512)
121
  morgan_fpgen = AllChem.GetMorganGenerator(radius=morgan_radius, fpSize=n_bits)
122
 
123
  smiles2fp = {}
 
142
  print(f'Number of SMILES with overlapping fingerprints: {len(overlapping_smiles)}')
143
  print(f'Number of overlapping SMILES in protac_df: {len(protac_df[protac_df["Smiles"].isin(overlapping_smiles)])}')
144
 
 
145
  # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
146
  from rdkit import DataStructs
147
  from collections import defaultdict
 
159
  avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
160
  protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
163
 
 
164
  # ## Set the Column to Predict
165
 
 
166
  # active_col = 'Active'
167
  active_col = 'Active - OR'
168
 
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  class PROTAC_Dataset(Dataset):
171
  def __init__(
172
  self,
 
277
  }
278
  return elem
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  # Ignore UserWarning from PyTorch Lightning
281
  warnings.filterwarnings("ignore", ".*does not have many workers.*")
282
 
 
471
  shuffle=False,
472
  )
473
 
 
474
  # ## Test Sets
475
 
 
476
  # We want a different test set per Cross-Validation (CV) experiment (see further down). We are interested in three scenarios:
477
  # * Randomly splitting the data into training and test sets. Hence, the test st shall contain unique SMILES and Uniprots
478
  # * Splitting the data according to their Uniprot. Hence, the test set shall contain unique Uniprots
479
  # * Splitting the data according to their SMILES, _i.e._, the test set shall contain unique SMILES
480
 
 
481
  test_indeces = {}
482
 
 
483
  # Isolating the unique SMILES and Uniprots:
484
 
 
485
  active_df = protac_df[protac_df[active_col].notna()].copy()
486
 
487
  # Get the unique SMILES and Uniprot
 
513
  # plt.title('Test set Active - OR distribution')
514
  # plt.show()
515
 
 
516
  # Isolating the unique Uniprots:
517
 
 
518
  active_df = protac_df[protac_df[active_col].notna()].copy()
519
 
520
  unique_uniprot = active_df['Uniprot'].value_counts() == 1
 
528
  test_indeces['uniprot'] = unique_indices
529
  print(f'Number of unique indices: {len(unique_indices)} ({len(unique_indices) / len(active_df):.1%})')
530
 
 
531
  # DEPRECATED: The following results in a too Before starting any training, we isolate a small group of test data. Each element in the test set is selected so that all the following conditions are met:
532
  # * its SMILES is unique
533
  # * its POI is unique
534
  # * its (SMILES, POI) pair is unique
535
 
 
536
  active_df = protac_df[protac_df[active_col].notna()]
537
 
538
  # Find the samples that:
 
555
 
556
  warnings.filterwarnings("ignore", ".*FixedLocator*")
557
 
 
558
  # ## Cross-Validation Training
559
 
 
560
  # Cross validation training with 5 splits. The split operation is done in three different ways:
561
  #
562
  # * Random split
563
  # * POI-wise: some POIs never in both splits
564
  # * Least Tanimoto similarity PROTAC-wise
565
 
 
566
  # ### Plotting CV Folds
567
 
 
 
 
 
 
 
568
 
569
  # NOTE: When set to 60, it will result in 29 groups, with nice distributions of
570
  # the number of unique groups in the train and validation sets, together with
 
636
  stats.append(stat)
637
  print('-' * 120)
638
 
 
639
  # ### Run CV
640
 
 
641
  import warnings
642
 
643
  # Seed everything in pytorch lightning
 
751
  metrics.update(test_metrics)
752
  return model, trainer, metrics
753
 
 
754
  # Setup hyperparameter optimization:
755
 
 
 
 
 
 
756
  def objective(
757
  trial,
758
  train_df,
 
866
  # train_df, val_df, test_df = load_your_data() # You need to load your datasets here
867
  # model, trainer, best_metrics = hyperparameter_tuning_and_training(train_df, val_df, test_df)
868
 
 
869
  # Loop over the different splits and train the model:
870
 
 
871
  n_splits = 5
872
  report = []
873
  active_df = protac_df[protac_df[active_col].notna()]
 
936
  report.to_csv(
937
  f'../reports/cv_report_hparam_search_{n_splits}-splits.csv', index=False,
938
  )