ribesstefano commited on
Commit
a589b70
1 Parent(s): 82509b6

Added confidence scores to reports

Browse files
protac_degradation_predictor/optuna_utils.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
- from typing import Literal, List, Tuple, Optional, Dict
3
  import logging
4
 
5
  from .pytorch_models import (
6
  train_model,
7
  PROTAC_Model,
8
  evaluate_model,
 
9
  )
10
  from .protac_dataset import get_datasets
11
 
@@ -81,6 +82,9 @@ def get_majority_vote_metrics(
81
  active_label: str = 'Active',
82
  ) -> Dict:
83
  """ Get the majority vote metrics. """
 
 
 
84
  test_preds = torch.stack(test_preds)
85
  test_preds, _ = torch.mode(test_preds, dim=0)
86
  y = torch.tensor(test_df[active_label].tolist())
@@ -92,8 +96,23 @@ def get_majority_vote_metrics(
92
  'test_recall': Recall(task='binary')(test_preds, y).item(),
93
  'test_f1_score': F1Score(task='binary')(test_preds, y).item(),
94
  }
 
 
 
 
 
 
95
  return majority_vote_metrics
96
 
 
 
 
 
 
 
 
 
 
97
 
98
  def pytorch_model_objective(
99
  trial: optuna.Trial,
@@ -104,11 +123,7 @@ def pytorch_model_objective(
104
  kf: StratifiedKFold | StratifiedGroupKFold,
105
  groups: Optional[np.array] = None,
106
  test_df: Optional[pd.DataFrame] = None,
107
- hidden_dim_options: List[int] = [256, 512, 768],
108
- batch_size_options: List[int] = [8, 16, 32],
109
- learning_rate_options: Tuple[float, float] = (1e-5, 1e-3),
110
- smote_k_neighbors_options: List[int] = list(range(3, 16)),
111
- dropout_options: Tuple[float, float] = (0.1, 0.5),
112
  fast_dev_run: bool = False,
113
  active_label: str = 'Active',
114
  disabled_embeddings: List[str] = [],
@@ -124,11 +139,8 @@ def pytorch_model_objective(
124
  trial (optuna.Trial): The Optuna trial object.
125
  train_df (pd.DataFrame): The training set.
126
  val_df (pd.DataFrame): The validation set.
127
- hidden_dim_options (List[int]): The hidden dimension options.
128
- batch_size_options (List[int]): The batch size options.
129
- learning_rate_options (Tuple[float, float]): The learning rate options.
130
- smote_k_neighbors_options (List[int]): The SMOTE k neighbors options.
131
- dropout_options (Tuple[float, float]): The dropout options.
132
  fast_dev_run (bool): Whether to run a fast development run.
133
  active_label (str): The active label column.
134
  disabled_embeddings (List[str]): The list of disabled embeddings.
@@ -139,11 +151,14 @@ def pytorch_model_objective(
139
  use_batch_norm = True
140
 
141
  # Suggest hyperparameters to be used accross the CV folds
142
- hidden_dim = trial.suggest_int('hidden_dim', 32, 512, step=32)
143
- smote_k_neighbors = trial.suggest_int('smote_k_neighbors', 0, 12)
144
- # hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
145
- # smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
146
- # dropout = trial.suggest_float('dropout', *dropout_options)
 
 
 
147
  # use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
148
 
149
  # Optimizer parameters
@@ -194,6 +209,7 @@ def pytorch_model_objective(
194
  beta2=beta2,
195
  eps=eps,
196
  use_batch_norm=use_batch_norm,
 
197
  max_epochs=max_epochs,
198
  smote_k_neighbors=smote_k_neighbors,
199
  apply_scaling=apply_scaling,
@@ -227,6 +243,9 @@ def pytorch_model_objective(
227
 
228
  # Get the average validation accuracy and ROC AUC accross the folds
229
  val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
 
 
 
230
 
231
  # Optuna aims to minimize the pytorch_model_objective
232
  return - val_roc_auc
@@ -240,7 +259,7 @@ def hyperparameter_tuning_and_training(
240
  test_df: pd.DataFrame,
241
  kf: StratifiedKFold | StratifiedGroupKFold,
242
  groups: Optional[np.array] = None,
243
- split_type: str = 'random',
244
  n_models_for_test: int = 3,
245
  fast_dev_run: bool = False,
246
  n_trials: int = 50,
@@ -279,21 +298,13 @@ def hyperparameter_tuning_and_training(
279
 
280
  # TODO: Make the following code more modular, i.e., the ranges shall be put
281
  # in dictionaries or config files or something like that.
282
-
283
- # Define the search space
284
- hidden_dim_options = [8, 16, 32, 64, 128, 256] #, 512]
285
- batch_size_options = [128, 128] # [4, 8, 16, 32, 64, 128]
286
- learning_rate_options = (1e-6, 1e-1) # min and max values for loguniform distribution
287
- smote_k_neighbors_options = list(range(3, 16))
288
- # NOTE: We want Optuna to explore the combination (very low dropout, very
289
- # small hidden_dim)
290
- dropout_options = (0, 0.5)
291
 
292
  # Set the verbosity of Optuna
293
  optuna.logging.set_verbosity(optuna.logging.WARNING)
294
  # Set a quasi-random sampler, as suggested in: https://github.com/google-research/tuning_playbook?tab=readme-ov-file#faqs
295
- # sampler = TPESampler(seed=42, multivariate=True)
296
- sampler = QMCSampler(qmc_type='halton', scramble=True, seed=42)
297
  # Create an Optuna study object
298
  study = optuna.create_study(direction='minimize', sampler=sampler)
299
 
@@ -316,11 +327,7 @@ def hyperparameter_tuning_and_training(
316
  kf=kf,
317
  groups=groups,
318
  test_df=test_df,
319
- hidden_dim_options=hidden_dim_options,
320
- batch_size_options=batch_size_options,
321
- learning_rate_options=learning_rate_options,
322
- smote_k_neighbors_options=smote_k_neighbors_options,
323
- dropout_options=dropout_options,
324
  fast_dev_run=fast_dev_run,
325
  active_label=active_label,
326
  max_epochs=max_epochs,
@@ -344,11 +351,7 @@ def hyperparameter_tuning_and_training(
344
  kf=kf,
345
  groups=groups,
346
  test_df=test_df,
347
- hidden_dim_options=hidden_dim_options,
348
- batch_size_options=batch_size_options,
349
- learning_rate_options=learning_rate_options,
350
- smote_k_neighbors_options=smote_k_neighbors_options,
351
- dropout_options=dropout_options,
352
  fast_dev_run=fast_dev_run,
353
  active_label=active_label,
354
  max_epochs=max_epochs,
@@ -384,7 +387,7 @@ def hyperparameter_tuning_and_training(
384
  return_predictions=True,
385
  batch_size=128,
386
  apply_scaling=True,
387
- use_batch_norm=True,
388
  **study.best_params,
389
  )
390
  # Rename the keys in the metrics dictionary
@@ -464,34 +467,6 @@ def hyperparameter_tuning_and_training(
464
  majority_vote_metrics['disabled_embeddings'] = disabled_embeddings_str
465
  ablation_report.append(majority_vote_metrics.copy())
466
 
467
- # _, _, metrics = train_model(
468
- # protein2embedding=protein2embedding,
469
- # cell2embedding=cell2embedding,
470
- # smiles2fp=smiles2fp,
471
- # train_df=train_val_df,
472
- # val_df=test_df,
473
- # fast_dev_run=fast_dev_run,
474
- # active_label=active_label,
475
- # max_epochs=max_epochs,
476
- # use_logger=False,
477
- # logger_save_dir=logger_save_dir,
478
- # logger_name=f'{logger_name}_disabled-{"-".join(disabled_embeddings)}',
479
- # disabled_embeddings=disabled_embeddings,
480
- # batch_size=128,
481
- # apply_scaling=True,
482
- # **study.best_params,
483
- # )
484
- # # Rename the keys in the metrics dictionary
485
- # metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
486
- # metrics['disabled_embeddings'] = disabled_embeddings_str
487
- # metrics['model_type'] = 'Pytorch'
488
- # metrics.update(dfs_stats)
489
-
490
- # # Add the training metrics
491
- # train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
492
- # metrics.update(train_metrics)
493
- # ablation_report.append(metrics.copy())
494
-
495
  ablation_report = pd.DataFrame(ablation_report)
496
 
497
  # Add a column with the split_type to all reports
 
1
  import os
2
+ from typing import Literal, List, Tuple, Optional, Dict, Any
3
  import logging
4
 
5
  from .pytorch_models import (
6
  train_model,
7
  PROTAC_Model,
8
  evaluate_model,
9
+ get_confidence_scores,
10
  )
11
  from .protac_dataset import get_datasets
12
 
 
82
  active_label: str = 'Active',
83
  ) -> Dict:
84
  """ Get the majority vote metrics. """
85
+ test_preds_mean = np.array(test_preds).mean(axis=0)
86
+ logging.info(f'Test predictions: {test_preds}')
87
+ logging.info(f'Test predictions mean: {test_preds_mean}')
88
  test_preds = torch.stack(test_preds)
89
  test_preds, _ = torch.mode(test_preds, dim=0)
90
  y = torch.tensor(test_df[active_label].tolist())
 
96
  'test_recall': Recall(task='binary')(test_preds, y).item(),
97
  'test_f1_score': F1Score(task='binary')(test_preds, y).item(),
98
  }
99
+
100
+ # Get mean predictions
101
+ fp_mean, fn_mean = get_confidence_scores(y, test_preds_mean)
102
+ majority_vote_metrics['test_false_negatives_mean'] = fn_mean
103
+ majority_vote_metrics['test_false_positives_mean'] = fp_mean
104
+
105
  return majority_vote_metrics
106
 
107
+ def get_suggestion(trial, dtype, hparams_range):
108
+ if dtype == 'int':
109
+ return trial.suggest_int(**hparams_range)
110
+ elif dtype == 'float':
111
+ return trial.suggest_float(**hparams_range)
112
+ elif dtype == 'categorical':
113
+ return trial.suggest_categorical(**hparams_range)
114
+ else:
115
+ raise ValueError(f'Invalid dtype for trial.suggest: {dtype}')
116
 
117
  def pytorch_model_objective(
118
  trial: optuna.Trial,
 
123
  kf: StratifiedKFold | StratifiedGroupKFold,
124
  groups: Optional[np.array] = None,
125
  test_df: Optional[pd.DataFrame] = None,
126
+ hparams_ranges: Optional[List[Tuple[str, Dict[str, Any]]]] = None,
 
 
 
 
127
  fast_dev_run: bool = False,
128
  active_label: str = 'Active',
129
  disabled_embeddings: List[str] = [],
 
139
  trial (optuna.Trial): The Optuna trial object.
140
  train_df (pd.DataFrame): The training set.
141
  val_df (pd.DataFrame): The validation set.
142
+ hparams_ranges (List[Dict[str, Any]]): NOT IMPLEMENTED YET. Hyperparameters ranges.
143
+ The list must be of a tuple of the type of hparam to suggest ('int', 'float', or 'categorical'), and the dictionary must contain the arguments of the corresponding trial.suggest method.
 
 
 
144
  fast_dev_run (bool): Whether to run a fast development run.
145
  active_label (str): The active label column.
146
  disabled_embeddings (List[str]): The list of disabled embeddings.
 
151
  use_batch_norm = True
152
 
153
  # Suggest hyperparameters to be used accross the CV folds
154
+ hidden_dim = trial.suggest_categorical('hidden_dim', [16, 32, 64, 128, 256, 512])
155
+ smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', [0] + list(range(3, 16)))
156
+ # hidden_dim = trial.suggest_int('hidden_dim', 32, 512, step=32)
157
+ # smote_k_neighbors = trial.suggest_int('smote_k_neighbors', 0, 12)
158
+
159
+ # use_smote = trial.suggest_categorical('use_smote', [True, False])
160
+ # smote_k_neighbors = smote_k_neighbors if use_smote else 0
161
+ # dropout = trial.suggest_float('dropout', 0, 0.5)
162
  # use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
163
 
164
  # Optimizer parameters
 
209
  beta2=beta2,
210
  eps=eps,
211
  use_batch_norm=use_batch_norm,
212
+ # dropout=dropout,
213
  max_epochs=max_epochs,
214
  smote_k_neighbors=smote_k_neighbors,
215
  apply_scaling=apply_scaling,
 
243
 
244
  # Get the average validation accuracy and ROC AUC accross the folds
245
  val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
246
+ val_acc = np.mean([r['val_acc'] for r in report])
247
+ logging.info(f'Average val accuracy: {val_acc}')
248
+ logging.info(f'Average val ROC AUC: {val_roc_auc}')
249
 
250
  # Optuna aims to minimize the pytorch_model_objective
251
  return - val_roc_auc
 
259
  test_df: pd.DataFrame,
260
  kf: StratifiedKFold | StratifiedGroupKFold,
261
  groups: Optional[np.array] = None,
262
+ split_type: str = 'standard',
263
  n_models_for_test: int = 3,
264
  fast_dev_run: bool = False,
265
  n_trials: int = 50,
 
298
 
299
  # TODO: Make the following code more modular, i.e., the ranges shall be put
300
  # in dictionaries or config files or something like that.
301
+ hparams_ranges = None
 
 
 
 
 
 
 
 
302
 
303
  # Set the verbosity of Optuna
304
  optuna.logging.set_verbosity(optuna.logging.WARNING)
305
  # Set a quasi-random sampler, as suggested in: https://github.com/google-research/tuning_playbook?tab=readme-ov-file#faqs
306
+ # sampler = QMCSampler(qmc_type='halton', scramble=True, seed=42)
307
+ sampler = TPESampler(seed=42, multivariate=True)
308
  # Create an Optuna study object
309
  study = optuna.create_study(direction='minimize', sampler=sampler)
310
 
 
327
  kf=kf,
328
  groups=groups,
329
  test_df=test_df,
330
+ hparams_ranges=hparams_ranges,
 
 
 
 
331
  fast_dev_run=fast_dev_run,
332
  active_label=active_label,
333
  max_epochs=max_epochs,
 
351
  kf=kf,
352
  groups=groups,
353
  test_df=test_df,
354
+ hparams_ranges=hparams_ranges,
 
 
 
 
355
  fast_dev_run=fast_dev_run,
356
  active_label=active_label,
357
  max_epochs=max_epochs,
 
387
  return_predictions=True,
388
  batch_size=128,
389
  apply_scaling=True,
390
+ # use_batch_norm=True,
391
  **study.best_params,
392
  )
393
  # Rename the keys in the metrics dictionary
 
467
  majority_vote_metrics['disabled_embeddings'] = disabled_embeddings_str
468
  ablation_report.append(majority_vote_metrics.copy())
469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  ablation_report = pd.DataFrame(ablation_report)
471
 
472
  # Add a column with the split_type to all reports
protac_degradation_predictor/optuna_utils_xgboost.py CHANGED
@@ -24,6 +24,21 @@ import torch
24
  xgb.set_config(verbosity=0)
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def train_and_evaluate_xgboost(
28
  protein2embedding: Dict,
29
  cell2embedding: Dict,
@@ -92,24 +107,34 @@ def train_and_evaluate_xgboost(
92
  # Evaluate model
93
  val_pred = model.predict(dval)
94
  val_pred_binary = (val_pred > 0.5).astype(int)
 
 
 
95
  metrics = {
96
  'val_acc': accuracy_score(y_val, val_pred_binary),
97
  'val_roc_auc': roc_auc_score(y_val, val_pred),
98
  'val_precision': precision_score(y_val, val_pred_binary),
99
  'val_recall': recall_score(y_val, val_pred_binary),
100
  'val_f1_score': f1_score(y_val, val_pred_binary),
 
 
101
  }
102
  preds = {'val_pred': val_pred}
103
 
104
  if test_df is not None:
105
  test_pred = model.predict(dtest)
106
- test_pred_binary = (test_pred > 0.5).astype(int)
 
 
 
107
  metrics.update({
108
  'test_acc': accuracy_score(y_test, test_pred_binary),
109
  'test_roc_auc': roc_auc_score(y_test, test_pred),
110
  'test_precision': precision_score(y_test, test_pred_binary),
111
  'test_recall': recall_score(y_test, test_pred_binary),
112
  'test_f1_score': f1_score(y_test, test_pred_binary),
 
 
113
  })
114
  preds.update({'test_pred': test_pred})
115
 
@@ -328,7 +353,7 @@ def xgboost_hyperparameter_tuning_and_training(
328
 
329
  # Save the trained model
330
  if model_name:
331
- model_filename = f'{model_name}_best_model_{split_type}_n{i}.json'
332
  model.save_model(model_filename)
333
  logging.info(f'Best XGBoost model saved to: {model_filename}')
334
  test_report = pd.DataFrame(test_report)
 
24
  xgb.set_config(verbosity=0)
25
 
26
 
27
+ def get_confidence_scores(y, y_pred, threshold=0.5):
28
+ # Calculate the likelihood for the false negative: get the mean value of
29
+ # the prediction for the false-positive and false-negatives
30
+
31
+ # Get the indices of the false positives and false negatives
32
+ false_positives = (y == 0) & ((y_pred > threshold).astype(int) == 1)
33
+ false_negatives = (y == 1) & ((y_pred > threshold).astype(int) == 0)
34
+
35
+ # Get the mean value of the predictions for the false positives and false negatives
36
+ false_positives_mean = y_pred[false_positives].mean()
37
+ false_negatives_mean = y_pred[false_negatives].mean()
38
+
39
+ return false_positives_mean, false_negatives_mean
40
+
41
+
42
  def train_and_evaluate_xgboost(
43
  protein2embedding: Dict,
44
  cell2embedding: Dict,
 
107
  # Evaluate model
108
  val_pred = model.predict(dval)
109
  val_pred_binary = (val_pred > 0.5).astype(int)
110
+
111
+ fp_mean, fn_mean = get_confidence_scores(y_val, val_pred)
112
+
113
  metrics = {
114
  'val_acc': accuracy_score(y_val, val_pred_binary),
115
  'val_roc_auc': roc_auc_score(y_val, val_pred),
116
  'val_precision': precision_score(y_val, val_pred_binary),
117
  'val_recall': recall_score(y_val, val_pred_binary),
118
  'val_f1_score': f1_score(y_val, val_pred_binary),
119
+ 'val_false_positives_mean': fp_mean,
120
+ 'val_false_negatives_mean': fn_mean,
121
  }
122
  preds = {'val_pred': val_pred}
123
 
124
  if test_df is not None:
125
  test_pred = model.predict(dtest)
126
+ test_pred_binary = (test_pred > 0.5).astype(int)
127
+
128
+ fp_mean, fn_mean = get_confidence_scores(y_test, test_pred)
129
+
130
  metrics.update({
131
  'test_acc': accuracy_score(y_test, test_pred_binary),
132
  'test_roc_auc': roc_auc_score(y_test, test_pred),
133
  'test_precision': precision_score(y_test, test_pred_binary),
134
  'test_recall': recall_score(y_test, test_pred_binary),
135
  'test_f1_score': f1_score(y_test, test_pred_binary),
136
+ 'test_false_positives_mean': fp_mean,
137
+ 'test_false_negatives_mean': fn_mean,
138
  })
139
  preds.update({'test_pred': test_pred})
140
 
 
353
 
354
  # Save the trained model
355
  if model_name:
356
+ model_filename = f'{model_name}_best_model_{split_type}_n{i}-test_acc={metrics["test_acc"]:.2f}-test_roc_auc={metrics["test_roc_auc"]:.3f}.json'
357
  model.save_model(model_filename)
358
  logging.info(f'Best XGBoost model saved to: {model_filename}')
359
  test_report = pd.DataFrame(test_report)
protac_degradation_predictor/pytorch_models.py CHANGED
@@ -336,21 +336,21 @@ class PROTAC_Model(pl.LightningModule):
336
  else:
337
  optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
338
  # Define LR scheduler
339
- if self.trainer.max_epochs:
340
- total_iters = self.trainer.max_epochs
341
- elif self.trainer.max_steps:
342
- total_iters = self.trainer.max_steps
343
- else:
344
- total_iters = 20
345
- lr_scheduler = optim.lr_scheduler.LinearLR(
346
  optimizer=optimizer,
347
- total_iters=total_iters,
 
 
348
  )
349
- # lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
 
 
 
 
 
 
350
  # optimizer=optimizer,
351
- # mode='min',
352
- # factor=0.01,
353
- # patience=0,
354
  # )
355
  return {
356
  'optimizer': optimizer,
@@ -418,6 +418,44 @@ class PROTAC_Model(pl.LightningModule):
418
  logging.warning("Scalers not found in checkpoint. Consider re-fitting scalers if necessary.")
419
 
420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  # TODO: Use some sort of **kwargs to pass all the parameters to the model...
422
  def train_model(
423
  protein2embedding: Dict[str, np.ndarray],
@@ -448,6 +486,7 @@ def train_model(
448
  disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
449
  return_predictions: bool = False,
450
  shuffle_embedding_prob: float = 0.0,
 
451
  ) -> tuple:
452
  """ Train a PROTAC model using the given datasets and hyperparameters.
453
 
@@ -532,7 +571,7 @@ def train_model(
532
  ),
533
  pl.callbacks.EarlyStopping(
534
  monitor='val_acc',
535
- patience=10,
536
  mode='max',
537
  verbose=False,
538
  ),
@@ -604,9 +643,19 @@ def train_model(
604
  val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
605
  val_pred = trainer.predict(model, val_dl)
606
  val_pred = torch.concat(trainer.predict(model, val_dl)).squeeze()
 
 
 
 
 
607
  if test_df is not None:
608
  test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
609
  test_pred = torch.concat(trainer.predict(model, test_dl)).squeeze()
 
 
 
 
 
610
  return model, trainer, metrics, val_pred, test_pred
611
  return model, trainer, metrics, val_pred
612
  return model, trainer, metrics
 
336
  else:
337
  optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
338
  # Define LR scheduler
339
+ lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
 
 
 
 
 
 
340
  optimizer=optimizer,
341
+ mode='min',
342
+ factor=0.1,
343
+ patience=0,
344
  )
345
+ # if self.trainer.max_epochs:
346
+ # total_iters = self.trainer.max_epochs
347
+ # elif self.trainer.max_steps:
348
+ # total_iters = self.trainer.max_steps
349
+ # else:
350
+ # total_iters = 20
351
+ # lr_scheduler = optim.lr_scheduler.LinearLR(
352
  # optimizer=optimizer,
353
+ # total_iters=total_iters,
 
 
354
  # )
355
  return {
356
  'optimizer': optimizer,
 
418
  logging.warning("Scalers not found in checkpoint. Consider re-fitting scalers if necessary.")
419
 
420
 
421
+ def get_confidence_scores(true_ds, y_preds, threshold=0.5):
422
+ # Calculate the likelihood for the false negative: get the mean value of
423
+ # the prediction for the false-positive and false-negatives
424
+
425
+ # Convert PyTorch dataset labels to numpy array
426
+ if isinstance(true_ds, PROTAC_Dataset):
427
+ true_vals = np.array([x['active'] for x in true_ds]).flatten()
428
+ elif isinstance(true_ds, torch.Tensor):
429
+ true_vals = true_ds.numpy().flatten()
430
+ elif isinstance(true_ds, np.ndarray):
431
+ true_vals = true_ds.flatten()
432
+ else:
433
+ raise ValueError("Unknown type for true labels.")
434
+
435
+ if isinstance(y_preds, torch.Tensor):
436
+ preds = y_preds.numpy().flatten()
437
+ elif isinstance(y_preds, np.ndarray):
438
+ preds = y_preds.flatten()
439
+ else:
440
+ raise ValueError("Unknown type for predictions.")
441
+
442
+ logging.info(f"True values: {true_vals}")
443
+ logging.info(f"Predictions: {preds}")
444
+
445
+ # Get the indices of the false positives and false negatives
446
+ false_positives = (true_vals == 0) & ((preds > threshold).astype(int) == 1)
447
+ false_negatives = (true_vals == 1) & ((preds > threshold).astype(int) == 0)
448
+
449
+ logging.info(f"False positives: {false_positives}")
450
+ logging.info(f"False negatives: {false_negatives}")
451
+
452
+ # Get the mean value of the predictions for the false positives and false negatives
453
+ false_positives_mean = preds[false_positives].mean()
454
+ false_negatives_mean = preds[false_negatives].mean()
455
+
456
+ return false_positives_mean, false_negatives_mean
457
+
458
+
459
  # TODO: Use some sort of **kwargs to pass all the parameters to the model...
460
  def train_model(
461
  protein2embedding: Dict[str, np.ndarray],
 
486
  disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
487
  return_predictions: bool = False,
488
  shuffle_embedding_prob: float = 0.0,
489
+ use_smote: bool = False,
490
  ) -> tuple:
491
  """ Train a PROTAC model using the given datasets and hyperparameters.
492
 
 
571
  ),
572
  pl.callbacks.EarlyStopping(
573
  monitor='val_acc',
574
+ patience=10, # Original: 10
575
  mode='max',
576
  verbose=False,
577
  ),
 
643
  val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
644
  val_pred = trainer.predict(model, val_dl)
645
  val_pred = torch.concat(trainer.predict(model, val_dl)).squeeze()
646
+
647
+ fp_mean, fn_mean = get_confidence_scores(val_ds, val_pred)
648
+ metrics['val_false_positives_mean'] = fp_mean
649
+ metrics['val_false_negatives_mean'] = fn_mean
650
+
651
  if test_df is not None:
652
  test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
653
  test_pred = torch.concat(trainer.predict(model, test_dl)).squeeze()
654
+
655
+ fp_mean, fn_mean = get_confidence_scores(test_ds, test_pred)
656
+ metrics['test_false_positives_mean'] = fp_mean
657
+ metrics['test_false_negatives_mean'] = fn_mean
658
+
659
  return model, trainer, metrics, val_pred, test_pred
660
  return model, trainer, metrics, val_pred
661
  return model, trainer, metrics