ribesstefano
commited on
Commit
•
62ccb16
1
Parent(s):
b86d3ec
Added LR scheduler + set default sum of embeddings
Browse files
protac_degradation_predictor/config.py
CHANGED
@@ -4,7 +4,7 @@ from dataclasses import dataclass, field
|
|
4 |
class Config:
|
5 |
# Embeddings information
|
6 |
morgan_radius: int = 15
|
7 |
-
fingerprint_size: int = 224
|
8 |
protein_embedding_size: int = 1024
|
9 |
cell_embedding_size: int = 768
|
10 |
|
|
|
4 |
class Config:
|
5 |
# Embeddings information
|
6 |
morgan_radius: int = 15
|
7 |
+
fingerprint_size: int = 256 # 224
|
8 |
protein_embedding_size: int = 1024
|
9 |
cell_embedding_size: int = 768
|
10 |
|
protac_degradation_predictor/optuna_utils.py
CHANGED
@@ -118,7 +118,6 @@ def pytorch_model_objective(
|
|
118 |
hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
|
119 |
batch_size = trial.suggest_categorical('batch_size', batch_size_options)
|
120 |
learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
|
121 |
-
join_embeddings = trial.suggest_categorical('join_embeddings', ['beginning', 'concat', 'sum'])
|
122 |
smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
|
123 |
use_smote = trial.suggest_categorical('use_smote', [True, False])
|
124 |
apply_scaling = trial.suggest_categorical('apply_scaling', [True, False])
|
@@ -161,7 +160,6 @@ def pytorch_model_objective(
|
|
161 |
test_df=test_df,
|
162 |
hidden_dim=hidden_dim,
|
163 |
batch_size=batch_size,
|
164 |
-
join_embeddings=join_embeddings,
|
165 |
learning_rate=learning_rate,
|
166 |
dropout=dropout,
|
167 |
max_epochs=max_epochs,
|
@@ -177,7 +175,6 @@ def pytorch_model_objective(
|
|
177 |
if test_df is not None:
|
178 |
_, trainer, metrics, val_pred, test_pred = ret
|
179 |
test_preds.append(test_pred)
|
180 |
-
logging.info(f'Test predictions: {test_pred}')
|
181 |
else:
|
182 |
_, trainer, metrics, val_pred = ret
|
183 |
train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
|
@@ -190,7 +187,7 @@ def pytorch_model_objective(
|
|
190 |
trial.set_user_attr('report', report)
|
191 |
|
192 |
# Get the majority vote for the test predictions
|
193 |
-
if test_df is not None:
|
194 |
# Get the majority vote for the test predictions
|
195 |
test_preds = torch.stack(test_preds)
|
196 |
test_preds, _ = torch.mode(test_preds, dim=0)
|
@@ -340,27 +337,28 @@ def hyperparameter_tuning_and_training(
|
|
340 |
test_report = pd.DataFrame(test_report)
|
341 |
|
342 |
# Get the majority vote for the test predictions
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
|
|
364 |
|
365 |
# Ablation study: disable embeddings at a time
|
366 |
ablation_report = []
|
@@ -407,8 +405,9 @@ def hyperparameter_tuning_and_training(
|
|
407 |
'hparam_report': hparam_report,
|
408 |
'test_report': test_report,
|
409 |
'ablation_report': ablation_report,
|
410 |
-
'majority_vote_report': majority_vote_report,
|
411 |
}
|
|
|
|
|
412 |
return ret
|
413 |
|
414 |
|
|
|
118 |
hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
|
119 |
batch_size = trial.suggest_categorical('batch_size', batch_size_options)
|
120 |
learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
|
|
|
121 |
smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
|
122 |
use_smote = trial.suggest_categorical('use_smote', [True, False])
|
123 |
apply_scaling = trial.suggest_categorical('apply_scaling', [True, False])
|
|
|
160 |
test_df=test_df,
|
161 |
hidden_dim=hidden_dim,
|
162 |
batch_size=batch_size,
|
|
|
163 |
learning_rate=learning_rate,
|
164 |
dropout=dropout,
|
165 |
max_epochs=max_epochs,
|
|
|
175 |
if test_df is not None:
|
176 |
_, trainer, metrics, val_pred, test_pred = ret
|
177 |
test_preds.append(test_pred)
|
|
|
178 |
else:
|
179 |
_, trainer, metrics, val_pred = ret
|
180 |
train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
|
|
|
187 |
trial.set_user_attr('report', report)
|
188 |
|
189 |
# Get the majority vote for the test predictions
|
190 |
+
if test_df is not None and not fast_dev_run:
|
191 |
# Get the majority vote for the test predictions
|
192 |
test_preds = torch.stack(test_preds)
|
193 |
test_preds, _ = torch.mode(test_preds, dim=0)
|
|
|
337 |
test_report = pd.DataFrame(test_report)
|
338 |
|
339 |
# Get the majority vote for the test predictions
|
340 |
+
if not fast_dev_run:
|
341 |
+
test_preds = torch.stack(test_preds)
|
342 |
+
test_preds, _ = torch.mode(test_preds, dim=0)
|
343 |
+
y = torch.tensor(test_df[active_label].tolist())
|
344 |
+
# Measure the test accuracy and ROC AUC
|
345 |
+
majority_vote_metrics = {
|
346 |
+
'cv_models': False,
|
347 |
+
'test_acc': Accuracy(task='binary')(test_preds, y).item(),
|
348 |
+
'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
|
349 |
+
'test_precision': Precision(task='binary')(test_preds, y).item(),
|
350 |
+
'test_recall': Recall(task='binary')(test_preds, y).item(),
|
351 |
+
'test_f1': F1Score(task='binary')(test_preds, y).item(),
|
352 |
+
}
|
353 |
+
majority_vote_metrics.update(get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label))
|
354 |
+
majority_vote_metrics_cv = study.best_trial.user_attrs['majority_vote_metrics']
|
355 |
+
majority_vote_metrics_cv['cv_models'] = True
|
356 |
+
majority_vote_report = pd.DataFrame([
|
357 |
+
majority_vote_metrics,
|
358 |
+
majority_vote_metrics_cv,
|
359 |
+
])
|
360 |
+
majority_vote_report['model_type'] = 'Pytorch'
|
361 |
+
majority_vote_report['split_type'] = split_type
|
362 |
|
363 |
# Ablation study: disable embeddings at a time
|
364 |
ablation_report = []
|
|
|
405 |
'hparam_report': hparam_report,
|
406 |
'test_report': test_report,
|
407 |
'ablation_report': ablation_report,
|
|
|
408 |
}
|
409 |
+
if not fast_dev_run:
|
410 |
+
ret['majority_vote_report'] = majority_vote_report
|
411 |
return ret
|
412 |
|
413 |
|
protac_degradation_predictor/pytorch_models.py
CHANGED
@@ -36,7 +36,7 @@ class PROTAC_Predictor(nn.Module):
|
|
36 |
e3_emb_dim: int = config.protein_embedding_size,
|
37 |
cell_emb_dim: int = config.cell_embedding_size,
|
38 |
dropout: float = 0.2,
|
39 |
-
join_embeddings: Literal['beginning', 'concat', 'sum'] = '
|
40 |
disabled_embeddings: list = [],
|
41 |
):
|
42 |
""" Initialize the PROTAC model.
|
@@ -140,7 +140,7 @@ class PROTAC_Model(pl.LightningModule):
|
|
140 |
batch_size: int = 32,
|
141 |
learning_rate: float = 1e-3,
|
142 |
dropout: float = 0.2,
|
143 |
-
join_embeddings: Literal['beginning', 'concat', 'sum'] = '
|
144 |
train_dataset: PROTAC_Dataset = None,
|
145 |
val_dataset: PROTAC_Dataset = None,
|
146 |
test_dataset: PROTAC_Dataset = None,
|
@@ -308,7 +308,19 @@ class PROTAC_Model(pl.LightningModule):
|
|
308 |
return self.step(batch, batch_idx, 'test')
|
309 |
|
310 |
def configure_optimizers(self):
|
311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
def predict_step(self, batch, batch_idx):
|
314 |
poi_emb = batch['poi_emb']
|
@@ -384,7 +396,7 @@ def train_model(
|
|
384 |
poi_emb_dim: int = config.protein_embedding_size,
|
385 |
e3_emb_dim: int = config.protein_embedding_size,
|
386 |
cell_emb_dim: int = config.cell_embedding_size,
|
387 |
-
join_embeddings: Literal['beginning', 'concat', 'sum'] = '
|
388 |
smote_k_neighbors:int = 5,
|
389 |
use_smote: bool = True,
|
390 |
apply_scaling: bool = False,
|
@@ -482,6 +494,8 @@ def train_model(
|
|
482 |
verbose=False,
|
483 |
),
|
484 |
]
|
|
|
|
|
485 |
if enable_checkpointing:
|
486 |
callbacks.append(pl.callbacks.ModelCheckpoint(
|
487 |
monitor='val_acc',
|
|
|
36 |
e3_emb_dim: int = config.protein_embedding_size,
|
37 |
cell_emb_dim: int = config.cell_embedding_size,
|
38 |
dropout: float = 0.2,
|
39 |
+
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
|
40 |
disabled_embeddings: list = [],
|
41 |
):
|
42 |
""" Initialize the PROTAC model.
|
|
|
140 |
batch_size: int = 32,
|
141 |
learning_rate: float = 1e-3,
|
142 |
dropout: float = 0.2,
|
143 |
+
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
|
144 |
train_dataset: PROTAC_Dataset = None,
|
145 |
val_dataset: PROTAC_Dataset = None,
|
146 |
test_dataset: PROTAC_Dataset = None,
|
|
|
308 |
return self.step(batch, batch_idx, 'test')
|
309 |
|
310 |
def configure_optimizers(self):
|
311 |
+
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
312 |
+
return {
|
313 |
+
'optimizer': optimizer,
|
314 |
+
'lr_scheduler': optim.lr_scheduler.ReduceLROnPlateau(
|
315 |
+
optimizer=optimizer,
|
316 |
+
mode='min',
|
317 |
+
factor=0.5,
|
318 |
+
patience=2,
|
319 |
+
),
|
320 |
+
'interval': 'step', # or 'epoch'
|
321 |
+
'frequency': 1,
|
322 |
+
'monitor': 'val_loss',
|
323 |
+
}
|
324 |
|
325 |
def predict_step(self, batch, batch_idx):
|
326 |
poi_emb = batch['poi_emb']
|
|
|
396 |
poi_emb_dim: int = config.protein_embedding_size,
|
397 |
e3_emb_dim: int = config.protein_embedding_size,
|
398 |
cell_emb_dim: int = config.cell_embedding_size,
|
399 |
+
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
|
400 |
smote_k_neighbors:int = 5,
|
401 |
use_smote: bool = True,
|
402 |
apply_scaling: bool = False,
|
|
|
494 |
verbose=False,
|
495 |
),
|
496 |
]
|
497 |
+
if use_logger:
|
498 |
+
callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='step'))
|
499 |
if enable_checkpointing:
|
500 |
callbacks.append(pl.callbacks.ModelCheckpoint(
|
501 |
monitor='val_acc',
|
src/run_experiments.py
CHANGED
@@ -309,15 +309,8 @@ def main(
|
|
309 |
|
310 |
# Save the reports to file
|
311 |
for report_name, report in optuna_reports.items():
|
312 |
-
report.to_csv(f'../reports/
|
313 |
reports[report_name].append(report.copy())
|
314 |
-
|
315 |
-
# Save the reports to file after concatenating them
|
316 |
-
for report_name, report in reports.items():
|
317 |
-
report = pd.concat(report)
|
318 |
-
report.to_csv(f'../reports/report_{report_name}_{active_name}_test_split_{test_split}.csv', index=False)
|
319 |
-
|
320 |
-
|
321 |
|
322 |
# # Start the CV over the folds
|
323 |
# X = train_val_df.drop(columns=active_col)
|
|
|
309 |
|
310 |
# Save the reports to file
|
311 |
for report_name, report in optuna_reports.items():
|
312 |
+
report.to_csv(f'../reports/{report_name}_{experiment_name}.csv', index=False)
|
313 |
reports[report_name].append(report.copy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
|
315 |
# # Start the CV over the folds
|
316 |
# X = train_val_df.drop(columns=active_col)
|