ribesstefano commited on
Commit
bda3015
·
1 Parent(s): fda7af7

Added softmax in model + Fixed and updated some hparams + Fixed bug in tanimoto distance (it was treated as similarity)

Browse files
protac_degradation_predictor/optuna_utils.py CHANGED
@@ -55,14 +55,17 @@ def get_dataframe_stats(
55
  stats['train_len'] = len(train_df)
56
  stats['train_active_perc'] = train_df[active_label].sum() / len(train_df)
57
  stats['train_inactive_perc'] = (len(train_df) - train_df[active_label].sum()) / len(train_df)
 
58
  if val_df is not None:
59
  stats['val_len'] = len(val_df)
60
  stats['val_active_perc'] = val_df[active_label].sum() / len(val_df)
61
  stats['val_inactive_perc'] = (len(val_df) - val_df[active_label].sum()) / len(val_df)
 
62
  if test_df is not None:
63
  stats['test_len'] = len(test_df)
64
  stats['test_active_perc'] = test_df[active_label].sum() / len(test_df)
65
  stats['test_inactive_perc'] = (len(test_df) - test_df[active_label].sum()) / len(test_df)
 
66
  if train_df is not None and val_df is not None:
67
  leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(val_df['Uniprot'])))
68
  leaking_smiles = list(set(train_df['Smiles']).intersection(set(val_df['Smiles'])))
@@ -98,6 +101,10 @@ def pytorch_model_objective(
98
  active_label: str = 'Active',
99
  disabled_embeddings: List[str] = [],
100
  max_epochs: int = 100,
 
 
 
 
101
  ) -> float:
102
  """ Objective function for hyperparameter optimization.
103
 
@@ -116,11 +123,11 @@ def pytorch_model_objective(
116
  """
117
  # Suggest hyperparameters to be used accross the CV folds
118
  hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
119
- batch_size = trial.suggest_categorical('batch_size', batch_size_options)
120
  learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
121
  smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
122
  use_smote = trial.suggest_categorical('use_smote', [True, False])
123
- apply_scaling = trial.suggest_categorical('apply_scaling', [True, False])
124
  dropout = trial.suggest_float('dropout', *dropout_options)
125
 
126
  # Start the CV over the folds
@@ -166,11 +173,14 @@ def pytorch_model_objective(
166
  smote_k_neighbors=smote_k_neighbors,
167
  apply_scaling=apply_scaling,
168
  use_smote=use_smote,
169
- use_logger=False,
170
  fast_dev_run=fast_dev_run,
171
  active_label=active_label,
172
  return_predictions=True,
173
  disabled_embeddings=disabled_embeddings,
 
 
 
 
174
  )
175
  if test_df is not None:
176
  _, _, metrics, val_pred, test_pred = ret
