ribesstefano commited on
Commit
165d38a
1 Parent(s): 74be897

refactored package code and fine running experimental code

Browse files
protac_degradation_predictor/__init__.py CHANGED
@@ -1,14 +1,18 @@
1
- # from .protac_degradation_predictor.config import config
2
- # from .protac_degradation_predictor.pytorch_models import train_model
3
- # from .protac_degradation_predictor.pytorch_models import
4
- # from .protac_degradation_predictor.pytorch_models import
5
- from . import (
6
- config,
7
- pytorch_models,
8
- sklearn_models,
9
- protac_dataset,
10
- data_utils,
11
- optuna_utils,
 
 
 
 
12
  )
13
 
14
  __version__ = "0.0.1"
 
1
+ from .data_utils import (
2
+ load_protein2embedding,
3
+ load_cell2embedding,
4
+ get_fingerprint,
5
+ is_active,
6
+ )
7
+ from .pytorch_models import (
8
+ train_model,
9
+ )
10
+ from .sklearn_models import (
11
+ train_sklearn_model,
12
+ )
13
+ from .optuna_utils import (
14
+ hyperparameter_tuning_and_training,
15
+ hyperparameter_tuning_and_training_sklearn,
16
  )
17
 
18
  __version__ = "0.0.1"
protac_degradation_predictor/protac_degradation_predictor.py CHANGED
@@ -14,9 +14,6 @@ import torch
14
  from torch import sigmoid
15
 
16
 
