ribesstefano commited on
Commit
62ccb16
1 Parent(s): b86d3ec

Added LR scheduler + set default sum of embeddings

Browse files
protac_degradation_predictor/config.py CHANGED
@@ -4,7 +4,7 @@ from dataclasses import dataclass, field
4
  class Config:
5
  # Embeddings information
6
  morgan_radius: int = 15
7
- fingerprint_size: int = 224
8
  protein_embedding_size: int = 1024
9
  cell_embedding_size: int = 768
10
 
 
4
  class Config:
5
  # Embeddings information
6
  morgan_radius: int = 15
7
+ fingerprint_size: int = 256 # 224
8
  protein_embedding_size: int = 1024
9
  cell_embedding_size: int = 768
10
 
protac_degradation_predictor/optuna_utils.py CHANGED
@@ -118,7 +118,6 @@ def pytorch_model_objective(
118
  hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
119
  batch_size = trial.suggest_categorical('batch_size', batch_size_options)
120
  learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
121
- join_embeddings = trial.suggest_categorical('join_embeddings', ['beginning', 'concat', 'sum'])
122
  smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
123
  use_smote = trial.suggest_categorical('use_smote', [True, False])
124
  apply_scaling = trial.suggest_categorical('apply_scaling', [True, False])
@@ -161,7 +160,6 @@ def pytorch_model_objective(
161
  test_df=test_df,
162
  hidden_dim=hidden_dim,
163
  batch_size=batch_size,
164
- join_embeddings=join_embeddings,
165
  learning_rate=learning_rate,
166
  dropout=dropout,
167
  max_epochs=max_epochs,
@@ -177,7 +175,6 @@ def pytorch_model_objective(
177
  if test_df is not None:
178
  _, trainer, metrics, val_pred, test_pred = ret
179
  test_preds.append(test_pred)
180
- logging.info(f'Test predictions: {test_pred}')
181
  else:
182
  _, trainer, metrics, val_pred = ret
183
  train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
@@ -190,7 +187,7 @@ def pytorch_model_objective(
190
  trial.set_user_attr('report', report)
191
 
192
  # Get the majority vote for the test predictions
193
- if test_df is not None:
194
  # Get the majority vote for the test predictions
195
  test_preds = torch.stack(test_preds)
196
  test_preds, _ = torch.mode(test_preds, dim=0)
@@ -340,27 +337,28 @@ def hyperparameter_tuning_and_training(
340
  test_report = pd.DataFrame(test_report)
341
 
342
  # Get the majority vote for the test predictions
343
- test_preds = torch.stack(test_preds)
344
- test_preds, _ = torch.mode(test_preds, dim=0)
345
- y = torch.tensor(test_df[active_label].tolist())
346
- # Measure the test accuracy and ROC AUC
347
- majority_vote_metrics = {
348
- 'cv_models': False,
349
- 'test_acc': Accuracy(task='binary')(test_preds, y).item(),
350
- 'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
351
- 'test_precision': Precision(task='binary')(test_preds, y).item(),
352
- 'test_recall': Recall(task='binary')(test_preds, y).item(),
353
- 'test_f1': F1Score(task='binary')(test_preds, y).item(),
354
- }
355
- majority_vote_metrics.update(get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label))
356
- majority_vote_metrics_cv = study.best_trial.user_attrs['majority_vote_metrics']
357
- majority_vote_metrics_cv['cv_models'] = True
358
- majority_vote_report = pd.DataFrame([
359
- majority_vote_metrics,
360
- majority_vote_metrics_cv,
361
- ])
362
- majority_vote_report['model_type'] = 'Pytorch'
363
- majority_vote_report['split_type'] = split_type
 
364
 
365
  # Ablation study: disable embeddings at a time
366
  ablation_report = []
@@ -407,8 +405,9 @@ def hyperparameter_tuning_and_training(
407
  'hparam_report': hparam_report,
408
  'test_report': test_report,
409
  'ablation_report': ablation_report,
410
- 'majority_vote_report': majority_vote_report,
411
  }
 
 
412
  return ret
413
 
414
 
 
118
  hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
119
  batch_size = trial.suggest_categorical('batch_size', batch_size_options)
120
  learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
 
121
  smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
122
  use_smote = trial.suggest_categorical('use_smote', [True, False])
123
  apply_scaling = trial.suggest_categorical('apply_scaling', [True, False])
 
160
  test_df=test_df,
161
  hidden_dim=hidden_dim,
162
  batch_size=batch_size,
 
163
  learning_rate=learning_rate,
164
  dropout=dropout,
165
  max_epochs=max_epochs,
 
175
  if test_df is not None:
176
  _, trainer, metrics, val_pred, test_pred = ret
177
  test_preds.append(test_pred)
 
178
  else:
179
  _, trainer, metrics, val_pred = ret
180
  train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
 
187
  trial.set_user_attr('report', report)
188
 
189
  # Get the majority vote for the test predictions
190
+ if test_df is not None and not fast_dev_run:
191
  # Get the majority vote for the test predictions
192
  test_preds = torch.stack(test_preds)
193
  test_preds, _ = torch.mode(test_preds, dim=0)
 
337
  test_report = pd.DataFrame(test_report)
338
 
339
  # Get the majority vote for the test predictions
340
+ if not fast_dev_run:
341
+ test_preds = torch.stack(test_preds)
342
+ test_preds, _ = torch.mode(test_preds, dim=0)
343
+ y = torch.tensor(test_df[active_label].tolist())
344
+ # Measure the test accuracy and ROC AUC
345
+ majority_vote_metrics = {
346
+ 'cv_models': False,
347
+ 'test_acc': Accuracy(task='binary')(test_preds, y).item(),
348
+ 'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
349
+ 'test_precision': Precision(task='binary')(test_preds, y).item(),
350
+ 'test_recall': Recall(task='binary')(test_preds, y).item(),
351
+ 'test_f1': F1Score(task='binary')(test_preds, y).item(),
352
+ }
353
+ majority_vote_metrics.update(get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label))
354
+ majority_vote_metrics_cv = study.best_trial.user_attrs['majority_vote_metrics']
355
+ majority_vote_metrics_cv['cv_models'] = True
356
+ majority_vote_report = pd.DataFrame([
357
+ majority_vote_metrics,
358
+ majority_vote_metrics_cv,
359
+ ])
360
+ majority_vote_report['model_type'] = 'Pytorch'
361
+ majority_vote_report['split_type'] = split_type
362
 
363
  # Ablation study: disable embeddings at a time
364
  ablation_report = []
 
405
  'hparam_report': hparam_report,
406
  'test_report': test_report,
407
  'ablation_report': ablation_report,
 
408
  }
409
+ if not fast_dev_run:
410
+ ret['majority_vote_report'] = majority_vote_report
411
  return ret
412
 
413
 
protac_degradation_predictor/pytorch_models.py CHANGED
@@ -36,7 +36,7 @@ class PROTAC_Predictor(nn.Module):
36
  e3_emb_dim: int = config.protein_embedding_size,
37
  cell_emb_dim: int = config.cell_embedding_size,
38
  dropout: float = 0.2,
39
- join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
40
  disabled_embeddings: list = [],
41
  ):
42
  """ Initialize the PROTAC model.
@@ -140,7 +140,7 @@ class PROTAC_Model(pl.LightningModule):
140
  batch_size: int = 32,
141
  learning_rate: float = 1e-3,
142
  dropout: float = 0.2,
143
- join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
144
  train_dataset: PROTAC_Dataset = None,
145
  val_dataset: PROTAC_Dataset = None,
146
  test_dataset: PROTAC_Dataset = None,
@@ -308,7 +308,19 @@ class PROTAC_Model(pl.LightningModule):
308
  return self.step(batch, batch_idx, 'test')
309
 
310
  def configure_optimizers(self):
311
- return optim.Adam(self.parameters(), lr=self.learning_rate)
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  def predict_step(self, batch, batch_idx):
314
  poi_emb = batch['poi_emb']
@@ -384,7 +396,7 @@ def train_model(
384
  poi_emb_dim: int = config.protein_embedding_size,
385
  e3_emb_dim: int = config.protein_embedding_size,
386
  cell_emb_dim: int = config.cell_embedding_size,
387
- join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
388
  smote_k_neighbors:int = 5,
389
  use_smote: bool = True,
390
  apply_scaling: bool = False,
@@ -482,6 +494,8 @@ def train_model(
482
  verbose=False,
483
  ),
484
  ]
 
 
485
  if enable_checkpointing:
486
  callbacks.append(pl.callbacks.ModelCheckpoint(
487
  monitor='val_acc',
 
36
  e3_emb_dim: int = config.protein_embedding_size,
37
  cell_emb_dim: int = config.cell_embedding_size,
38
  dropout: float = 0.2,
39
+ join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
40
  disabled_embeddings: list = [],
41
  ):
42
  """ Initialize the PROTAC model.
 
140
  batch_size: int = 32,
141
  learning_rate: float = 1e-3,
142
  dropout: float = 0.2,
143
+ join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
144
  train_dataset: PROTAC_Dataset = None,
145
  val_dataset: PROTAC_Dataset = None,
146
  test_dataset: PROTAC_Dataset = None,
 
308
  return self.step(batch, batch_idx, 'test')
309
 
310
  def configure_optimizers(self):
311
+ optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
312
+ return {
313
+ 'optimizer': optimizer,
314
+ 'lr_scheduler': optim.lr_scheduler.ReduceLROnPlateau(
315
+ optimizer=optimizer,
316
+ mode='min',
317
+ factor=0.5,
318
+ patience=2,
319
+ ),
320
+ 'interval': 'step', # or 'epoch'
321
+ 'frequency': 1,
322
+ 'monitor': 'val_loss',
323
+ }
324
 
325
  def predict_step(self, batch, batch_idx):
326
  poi_emb = batch['poi_emb']
 
396
  poi_emb_dim: int = config.protein_embedding_size,
397
  e3_emb_dim: int = config.protein_embedding_size,
398
  cell_emb_dim: int = config.cell_embedding_size,
399
+ join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
400
  smote_k_neighbors:int = 5,
401
  use_smote: bool = True,
402
  apply_scaling: bool = False,
 
494
  verbose=False,
495
  ),