@@ -246,11 +256,13 @@ def hyperparameter_tuning_and_training(
246
  pl.seed_everything(42)
247
 
248
  # Define the search space
249
- hidden_dim_options = [32, 64, 128, 256, 512, 768]
250
- batch_size_options = [4, 8, 16, 32, 64, 128]
251
- learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
252
  smote_k_neighbors_options = list(range(3, 16))
253
- dropout_options = (0.2, 0.9)
 
 
254
 
255
  # Set the verbosity of Optuna
256
  optuna.logging.set_verbosity(optuna.logging.WARNING)
@@ -293,6 +305,31 @@ def hyperparameter_tuning_and_training(
293
  cv_report = pd.DataFrame(study.best_trial.user_attrs['report'])
294
  hparam_report = pd.DataFrame([study.best_params])
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  # Retrain N models with the best hyperparameters (measure model uncertainty)
297
  test_report = []
298
  test_preds = []
@@ -315,6 +352,8 @@ def hyperparameter_tuning_and_training(
315
  enable_checkpointing=True,
316
  checkpoint_model_name=f'best_model_n{i}_{split_type}',
317
  return_predictions=True,
 
 
318
  **study.best_params,
319
  )
320
  # Rename the keys in the metrics dictionary
@@ -371,6 +410,8 @@ def hyperparameter_tuning_and_training(
371
  logger_save_dir=logger_save_dir,
372
  logger_name=f'{logger_name}_disabled-{"-".join(disabled_embeddings)}',
373
  disabled_embeddings=disabled_embeddings,
 
 
374
  **study.best_params,
375
  )
376
  # Rename the keys in the metrics dictionary
 
55
  stats['train_len'] = len(train_df)
56
  stats['train_active_perc'] = train_df[active_label].sum() / len(train_df)
57
  stats['train_inactive_perc'] = (len(train_df) - train_df[active_label].sum()) / len(train_df)
58
+ stats['train_avg_tanimoto_dist'] = train_df['Avg Tanimoto'].mean()
59
  if val_df is not None:
60
  stats['val_len'] = len(val_df)
61
  stats['val_active_perc'] = val_df[active_label].sum() / len(val_df)
62
  stats['val_inactive_perc'] = (len(val_df) - val_df[active_label].sum()) / len(val_df)
63
+ stats['val_avg_tanimoto_dist'] = val_df['Avg Tanimoto'].mean()
64
  if test_df is not None:
65
  stats['test_len'] = len(test_df)
66
  stats['test_active_perc'] = test_df[active_label].sum() / len(test_df)
67
  stats['test_inactive_perc'] = (len(test_df) - test_df[active_label].sum()) / len(test_df)
68
+ stats['test_avg_tanimoto_dist'] = test_df['Avg Tanimoto'].mean()
69
  if train_df is not None and val_df is not None:
70
  leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(val_df['Uniprot'])))
71
  leaking_smiles = list(set(train_df['Smiles']).intersection(set(val_df['Smiles'])))
 
101
  active_label: str = 'Active',
102
  disabled_embeddings: List[str] = [],
103
  max_epochs: int = 100,
104
+ use_logger: bool = False,
105
+ logger_save_dir: str = 'logs',
106
+ logger_name: str = 'cv_model',
107
+ enable_checkpointing: bool = False,
108
  ) -> float:
109
  """ Objective function for hyperparameter optimization.
110
 
 
123
  """
124
  # Suggest hyperparameters to be used accross the CV folds
125
  hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
126
+ batch_size = 128 # trial.suggest_categorical('batch_size', batch_size_options)
127
  learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
128
  smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
129
  use_smote = trial.suggest_categorical('use_smote', [True, False])
130
+ apply_scaling = True # trial.suggest_categorical('apply_scaling', [True, False])
131
  dropout = trial.suggest_float('dropout', *dropout_options)
132
 
133
  # Start the CV over the folds
 
173
  smote_k_neighbors=smote_k_neighbors,
174
  apply_scaling=apply_scaling,
175
  use_smote=use_smote,
 
176
  fast_dev_run=fast_dev_run,
177
  active_label=active_label,
178
  return_predictions=True,
179
  disabled_embeddings=disabled_embeddings,
180
+ use_logger=use_logger,
181
+ logger_save_dir=logger_save_dir,
182
+ logger_name=f'{logger_name}_fold{k}',
183
+ enable_checkpointing=enable_checkpointing,
184
  )
185
  if test_df is not None:
186
  _, _, metrics, val_pred, test_pred = ret
 
256
  pl.seed_everything(42)
257
 
258
  # Define the search space
259
+ hidden_dim_options = [32, 64, 128, 256, 512]
260
+ batch_size_options = [128, 128] # [4, 8, 16, 32, 64, 128]
261
+ learning_rate_options = (1e-6, 1e-3) # min and max values for loguniform distribution
262
  smote_k_neighbors_options = list(range(3, 16))
263
+ # NOTE: We want Optuna to explore the combination (very low dropout, very
264
+ # small hidden_dim)
265
+ dropout_options = (0, 0.5)
266
 
267
  # Set the verbosity of Optuna
268
  optuna.logging.set_verbosity(optuna.logging.WARNING)
 
305
  cv_report = pd.DataFrame(study.best_trial.user_attrs['report'])
306
  hparam_report = pd.DataFrame([study.best_params])
307
 