17
- package_name = 'protac_degradation_predictor'
18
-
19
-
20
  def get_protac_active_proba(
21
  protac_smiles: str,
22
  e3_ligase: str,
 
14
  from torch import sigmoid
15
 
16
 
 
 
 
17
  def get_protac_active_proba(
18
  protac_smiles: str,
19
  e3_ligase: str,
setup.py CHANGED
@@ -9,7 +9,7 @@ setuptools.setup(
9
  description="A package to predict PROTAC-induced protein degradation.",
10
  long_description=open("README.md").read(),
11
  packages=setuptools.find_packages(),
12
- install_requires=["torch", "pytorch_lightning", "scikit-learn", "imblearn", "rdkit-pypi", "pandas", "joblib", "h5py", "optuna", "torchmetrics"],
13
  classifiers=[
14
  "Programming Language :: Python :: 3",
15
  "Programming Language :: Python :: 3.6",
 
9
  description="A package to predict PROTAC-induced protein degradation.",
10
  long_description=open("README.md").read(),
11
  packages=setuptools.find_packages(),
12
+ install_requires=["torch", "pytorch_lightning", "scikit-learn", "imbalanced-learn", "rdkit-pypi", "pandas", "joblib", "h5py", "optuna", "torchmetrics"],
13
  classifiers=[
14
  "Programming Language :: Python :: 3",
15
  "Programming Language :: Python :: 3.6",
src/{main.py → run_experiments.py} RENAMED
@@ -1,19 +1,12 @@
1
  import os
 
2
  from collections import defaultdict
3
  import warnings
4
 
5
- from protac_degradation_predictor.config import config
6
- from protac_degradation_predictor.data_utils import (
7
- load_protein2embedding,
8
- load_cell2embedding,
9
- is_active,
10
- )
11
- from protac_degradation_predictor.pytorch_models import (
12
- train_model,
13
- )
14
- from protac_degradation_predictor.optuna_utils import (
15
- hyperparameter_tuning_and_training,
16
- )
17
 
18
  from rdkit import Chem
19
  from rdkit.Chem import AllChem
@@ -28,78 +21,53 @@ from sklearn.model_selection import (
28
  StratifiedGroupKFold,
29
  )
30
 
31
-
32
  # Ignore UserWarning from Matplotlib
33
  warnings.filterwarnings("ignore", ".*FixedLocator*")
34
  # Ignore UserWarning from PyTorch Lightning
35
  warnings.filterwarnings("ignore", ".*does not have many workers.*")
36
 
37
 
38
- def main(
39
- active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
40
- n_trials: int = 50,
41
- fast_dev_run: bool = False,
42
- test_split: float = 0.2,
43
- cv_n_splits: int = 5,
44
- ):
45
- """ Train a PROTAC model using the given datasets and hyperparameters.
46
 
47
  Args:
48
- use_ored_activity (bool): Whether to use the 'Active - OR' column.
49
- n_trials (int): The number of hyperparameter optimization trials.
50
- n_splits (int): The number of cross-validation splits.
51
- fast_dev_run (bool): Whether to run a fast development run.
 
52
  """
53
- ## Set the Column to Predict
54
- active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
55
-
56
- # Get Dmax_threshold from the active_col
57
- Dmax_threshold = float(active_col.split('Dmax')[1].split(',')[0].strip('(').strip(')').strip())
58
- pDC50_threshold = float(active_col.split('pDC50')[1].strip('(').strip(')').strip())
59
-
60
- ## Load the Data
61
- protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
62
-
63
- # Map E3 Ligase Iap to IAP
64
- protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
65
-
66
- protac_df[active_col] = protac_df.apply(
67
- lambda x: is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1
68
- )
69
-
70
- ## Test Sets
71
-
72
- test_indeces = {}
73
-
74
- ### Random Split
75
-
76
- # Randomly select 20% of the active PROTACs as the test set
77
- active_df = protac_df[protac_df[active_col].notna()].copy()
78
  test_df = active_df.sample(frac=test_split, random_state=42)
79
- test_indeces['random'] = test_df.index
80
 
81
- ### E3-based Split
82
 
 
 
 
 
 
 
 
 
 
83
  encoder = OrdinalEncoder()
84
- protac_df['E3 Group'] = encoder.fit_transform(protac_df[['E3 Ligase']]).astype(int)
85
- active_df = protac_df[protac_df[active_col].notna()].copy()
86
  test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
87
- test_indeces['e3_ligase'] = test_df.index
88
 
89
- ### Tanimoto-based Split
90
-
91
- #### Precompute fingerprints
92
- morgan_fpgen = AllChem.GetMorganGenerator(
93
- radius=config.morgan_radius,
94
- fpSize=config.fingerprint_size,
95
- includeChirality=True,
96
- )
97
 
 
 
 
 
 
 
 
 
 
98
  smiles2fp = {}
99
  for smiles in tqdm(protac_df['Smiles'].unique().tolist(), desc='Precomputing fingerprints'):
100
- # Get the fingerprint as a bit vector
101
- morgan_fp = morgan_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
102
- smiles2fp[smiles] = morgan_fp
103
 
104
  # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
105
  tanimoto_matrix = defaultdict(list)
@@ -117,12 +85,27 @@ def main(
117
 
118
  smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
119
 
120
- # Make the grouping of the PROTACs based on the Tanimoto similarity
121
- n_bins_tanimoto = 200
122
- tanimoto_groups = pd.cut(protac_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  encoder = OrdinalEncoder()
124
- protac_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
125
- active_df = protac_df[protac_df[active_col].notna()].copy()
126
  # Sort the groups so that samples with the highest tanimoto similarity,
127
  # i.e., the "less similar" ones, are placed in the test set first
128
  tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
@@ -159,14 +142,22 @@ def main(
159
  if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
160
  test_df.append(group_df)
161
  test_df = pd.concat(test_df)
162
- # Save to global dictionary of test indeces
163
- test_indeces['tanimoto'] = test_df.index
164
 
165
- ### Target-based Split
 
166
 
 
 
 
 
 
 
 
 
167
  encoder = OrdinalEncoder()
168
- protac_df['Uniprot Group'] = encoder.fit_transform(protac_df[['Uniprot']]).astype(int)
169
- active_df = protac_df[protac_df[active_col].notna()].copy()
170
 
171
  test_df = []
172
  # For each group, get the number of active and inactive entries. Then, add those
@@ -201,25 +192,64 @@ def main(
201
  if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
202
  test_df.append(group_df)
203
  test_df = pd.concat(test_df)
204
- # Save to global dictionary of test indeces
205
- test_indeces['uniprot'] = test_df.index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- ## Cross-Validation Training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  # Make directory ../reports if it does not exist
210
  if not os.path.exists('../reports'):
211
  os.makedirs('../reports')
212
 
213
  # Load embedding dictionaries
214
- protein2embedding = load_protein2embedding('../data/uniprot2embedding.h5')
215
- cell2embedding = load_cell2embedding('../data/cell2embedding.pkl')
216
 
 
217
  report = []
218
  for split_type, indeces in test_indeces.items():
219
- active_df = protac_df[protac_df[active_col].notna()].copy()
220
  test_df = active_df.loc[indeces]
221
  train_val_df = active_df[~active_df.index.isin(test_df.index)]
222
 
 
223
  if split_type == 'random':
224
  kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
225
  group = None
@@ -232,6 +262,7 @@ def main(
232
  elif split_type == 'uniprot':
233
  kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
234
  group = train_val_df['Uniprot Group'].to_numpy()
 
235
  # Start the CV over the folds
236
  X = train_val_df.drop(columns=active_col)
237
  y = train_val_df[active_col].tolist()
@@ -269,7 +300,7 @@ def main(
269
 
270
  print(stats)
271
  # # Train and evaluate the model
272
- # model, trainer, metrics = hyperparameter_tuning_and_training(
273
  # protein2embedding,
274
  # cell2embedding,
275
  # smiles2fp,
@@ -294,7 +325,7 @@ def main(
294
  # print(f'Ablation study with disabled embeddings: {disabled_embeddings}')
295
  # print('-' * 100)
296
  # stats['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
297
- # model, trainer, metrics = train_model(
298
  # protein2embedding,
299
  # cell2embedding,
300
  # smiles2fp,
 
1
  import os
2
+ import sys
3
  from collections import defaultdict
4
  import warnings
5
 
6
+
7
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8
+
9
+ import protac_degradation_predictor as pdp
 
 
 
 
 
 
 
 
10
 
11
  from rdkit import Chem
12
  from rdkit.Chem import AllChem
 
21
  StratifiedGroupKFold,
22
  )
23
 
 
24
  # Ignore UserWarning from Matplotlib
25
  warnings.filterwarnings("ignore", ".*FixedLocator*")
26
  # Ignore UserWarning from PyTorch Lightning
27
  warnings.filterwarnings("ignore", ".*does not have many workers.*")
28
 
29
 
30
+ def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:
31
+ """ Get the indices of the test set using a random split.
 
 
 
 
 
 
32
 
33
  Args:
34
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
35
+ test_split (float): The percentage of the active PROTACs to use as the test set.
36
+
37
+ Returns:
38
+ pd.Index: The indices of the test set.
39
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  test_df = active_df.sample(frac=test_split, random_state=42)
41
+ return test_df.index
42
 
 
43
 
44
+ def get_e3_ligase_split_indices(active_df: pd.DataFrame) -> pd.Index:
45
+ """ Get the indices of the test set using the E3 ligase split.
46
+
47
+ Args:
48
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
49
+
50
+ Returns:
51
+ pd.Index: The indices of the test set.
52
+ """
53
  encoder = OrdinalEncoder()
54
+ active_df['E3 Group'] = encoder.fit_transform(active_df[['E3 Ligase']]).astype(int)
 
55
  test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
56
+ return test_df.index
57
 
 
 
 
 
 
 
 
 
58
 
59
+ def get_smiles2fp_and_avg_tanimoto(protac_df: pd.DataFrame) -> tuple:
60
+ """ Get the SMILES to fingerprint dictionary and the average Tanimoto similarity.
61
+
62
+ Args:
63
+ protac_df (pd.DataFrame): The DataFrame containing the PROTACs.
64
+
65
+ Returns:
66
+ tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity.
67
+ """
68
  smiles2fp = {}
69
  for smiles in tqdm(protac_df['Smiles'].unique().tolist(), desc='Precomputing fingerprints'):
70
+ smiles2fp[smiles] = pdp.get_fingerprint(smiles)
 
 
71
 
72
  # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
73
  tanimoto_matrix = defaultdict(list)
 
85
 
86
  smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
87
 
88
+ return smiles2fp, protac_df
89
+
90
+
91
+ def get_tanimoto_split_indices(
92
+ active_df: pd.DataFrame,
93
+ active_col: str,
94
+ test_split: float,
95
+ n_bins_tanimoto: int = 200,
96
+ ) -> pd.Index:
97
+ """ Get the indices of the test set using the Tanimoto-based split.
98
+
99
+ Args:
100
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
101
+ n_bins_tanimoto (int): The number of bins to use for the Tanimoto similarity.
102
+
103
+ Returns:
104
+ pd.Index: The indices of the test set.
105
+ """
106
+ tanimoto_groups = pd.cut(active_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
107
  encoder = OrdinalEncoder()
108
+ active_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
 
109
  # Sort the groups so that samples with the highest tanimoto similarity,
110
  # i.e., the "less similar" ones, are placed in the test set first
111
  tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
 
142
  if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
143
  test_df.append(group_df)
144
  test_df = pd.concat(test_df)
145
+ return test_df.index
146
+
147
 
148
+ def get_target_split_indices(active_df: pd.DataFrame, active_col: str, test_split: float) -> pd.Index:
149
+ """ Get the indices of the test set using the target-based split.
150
 
151
+ Args:
152
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
153
+ active_col (str): The column containing the active/inactive information.
154
+ test_split (float): The percentage of the active PROTACs to use as the test set.
155
+
156
+ Returns:
157
+ pd.Index: The indices of the test set.
158
+ """
159
  encoder = OrdinalEncoder()
160
+ active_df['Uniprot Group'] = encoder.fit_transform(active_df[['Uniprot']]).astype(int)
 
161
 
162
  test_df = []
163
  # For each group, get the number of active and inactive entries. Then, add those
 
192
  if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
193
  test_df.append(group_df)
194
  test_df = pd.concat(test_df)
195
+ return test_df.index
196
+
197
+
198
+ def main(
199
+ active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
200
+ n_trials: int = 50,
201
+ fast_dev_run: bool = False,
202
+ test_split: float = 0.2,
203
+ cv_n_splits: int = 5,
204
+ ):
205
+ """ Train a PROTAC model using the given datasets and hyperparameters.
206
+
207
+ Args:
208
+ use_ored_activity (bool): Whether to use the 'Active - OR' column.
209
+ n_trials (int): The number of hyperparameter optimization trials.
210
+ n_splits (int): The number of cross-validation splits.
211
+ fast_dev_run (bool): Whether to run a fast development run.
212
+ """
213
+ # Set the Column to Predict
214
+ active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
215
 
216
+ # Get Dmax_threshold from the active_col
217
+ Dmax_threshold = float(active_col.split('Dmax')[1].split(',')[0].strip('(').strip(')').strip())
218
+ pDC50_threshold = float(active_col.split('pDC50')[1].strip('(').strip(')').strip())
219
+
220
+ # Load the PROTAC dataset
221
+ protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
222
+ # Map E3 Ligase Iap to IAP
223
+ protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
224
+ protac_df[active_col] = protac_df.apply(
225
+ lambda x: pdp.is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1
226
+ )
227
+ smiles2fp, protac_df = get_smiles2fp_and_avg_tanimoto(protac_df)
228
+
229
+ ## Get the test sets
230
+ test_indeces = {}
231
+ active_df = protac_df[protac_df[active_col].notna()].copy()
232
+ test_indeces['random'] = get_random_split_indices(active_df, test_split)
233
+ test_indeces['e3_ligase'] = get_e3_ligase_split_indices(active_df)
234
+ test_indeces['tanimoto'] = get_tanimoto_split_indices(active_df, active_col, test_split)
235
+ test_indeces['uniprot'] = get_target_split_indices(active_df, active_col, test_split)
236
 
237
  # Make directory ../reports if it does not exist
238
  if not os.path.exists('../reports'):
239
  os.makedirs('../reports')
240
 
241
  # Load embedding dictionaries
242
+ protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
243
+ cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
244
 
245
+ # Cross-Validation Training
246
  report = []
247
  for split_type, indeces in test_indeces.items():
248
+ # active_df = protac_df[protac_df[active_col].notna()].copy()
249
  test_df = active_df.loc[indeces]
250
  train_val_df = active_df[~active_df.index.isin(test_df.index)]
251
 
252
+ # Get the CV object
253
  if split_type == 'random':
254
  kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
255
  group = None
 
262
  elif split_type == 'uniprot':
263
  kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
264
  group = train_val_df['Uniprot Group'].to_numpy()
265
+
266
  # Start the CV over the folds
267
  X = train_val_df.drop(columns=active_col)
268
  y = train_val_df[active_col].tolist()
 
300
 
301
  print(stats)
302
  # # Train and evaluate the model
303
+ # model, trainer, metrics = pdp.hyperparameter_tuning_and_training(
304
  # protein2embedding,
305
  # cell2embedding,
306
  # smiles2fp,
 
325
  # print(f'Ablation study with disabled embeddings: {disabled_embeddings}')
326
  # print('-' * 100)
327
  # stats['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
328
+ # model, trainer, metrics = pdp.train_model(
329
  # protein2embedding,
330
  # cell2embedding,
331
  # smiles2fp,