496
  ]
497
+ if use_logger:
498
+ callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='step'))
499
  if enable_checkpointing:
500
  callbacks.append(pl.callbacks.ModelCheckpoint(
501
  monitor='val_acc',
src/run_experiments.py CHANGED
@@ -309,15 +309,8 @@ def main(
309
 
310
  # Save the reports to file
311
  for report_name, report in optuna_reports.items():
312
- report.to_csv(f'../reports/report_{report_name}_{experiment_name}.csv', index=False)
313
  reports[report_name].append(report.copy())
314
-
315
- # Save the reports to file after concatenating them
316
- for report_name, report in reports.items():
317
- report = pd.concat(report)
318
- report.to_csv(f'../reports/report_{report_name}_{active_name}_test_split_{test_split}.csv', index=False)
319
-
320
-
321
 
322
  # # Start the CV over the folds
323
  # X = train_val_df.drop(columns=active_col)
 
309
 
310
  # Save the reports to file
311
  for report_name, report in optuna_reports.items():
312
+ report.to_csv(f'../reports/{report_name}_{experiment_name}.csv', index=False)
313
  reports[report_name].append(report.copy())
 
 
 
 
 
 
 
314
 
315
  # # Start the CV over the folds
316
  # X = train_val_df.drop(columns=active_col)