308
+ # Train the best CV models and store their checkpoints by running the objective
309
+ pytorch_model_objective(
310
+ trial=study.best_trial,
311
+ protein2embedding=protein2embedding,
312
+ cell2embedding=cell2embedding,
313
+ smiles2fp=smiles2fp,
314
+ train_val_df=train_val_df,
315
+ kf=kf,
316
+ groups=groups,
317
+ test_df=test_df,
318
+ hidden_dim_options=hidden_dim_options,
319
+ batch_size_options=batch_size_options,
320
+ learning_rate_options=learning_rate_options,
321
+ smote_k_neighbors_options=smote_k_neighbors_options,
322
+ dropout_options=dropout_options,
323
+ fast_dev_run=fast_dev_run,
324
+ active_label=active_label,
325
+ max_epochs=max_epochs,
326
+ disabled_embeddings=[],
327
+ use_logger=True,
328
+ logger_save_dir=logger_save_dir,
329
+ logger_name=f'{logger_name}_{split_type}_cv_model',
330
+ enable_checkpointing=True,
331
+ )
332
+
333
  # Retrain N models with the best hyperparameters (measure model uncertainty)
334
  test_report = []
335
  test_preds = []
 
352
  enable_checkpointing=True,
353
  checkpoint_model_name=f'best_model_n{i}_{split_type}',
354
  return_predictions=True,
355
+ batch_size=128,
356
+ apply_scaling=True,
357
  **study.best_params,
358
  )
359
  # Rename the keys in the metrics dictionary
 
410
  logger_save_dir=logger_save_dir,
411
  logger_name=f'{logger_name}_disabled-{"-".join(disabled_embeddings)}',
412
  disabled_embeddings=disabled_embeddings,
413
+ batch_size=128,
414
+ apply_scaling=True,
415
  **study.best_params,
416
  )
417
  # Rename the keys in the metrics dictionary
protac_degradation_predictor/pytorch_models.py CHANGED
@@ -63,15 +63,29 @@ class PROTAC_Predictor(nn.Module):
63
  self.__dict__.update(locals())
64
 
65
  # Define "surrogate models" branches
 
 
66
  if self.join_embeddings != 'beginning':
67
  if 'poi' not in self.disabled_embeddings:
68
- self.poi_emb = nn.Linear(poi_emb_dim, hidden_dim)
 
 
 
69
  if 'e3' not in self.disabled_embeddings:
70
- self.e3_emb = nn.Linear(e3_emb_dim, hidden_dim)
 
 
 
71
  if 'cell' not in self.disabled_embeddings:
72
- self.cell_emb = nn.Linear(cell_emb_dim, hidden_dim)
 
 
 
73
  if 'smiles' not in self.disabled_embeddings:
74
- self.smiles_emb = nn.Linear(smiles_emb_dim, hidden_dim)
 
 
 
75
 
76
  # Define hidden dimension for joining layer
77
  if self.join_embeddings == 'beginning':
@@ -95,6 +109,7 @@ class PROTAC_Predictor(nn.Module):
95
  def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
96
  embeddings = []
97
  if self.join_embeddings == 'beginning':
 
98
  if 'poi' not in self.disabled_embeddings:
99
  embeddings.append(poi_emb)
100
  if 'e3' not in self.disabled_embeddings:
@@ -123,7 +138,6 @@ class PROTAC_Predictor(nn.Module):
123
  else:
124
  x = embeddings[0]
125
  x = self.dropout(F.relu(self.fc1(x)))
126
- x = self.dropout(F.relu(self.fc2(x)))
127
  x = self.fc3(x)
128
  return x
129
 
@@ -137,7 +151,7 @@ class PROTAC_Model(pl.LightningModule):
137
  poi_emb_dim: int = config.protein_embedding_size,
138
  e3_emb_dim: int = config.protein_embedding_size,
139
  cell_emb_dim: int = config.cell_embedding_size,
140
- batch_size: int = 32,
141
  learning_rate: float = 1e-3,
142
  dropout: float = 0.2,
143
  join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
@@ -145,7 +159,7 @@ class PROTAC_Model(pl.LightningModule):
145
  val_dataset: PROTAC_Dataset = None,
146
  test_dataset: PROTAC_Dataset = None,
147
  disabled_embeddings: list = [],
148
- apply_scaling: bool = False,
149
  ):
150
  """ Initialize the PROTAC Pytorch Lightning model.
151
 
@@ -388,7 +402,7 @@ def train_model(
388
  val_df: pd.DataFrame,
389
  test_df: Optional[pd.DataFrame] = None,
390
  hidden_dim: int = 768,
391
- batch_size: int = 8,
392
  learning_rate: float = 2e-5,
393
  dropout: float = 0.2,
394
  max_epochs: int = 50,
@@ -399,7 +413,7 @@ def train_model(
399
  join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
400
  smote_k_neighbors:int = 5,
401
  use_smote: bool = True,
402
- apply_scaling: bool = False,
403
  active_label: str = 'Active',
404
  fast_dev_run: bool = False,
405
  use_logger: bool = True,
@@ -508,7 +522,7 @@ def train_model(
508
  logger=loggers if use_logger else False,
509
  callbacks=callbacks,
510
  max_epochs=max_epochs,
511
- val_check_interval=0.5,
512
  fast_dev_run=fast_dev_run,
513
  enable_model_summary=False,
514
  enable_checkpointing=enable_checkpointing,
 
63
  self.__dict__.update(locals())
64
 
65
  # Define "surrogate models" branches
66
+ # NOTE: The softmax is used to ensure that the embeddings are normalized
67
+ # and can be summed on a "similar scale".
68
  if self.join_embeddings != 'beginning':
69
  if 'poi' not in self.disabled_embeddings:
70
+ self.poi_emb = nn.Sequential(
71
+ nn.Linear(poi_emb_dim, hidden_dim),
72
+ nn.Softmax(dim=1),
73
+ )
74
  if 'e3' not in self.disabled_embeddings:
75
+ self.e3_emb = nn.Sequential(
76
+ nn.Linear(e3_emb_dim, hidden_dim),
77
+ nn.Softmax(dim=1),
78
+ )
79
  if 'cell' not in self.disabled_embeddings:
80
+ self.cell_emb = nn.Sequential(
81
+ nn.Linear(cell_emb_dim, hidden_dim),
82
+ nn.Softmax(dim=1),
83
+ )
84
  if 'smiles' not in self.disabled_embeddings:
85
+ self.smiles_emb = nn.Sequential(
86
+ nn.Linear(smiles_emb_dim, hidden_dim),
87
+ nn.Softmax(dim=1),
88
+ )
89
 
90
  # Define hidden dimension for joining layer
91
  if self.join_embeddings == 'beginning':
 
109
  def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
110
  embeddings = []
111
  if self.join_embeddings == 'beginning':
112
+ # TODO: Remove this if-branch
113
  if 'poi' not in self.disabled_embeddings:
114
  embeddings.append(poi_emb)
115
  if 'e3' not in self.disabled_embeddings:
 
138
  else:
139
  x = embeddings[0]
140
  x = self.dropout(F.relu(self.fc1(x)))
 
141
  x = self.fc3(x)
142
  return x
143
 
 
151
  poi_emb_dim: int = config.protein_embedding_size,
152
  e3_emb_dim: int = config.protein_embedding_size,
153
  cell_emb_dim: int = config.cell_embedding_size,
154
+ batch_size: int = 128,
155
  learning_rate: float = 1e-3,
156
  dropout: float = 0.2,
157
  join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
 
159
  val_dataset: PROTAC_Dataset = None,
160
  test_dataset: PROTAC_Dataset = None,
161
  disabled_embeddings: list = [],
162
+ apply_scaling: bool = True,
163
  ):
164
  """ Initialize the PROTAC Pytorch Lightning model.
165
 
 
402
  val_df: pd.DataFrame,
403
  test_df: Optional[pd.DataFrame] = None,
404
  hidden_dim: int = 768,
405
+ batch_size: int = 128,
406
  learning_rate: float = 2e-5,
407
  dropout: float = 0.2,
408
  max_epochs: int = 50,
 
413
  join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
414
  smote_k_neighbors:int = 5,
415
  use_smote: bool = True,
416
+ apply_scaling: bool = True,
417
  active_label: str = 'Active',
418
  fast_dev_run: bool = False,
419
  use_logger: bool = True,
 
522
  logger=loggers if use_logger else False,
523
  callbacks=callbacks,
524
  max_epochs=max_epochs,
525
+ # val_check_interval=0.5,
526
  fast_dev_run=fast_dev_run,
527
  enable_model_summary=False,
528
  enable_checkpointing=enable_checkpointing,
src/plot_experiment_results.py CHANGED
@@ -12,7 +12,9 @@ import numpy as np
12
  palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
13
 
14
 
15
- def plot_training_curves(df, split_type):
 
 
16
  # Clean the data
17
  df = df.dropna(how='all', axis=1)
18
 
@@ -26,14 +28,14 @@ def plot_training_curves(df, split_type):
26
 
27
  # Plot training loss
28
  ax[0].plot(epoch_data.index, epoch_data['train_loss_epoch'], label='Training Loss')
29
- ax[0].plot(epoch_data.index, epoch_data['test_loss'], label='Test Loss', linestyle='--')
30
  ax[0].set_ylabel('Loss')
31
  ax[0].legend(loc='lower right')
32
  ax[0].grid(axis='both', alpha=0.5)
33
 
34
  # Plot training accuracy
35
  ax[1].plot(epoch_data.index, epoch_data['train_acc_epoch'], label='Training Accuracy')
36
- ax[1].plot(epoch_data.index, epoch_data['test_acc'], label='Test Accuracy', linestyle='--')
37
  ax[1].set_ylabel('Accuracy')
38
  ax[1].legend(loc='lower right')
39
  ax[1].grid(axis='both', alpha=0.5)
@@ -44,7 +46,7 @@ def plot_training_curves(df, split_type):
44
 
45
  # Plot training ROC-AUC
46
  ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
47
- ax[2].plot(epoch_data.index, epoch_data['test_roc_auc'], label='Test ROC-AUC', linestyle='--')
48
  ax[2].set_ylabel('ROC-AUC')
49
  ax[2].legend(loc='lower right')
50
  ax[2].grid(axis='both', alpha=0.5)
@@ -270,10 +272,18 @@ def plot_ablation_study(report):
270
  plt.savefig(f'plots/ablation_study_{group}.pdf', bbox_inches='tight')
271
 
272
 
 
 
 
 
 
 
 
273
  def main():
274
  active_col = 'Active (Dmax 0.6, pDC50 6.0)'
275
  test_split = 0.1
276
  n_models_for_test = 3
 
277
 
278
  active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
279
  report_base_name = f'{active_name}_test_split_{test_split}'
@@ -300,25 +310,36 @@ def main():
300
  pd.read_csv(f'reports/hparam_report_{report_base_name}_uniprot.csv'),
301
  pd.read_csv(f'reports/hparam_report_{report_base_name}_tanimoto.csv'),
302
  ]),
 
 
 
 
 
303
  }
304
 
305
 
306
- metrics = {}
307
- for i in range(n_models_for_test):
308
- for split_type in ['random', 'tanimoto', 'uniprot', 'e3_ligase']:
309
  logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
310
- metrics[f'{split_type}_{i}'] = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
311
- metrics[f'{split_type}_{i}']['model_id'] = i
312
  # Rename 'val_' columns to 'test_' columns
313
- metrics[f'{split_type}_{i}'] = metrics[f'{split_type}_{i}'].rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'})
314
-
315
- plot_training_curves(metrics[f'{split_type}_{i}'], f'{split_type}_{i}')
316
 
 
 
 
 
 
 
317
 
318
  df_val = reports['cv_train']
319
  df_test = reports['test']
320
  plot_performance_metrics(df_val, df_test, title=f'{active_name}_metrics')
321
 
 
 
322
  reports['test']['disabled_embeddings'] = pd.NA
323
  plot_ablation_study(pd.concat([
324
  reports['ablation'],
 
12
  palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
13
 
14
 
15
+ def plot_training_curves(df, split_type, stage='test'):
16
+ Stage = 'Test' if stage == 'test' else 'Validation'
17
+
18
  # Clean the data
19
  df = df.dropna(how='all', axis=1)
20
 
 
28
 
29
  # Plot training loss
30
  ax[0].plot(epoch_data.index, epoch_data['train_loss_epoch'], label='Training Loss')
31
+ ax[0].plot(epoch_data.index, epoch_data[f'{stage}_loss'], label=f'{Stage} Loss', linestyle='--')
32
  ax[0].set_ylabel('Loss')
33
  ax[0].legend(loc='lower right')
34
  ax[0].grid(axis='both', alpha=0.5)
35
 
36
  # Plot training accuracy
37
  ax[1].plot(epoch_data.index, epoch_data['train_acc_epoch'], label='Training Accuracy')
38
+ ax[1].plot(epoch_data.index, epoch_data[f'{stage}_acc'], label=f'{Stage} Accuracy', linestyle='--')
39
  ax[1].set_ylabel('Accuracy')
40
  ax[1].legend(loc='lower right')
41
  ax[1].grid(axis='both', alpha=0.5)
 
46
 
47
  # Plot training ROC-AUC
48
  ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
49
+ ax[2].plot(epoch_data.index, epoch_data[f'{stage}_roc_auc'], label=f'{Stage} ROC-AUC', linestyle='--')
50
  ax[2].set_ylabel('ROC-AUC')
51
  ax[2].legend(loc='lower right')
52
  ax[2].grid(axis='both', alpha=0.5)
 
272
  plt.savefig(f'plots/ablation_study_{group}.pdf', bbox_inches='tight')
273
 
274
 
275
+ def plot_majority_voting_performance(df):
276
+ # cv_models,test_acc,test_roc_auc,split_type
277
+ # Melt the dataframe
278
+ df = df.melt(id_vars=['cv_models', 'test_acc', 'test_roc_auc', 'split_type'], var_name='Metric', value_name='Score')
279
+ print(df)
280
+
281
+
282
  def main():
283
  active_col = 'Active (Dmax 0.6, pDC50 6.0)'
284
  test_split = 0.1
285
  n_models_for_test = 3
286
+ cv_n_folds = 5
287
 
288
  active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
289
  report_base_name = f'{active_name}_test_split_{test_split}'
 
310
  pd.read_csv(f'reports/hparam_report_{report_base_name}_uniprot.csv'),
311
  pd.read_csv(f'reports/hparam_report_{report_base_name}_tanimoto.csv'),
312
  ]),
313
+ 'majority_vote': pd.concat([
314
+ pd.read_csv(f'reports/majority_vote_report_{report_base_name}_random.csv'),
315
+ pd.read_csv(f'reports/majority_vote_report_{report_base_name}_uniprot.csv'),
316
+ pd.read_csv(f'reports/majority_vote_report_{report_base_name}_tanimoto.csv'),
317
+ ]),
318
  }
319
 
320
 
321
+ for split_type in ['random', 'tanimoto', 'uniprot']:
322
+ for i in range(n_models_for_test):
 
323
  logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
324
+ metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
325
+ metrics['model_id'] = i
326
  # Rename 'val_' columns to 'test_' columns
327
+ metrics = metrics.rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'})
328
+ plot_training_curves(metrics, f'{split_type}_best_model_n{i}')
 
329
 
330
+ for i in range(cv_n_folds):
331
+ # logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
332
+ logs_dir = f'{split_type}_cv_model_fold{i}'
333
+ metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
334
+ metrics['fold'] = i
335
+ plot_training_curves(metrics, f'{split_type}_cv_model_fold{i}', stage='val')
336
 
337
  df_val = reports['cv_train']
338
  df_test = reports['test']
339
  plot_performance_metrics(df_val, df_test, title=f'{active_name}_metrics')
340
 
341
+ plot_majority_voting_performance(reports['majority_vote'])
342
+
343
  reports['test']['disabled_embeddings'] = pd.NA
344
  plot_ablation_study(pd.concat([
345
  reports['ablation'],
src/run_experiments.py CHANGED
@@ -8,6 +8,7 @@ from typing import Literal
8
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
 
10
  import protac_degradation_predictor as pdp
 
11
 
12
  import pytorch_lightning as pl
13
  from rdkit import Chem
@@ -77,21 +78,38 @@ def get_smiles2fp_and_avg_tanimoto(protac_df: pd.DataFrame) -> tuple:
77
  Returns:
78
  tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity.
79
  """
 
 
80
  smiles2fp = {}
81
- for smiles in tqdm(protac_df['Smiles'].unique().tolist(), desc='Precomputing fingerprints'):
82
  smiles2fp[smiles] = pdp.get_fingerprint(smiles)
83
 
84
- # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
 
 
 
 
 
 
 
 
 
 
 
 
85
  tanimoto_matrix = defaultdict(list)
86
- for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
87
- fp1 = smiles2fp[smiles1]
88
- # TODO: Use BulkTanimotoSimilarity for better performance
89
- for j, smiles2 in enumerate(protac_df['Smiles'].unique()):
90
- if j < i:
91
- continue
92
- fp2 = smiles2fp[smiles2]
93
- tanimoto_dist = DataStructs.TanimotoSimilarity(fp1, fp2)
94
- tanimoto_matrix[smiles1].append(tanimoto_dist)
 
 
 
95
  avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
96
  protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
97
 
@@ -256,7 +274,7 @@ def main(
256
  test_indeces['e3_ligase'] = get_e3_ligase_split_indices(active_df)
257
  if experiments == 'tanimoto' or experiments == 'all':
258
  test_indeces['tanimoto'] = get_tanimoto_split_indices(active_df, active_col, test_split)
259
-
260
  # Make directory ../reports if it does not exist
261
  if not os.path.exists('../reports'):
262
  os.makedirs('../reports')
 
8
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
 
10
  import protac_degradation_predictor as pdp
11
+ from protac_degradation_predictor.optuna_utils import get_dataframe_stats
12
 
13
  import pytorch_lightning as pl
14
  from rdkit import Chem
 
78
  Returns:
79
  tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity.
80
  """
81
+ unique_smiles = protac_df['Smiles'].unique().tolist()
82
+
83
  smiles2fp = {}
84
+ for smiles in tqdm(unique_smiles, desc='Precomputing fingerprints'):
85
  smiles2fp[smiles] = pdp.get_fingerprint(smiles)
86
 
87
+ # # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
88
+ # tanimoto_matrix = defaultdict(list)
89
+ # for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
90
+ # fp1 = smiles2fp[smiles1]
91
+ # # TODO: Use BulkTanimotoSimilarity for better performance
92
+ # for j, smiles2 in enumerate(protac_df['Smiles'].unique()[i:]):
93
+ # fp2 = smiles2fp[smiles2]
94
+ # tanimoto_dist = 1 - DataStructs.TanimotoSimilarity(fp1, fp2)
95
+ # tanimoto_matrix[smiles1].append(tanimoto_dist)
96
+ # avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
97
+ # protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
98
+
99
+
100
  tanimoto_matrix = defaultdict(list)
101
+ fps = list(smiles2fp.values())
102
+
103
+ # Compute all-against-all Tanimoto similarity using BulkTanimotoSimilarity
104
+ for i, (smiles1, fp1) in enumerate(tqdm(zip(unique_smiles, fps), desc='Computing Tanimoto similarity', total=len(fps))):
105
+ similarities = DataStructs.BulkTanimotoSimilarity(fp1, fps[i:]) # Only compute for i to end, avoiding duplicates
106
+ for j, similarity in enumerate(similarities):
107
+ distance = 1 - similarity
108
+ tanimoto_matrix[smiles1].append(distance) # Store as distance
109
+ if i != i + j:
110
+ tanimoto_matrix[unique_smiles[i + j]].append(distance) # Symmetric filling
111
+
112
+ # Calculate average Tanimoto distance for each unique SMILES
113
  avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
114
  protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
115
 
 
274
  test_indeces['e3_ligase'] = get_e3_ligase_split_indices(active_df)
275
  if experiments == 'tanimoto' or experiments == 'all':
276
  test_indeces['tanimoto'] = get_tanimoto_split_indices(active_df, active_col, test_split)
277
+
278
  # Make directory ../reports if it does not exist
279
  if not os.path.exists('../reports'):
280
  os.makedirs('../reports')