ribesstefano
commited on
Commit
•
a589b70
1
Parent(s):
82509b6
Added confidence scores to reports
Browse files
protac_degradation_predictor/optuna_utils.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import os
|
2 |
-
from typing import Literal, List, Tuple, Optional, Dict
|
3 |
import logging
|
4 |
|
5 |
from .pytorch_models import (
|
6 |
train_model,
|
7 |
PROTAC_Model,
|
8 |
evaluate_model,
|
|
|
9 |
)
|
10 |
from .protac_dataset import get_datasets
|
11 |
|
@@ -81,6 +82,9 @@ def get_majority_vote_metrics(
|
|
81 |
active_label: str = 'Active',
|
82 |
) -> Dict:
|
83 |
""" Get the majority vote metrics. """
|
|
|
|
|
|
|
84 |
test_preds = torch.stack(test_preds)
|
85 |
test_preds, _ = torch.mode(test_preds, dim=0)
|
86 |
y = torch.tensor(test_df[active_label].tolist())
|
@@ -92,8 +96,23 @@ def get_majority_vote_metrics(
|
|
92 |
'test_recall': Recall(task='binary')(test_preds, y).item(),
|
93 |
'test_f1_score': F1Score(task='binary')(test_preds, y).item(),
|
94 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
return majority_vote_metrics
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
def pytorch_model_objective(
|
99 |
trial: optuna.Trial,
|
@@ -104,11 +123,7 @@ def pytorch_model_objective(
|
|
104 |
kf: StratifiedKFold | StratifiedGroupKFold,
|
105 |
groups: Optional[np.array] = None,
|
106 |
test_df: Optional[pd.DataFrame] = None,
|
107 |
-
|
108 |
-
batch_size_options: List[int] = [8, 16, 32],
|
109 |
-
learning_rate_options: Tuple[float, float] = (1e-5, 1e-3),
|
110 |
-
smote_k_neighbors_options: List[int] = list(range(3, 16)),
|
111 |
-
dropout_options: Tuple[float, float] = (0.1, 0.5),
|
112 |
fast_dev_run: bool = False,
|
113 |
active_label: str = 'Active',
|
114 |
disabled_embeddings: List[str] = [],
|
@@ -124,11 +139,8 @@ def pytorch_model_objective(
|
|
124 |
trial (optuna.Trial): The Optuna trial object.
|
125 |
train_df (pd.DataFrame): The training set.
|
126 |
val_df (pd.DataFrame): The validation set.
|
127 |
-
|
128 |
-
|
129 |
-
learning_rate_options (Tuple[float, float]): The learning rate options.
|
130 |
-
smote_k_neighbors_options (List[int]): The SMOTE k neighbors options.
|
131 |
-
dropout_options (Tuple[float, float]): The dropout options.
|
132 |
fast_dev_run (bool): Whether to run a fast development run.
|
133 |
active_label (str): The active label column.
|
134 |
disabled_embeddings (List[str]): The list of disabled embeddings.
|
@@ -139,11 +151,14 @@ def pytorch_model_objective(
|
|
139 |
use_batch_norm = True
|
140 |
|
141 |
# Suggest hyperparameters to be used accross the CV folds
|
142 |
-
hidden_dim = trial.
|
143 |
-
smote_k_neighbors = trial.
|
144 |
-
# hidden_dim = trial.
|
145 |
-
# smote_k_neighbors = trial.
|
146 |
-
|
|
|
|
|
|
|
147 |
# use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
|
148 |
|
149 |
# Optimizer parameters
|
@@ -194,6 +209,7 @@ def pytorch_model_objective(
|
|
194 |
beta2=beta2,
|
195 |
eps=eps,
|
196 |
use_batch_norm=use_batch_norm,
|
|
|
197 |
max_epochs=max_epochs,
|
198 |
smote_k_neighbors=smote_k_neighbors,
|
199 |
apply_scaling=apply_scaling,
|
@@ -227,6 +243,9 @@ def pytorch_model_objective(
|
|
227 |
|
228 |
# Get the average validation accuracy and ROC AUC accross the folds
|
229 |
val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
|
|
|
|
|
|
|
230 |
|
231 |
# Optuna aims to minimize the pytorch_model_objective
|
232 |
return - val_roc_auc
|
@@ -240,7 +259,7 @@ def hyperparameter_tuning_and_training(
|
|
240 |
test_df: pd.DataFrame,
|
241 |
kf: StratifiedKFold | StratifiedGroupKFold,
|
242 |
groups: Optional[np.array] = None,
|
243 |
-
split_type: str = '
|
244 |
n_models_for_test: int = 3,
|
245 |
fast_dev_run: bool = False,
|
246 |
n_trials: int = 50,
|
@@ -279,21 +298,13 @@ def hyperparameter_tuning_and_training(
|
|
279 |
|
280 |
# TODO: Make the following code more modular, i.e., the ranges shall be put
|
281 |
# in dictionaries or config files or something like that.
|
282 |
-
|
283 |
-
# Define the search space
|
284 |
-
hidden_dim_options = [8, 16, 32, 64, 128, 256] #, 512]
|
285 |
-
batch_size_options = [128, 128] # [4, 8, 16, 32, 64, 128]
|
286 |
-
learning_rate_options = (1e-6, 1e-1) # min and max values for loguniform distribution
|
287 |
-
smote_k_neighbors_options = list(range(3, 16))
|
288 |
-
# NOTE: We want Optuna to explore the combination (very low dropout, very
|
289 |
-
# small hidden_dim)
|
290 |
-
dropout_options = (0, 0.5)
|
291 |
|
292 |
# Set the verbosity of Optuna
|
293 |
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
294 |
# Set a quasi-random sampler, as suggested in: https://github.com/google-research/tuning_playbook?tab=readme-ov-file#faqs
|
295 |
-
# sampler =
|
296 |
-
sampler =
|
297 |
# Create an Optuna study object
|
298 |
study = optuna.create_study(direction='minimize', sampler=sampler)
|
299 |
|
@@ -316,11 +327,7 @@ def hyperparameter_tuning_and_training(
|
|
316 |
kf=kf,
|
317 |
groups=groups,
|
318 |
test_df=test_df,
|
319 |
-
|
320 |
-
batch_size_options=batch_size_options,
|
321 |
-
learning_rate_options=learning_rate_options,
|
322 |
-
smote_k_neighbors_options=smote_k_neighbors_options,
|
323 |
-
dropout_options=dropout_options,
|
324 |
fast_dev_run=fast_dev_run,
|
325 |
active_label=active_label,
|
326 |
max_epochs=max_epochs,
|
@@ -344,11 +351,7 @@ def hyperparameter_tuning_and_training(
|
|
344 |
kf=kf,
|
345 |
groups=groups,
|
346 |
test_df=test_df,
|
347 |
-
|
348 |
-
batch_size_options=batch_size_options,
|
349 |
-
learning_rate_options=learning_rate_options,
|
350 |
-
smote_k_neighbors_options=smote_k_neighbors_options,
|
351 |
-
dropout_options=dropout_options,
|
352 |
fast_dev_run=fast_dev_run,
|
353 |
active_label=active_label,
|
354 |
max_epochs=max_epochs,
|
@@ -384,7 +387,7 @@ def hyperparameter_tuning_and_training(
|
|
384 |
return_predictions=True,
|
385 |
batch_size=128,
|
386 |
apply_scaling=True,
|
387 |
-
use_batch_norm=True,
|
388 |
**study.best_params,
|
389 |
)
|
390 |
# Rename the keys in the metrics dictionary
|
@@ -464,34 +467,6 @@ def hyperparameter_tuning_and_training(
|
|
464 |
majority_vote_metrics['disabled_embeddings'] = disabled_embeddings_str
|
465 |
ablation_report.append(majority_vote_metrics.copy())
|
466 |
|
467 |
-
# _, _, metrics = train_model(
|
468 |
-
# protein2embedding=protein2embedding,
|
469 |
-
# cell2embedding=cell2embedding,
|
470 |
-
# smiles2fp=smiles2fp,
|
471 |
-
# train_df=train_val_df,
|
472 |
-
# val_df=test_df,
|
473 |
-
# fast_dev_run=fast_dev_run,
|
474 |
-
# active_label=active_label,
|
475 |
-
# max_epochs=max_epochs,
|
476 |
-
# use_logger=False,
|
477 |
-
# logger_save_dir=logger_save_dir,
|
478 |
-
# logger_name=f'{logger_name}_disabled-{"-".join(disabled_embeddings)}',
|
479 |
-
# disabled_embeddings=disabled_embeddings,
|
480 |
-
# batch_size=128,
|
481 |
-
# apply_scaling=True,
|
482 |
-
# **study.best_params,
|
483 |
-
# )
|
484 |
-
# # Rename the keys in the metrics dictionary
|
485 |
-
# metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
|
486 |
-
# metrics['disabled_embeddings'] = disabled_embeddings_str
|
487 |
-
# metrics['model_type'] = 'Pytorch'
|
488 |
-
# metrics.update(dfs_stats)
|
489 |
-
|
490 |
-
# # Add the training metrics
|
491 |
-
# train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
|
492 |
-
# metrics.update(train_metrics)
|
493 |
-
# ablation_report.append(metrics.copy())
|
494 |
-
|
495 |
ablation_report = pd.DataFrame(ablation_report)
|
496 |
|
497 |
# Add a column with the split_type to all reports
|
|
|
1 |
import os
|
2 |
+
from typing import Literal, List, Tuple, Optional, Dict, Any
|
3 |
import logging
|
4 |
|
5 |
from .pytorch_models import (
|
6 |
train_model,
|
7 |
PROTAC_Model,
|
8 |
evaluate_model,
|
9 |
+
get_confidence_scores,
|
10 |
)
|
11 |
from .protac_dataset import get_datasets
|
12 |
|
|
|
82 |
active_label: str = 'Active',
|
83 |
) -> Dict:
|
84 |
""" Get the majority vote metrics. """
|
85 |
+
test_preds_mean = np.array(test_preds).mean(axis=0)
|
86 |
+
logging.info(f'Test predictions: {test_preds}')
|
87 |
+
logging.info(f'Test predictions mean: {test_preds_mean}')
|
88 |
test_preds = torch.stack(test_preds)
|
89 |
test_preds, _ = torch.mode(test_preds, dim=0)
|
90 |
y = torch.tensor(test_df[active_label].tolist())
|
|
|
96 |
'test_recall': Recall(task='binary')(test_preds, y).item(),
|
97 |
'test_f1_score': F1Score(task='binary')(test_preds, y).item(),
|
98 |
}
|
99 |
+
|
100 |
+
# Get mean predictions
|
101 |
+
fp_mean, fn_mean = get_confidence_scores(y, test_preds_mean)
|
102 |
+
majority_vote_metrics['test_false_negatives_mean'] = fn_mean
|
103 |
+
majority_vote_metrics['test_false_positives_mean'] = fp_mean
|
104 |
+
|
105 |
return majority_vote_metrics
|
106 |
|
107 |
+
def get_suggestion(trial, dtype, hparams_range):
|
108 |
+
if dtype == 'int':
|
109 |
+
return trial.suggest_int(**hparams_range)
|
110 |
+
elif dtype == 'float':
|
111 |
+
return trial.suggest_float(**hparams_range)
|
112 |
+
elif dtype == 'categorical':
|
113 |
+
return trial.suggest_categorical(**hparams_range)
|
114 |
+
else:
|
115 |
+
raise ValueError(f'Invalid dtype for trial.suggest: {dtype}')
|
116 |
|
117 |
def pytorch_model_objective(
|
118 |
trial: optuna.Trial,
|
|
|
123 |
kf: StratifiedKFold | StratifiedGroupKFold,
|
124 |
groups: Optional[np.array] = None,
|
125 |
test_df: Optional[pd.DataFrame] = None,
|
126 |
+
hparams_ranges: Optional[List[Tuple[str, Dict[str, Any]]]] = None,
|
|
|
|
|
|
|
|
|
127 |
fast_dev_run: bool = False,
|
128 |
active_label: str = 'Active',
|
129 |
disabled_embeddings: List[str] = [],
|
|
|
139 |
trial (optuna.Trial): The Optuna trial object.
|
140 |
train_df (pd.DataFrame): The training set.
|
141 |
val_df (pd.DataFrame): The validation set.
|
142 |
+
hparams_ranges (List[Dict[str, Any]]): NOT IMPLEMENTED YET. Hyperparameters ranges.
|
143 |
+
The list must be of a tuple of the type of hparam to suggest ('int', 'float', or 'categorical'), and the dictionary must contain the arguments of the corresponding trial.suggest method.
|
|
|
|
|
|
|
144 |
fast_dev_run (bool): Whether to run a fast development run.
|
145 |
active_label (str): The active label column.
|
146 |
disabled_embeddings (List[str]): The list of disabled embeddings.
|
|
|
151 |
use_batch_norm = True
|
152 |
|
153 |
# Suggest hyperparameters to be used accross the CV folds
|
154 |
+
hidden_dim = trial.suggest_categorical('hidden_dim', [16, 32, 64, 128, 256, 512])
|
155 |
+
smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', [0] + list(range(3, 16)))
|
156 |
+
# hidden_dim = trial.suggest_int('hidden_dim', 32, 512, step=32)
|
157 |
+
# smote_k_neighbors = trial.suggest_int('smote_k_neighbors', 0, 12)
|
158 |
+
|
159 |
+
# use_smote = trial.suggest_categorical('use_smote', [True, False])
|
160 |
+
# smote_k_neighbors = smote_k_neighbors if use_smote else 0
|
161 |
+
# dropout = trial.suggest_float('dropout', 0, 0.5)
|
162 |
# use_batch_norm = trial.suggest_categorical('use_batch_norm', [True, False])
|
163 |
|
164 |
# Optimizer parameters
|
|
|
209 |
beta2=beta2,
|
210 |
eps=eps,
|
211 |
use_batch_norm=use_batch_norm,
|
212 |
+
# dropout=dropout,
|
213 |
max_epochs=max_epochs,
|
214 |
smote_k_neighbors=smote_k_neighbors,
|
215 |
apply_scaling=apply_scaling,
|
|
|
243 |
|
244 |
# Get the average validation accuracy and ROC AUC accross the folds
|
245 |
val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
|
246 |
+
val_acc = np.mean([r['val_acc'] for r in report])
|
247 |
+
logging.info(f'Average val accuracy: {val_acc}')
|
248 |
+
logging.info(f'Average val ROC AUC: {val_roc_auc}')
|
249 |
|
250 |
# Optuna aims to minimize the pytorch_model_objective
|
251 |
return - val_roc_auc
|
|
|
259 |
test_df: pd.DataFrame,
|
260 |
kf: StratifiedKFold | StratifiedGroupKFold,
|
261 |
groups: Optional[np.array] = None,
|
262 |
+
split_type: str = 'standard',
|
263 |
n_models_for_test: int = 3,
|
264 |
fast_dev_run: bool = False,
|
265 |
n_trials: int = 50,
|
|
|
298 |
|
299 |
# TODO: Make the following code more modular, i.e., the ranges shall be put
|
300 |
# in dictionaries or config files or something like that.
|
301 |
+
hparams_ranges = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
|
303 |
# Set the verbosity of Optuna
|
304 |
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
305 |
# Set a quasi-random sampler, as suggested in: https://github.com/google-research/tuning_playbook?tab=readme-ov-file#faqs
|
306 |
+
# sampler = QMCSampler(qmc_type='halton', scramble=True, seed=42)
|
307 |
+
sampler = TPESampler(seed=42, multivariate=True)
|
308 |
# Create an Optuna study object
|
309 |
study = optuna.create_study(direction='minimize', sampler=sampler)
|
310 |
|
|
|
327 |
kf=kf,
|
328 |
groups=groups,
|
329 |
test_df=test_df,
|
330 |
+
hparams_ranges=hparams_ranges,
|
|
|
|
|
|
|
|
|
331 |
fast_dev_run=fast_dev_run,
|
332 |
active_label=active_label,
|
333 |
max_epochs=max_epochs,
|
|
|
351 |
kf=kf,
|
352 |
groups=groups,
|
353 |
test_df=test_df,
|
354 |
+
hparams_ranges=hparams_ranges,
|
|
|
|
|
|
|
|
|
355 |
fast_dev_run=fast_dev_run,
|
356 |
active_label=active_label,
|
357 |
max_epochs=max_epochs,
|
|
|
387 |
return_predictions=True,
|
388 |
batch_size=128,
|
389 |
apply_scaling=True,
|
390 |
+
# use_batch_norm=True,
|
391 |
**study.best_params,
|
392 |
)
|
393 |
# Rename the keys in the metrics dictionary
|
|
|
467 |
majority_vote_metrics['disabled_embeddings'] = disabled_embeddings_str
|
468 |
ablation_report.append(majority_vote_metrics.copy())
|
469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
ablation_report = pd.DataFrame(ablation_report)
|
471 |
|
472 |
# Add a column with the split_type to all reports
|
protac_degradation_predictor/optuna_utils_xgboost.py
CHANGED
@@ -24,6 +24,21 @@ import torch
|
|
24 |
xgb.set_config(verbosity=0)
|
25 |
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def train_and_evaluate_xgboost(
|
28 |
protein2embedding: Dict,
|
29 |
cell2embedding: Dict,
|
@@ -92,24 +107,34 @@ def train_and_evaluate_xgboost(
|
|
92 |
# Evaluate model
|
93 |
val_pred = model.predict(dval)
|
94 |
val_pred_binary = (val_pred > 0.5).astype(int)
|
|
|
|
|
|
|
95 |
metrics = {
|
96 |
'val_acc': accuracy_score(y_val, val_pred_binary),
|
97 |
'val_roc_auc': roc_auc_score(y_val, val_pred),
|
98 |
'val_precision': precision_score(y_val, val_pred_binary),
|
99 |
'val_recall': recall_score(y_val, val_pred_binary),
|
100 |
'val_f1_score': f1_score(y_val, val_pred_binary),
|
|
|
|
|
101 |
}
|
102 |
preds = {'val_pred': val_pred}
|
103 |
|
104 |
if test_df is not None:
|
105 |
test_pred = model.predict(dtest)
|
106 |
-
test_pred_binary = (test_pred > 0.5).astype(int)
|
|
|
|
|
|
|
107 |
metrics.update({
|
108 |
'test_acc': accuracy_score(y_test, test_pred_binary),
|
109 |
'test_roc_auc': roc_auc_score(y_test, test_pred),
|
110 |
'test_precision': precision_score(y_test, test_pred_binary),
|
111 |
'test_recall': recall_score(y_test, test_pred_binary),
|
112 |
'test_f1_score': f1_score(y_test, test_pred_binary),
|
|
|
|
|
113 |
})
|
114 |
preds.update({'test_pred': test_pred})
|
115 |
|
@@ -328,7 +353,7 @@ def xgboost_hyperparameter_tuning_and_training(
|
|
328 |
|
329 |
# Save the trained model
|
330 |
if model_name:
|
331 |
-
model_filename = f'{model_name}_best_model_{split_type}_n{i}.json'
|
332 |
model.save_model(model_filename)
|
333 |
logging.info(f'Best XGBoost model saved to: {model_filename}')
|
334 |
test_report = pd.DataFrame(test_report)
|
|
|
24 |
xgb.set_config(verbosity=0)
|
25 |
|
26 |
|
27 |
+
def get_confidence_scores(y, y_pred, threshold=0.5):
|
28 |
+
# Calculate the likelihood for the false negative: get the mean value of
|
29 |
+
# the prediction for the false-positive and false-negatives
|
30 |
+
|
31 |
+
# Get the indices of the false positives and false negatives
|
32 |
+
false_positives = (y == 0) & ((y_pred > threshold).astype(int) == 1)
|
33 |
+
false_negatives = (y == 1) & ((y_pred > threshold).astype(int) == 0)
|
34 |
+
|
35 |
+
# Get the mean value of the predictions for the false positives and false negatives
|
36 |
+
false_positives_mean = y_pred[false_positives].mean()
|
37 |
+
false_negatives_mean = y_pred[false_negatives].mean()
|
38 |
+
|
39 |
+
return false_positives_mean, false_negatives_mean
|
40 |
+
|
41 |
+
|
42 |
def train_and_evaluate_xgboost(
|
43 |
protein2embedding: Dict,
|
44 |
cell2embedding: Dict,
|
|
|
107 |
# Evaluate model
|
108 |
val_pred = model.predict(dval)
|
109 |
val_pred_binary = (val_pred > 0.5).astype(int)
|
110 |
+
|
111 |
+
fp_mean, fn_mean = get_confidence_scores(y_val, val_pred)
|
112 |
+
|
113 |
metrics = {
|
114 |
'val_acc': accuracy_score(y_val, val_pred_binary),
|
115 |
'val_roc_auc': roc_auc_score(y_val, val_pred),
|
116 |
'val_precision': precision_score(y_val, val_pred_binary),
|
117 |
'val_recall': recall_score(y_val, val_pred_binary),
|
118 |
'val_f1_score': f1_score(y_val, val_pred_binary),
|
119 |
+
'val_false_positives_mean': fp_mean,
|
120 |
+
'val_false_negatives_mean': fn_mean,
|
121 |
}
|
122 |
preds = {'val_pred': val_pred}
|
123 |
|
124 |
if test_df is not None:
|
125 |
test_pred = model.predict(dtest)
|
126 |
+
test_pred_binary = (test_pred > 0.5).astype(int)
|
127 |
+
|
128 |
+
fp_mean, fn_mean = get_confidence_scores(y_test, test_pred)
|
129 |
+
|
130 |
metrics.update({
|
131 |
'test_acc': accuracy_score(y_test, test_pred_binary),
|
132 |
'test_roc_auc': roc_auc_score(y_test, test_pred),
|
133 |
'test_precision': precision_score(y_test, test_pred_binary),
|
134 |
'test_recall': recall_score(y_test, test_pred_binary),
|
135 |
'test_f1_score': f1_score(y_test, test_pred_binary),
|
136 |
+
'test_false_positives_mean': fp_mean,
|
137 |
+
'test_false_negatives_mean': fn_mean,
|
138 |
})
|
139 |
preds.update({'test_pred': test_pred})
|
140 |
|
|
|
353 |
|
354 |
# Save the trained model
|
355 |
if model_name:
|
356 |
+
model_filename = f'{model_name}_best_model_{split_type}_n{i}-test_acc={metrics["test_acc"]:.2f}-test_roc_auc={metrics["test_roc_auc"]:.3f}.json'
|
357 |
model.save_model(model_filename)
|
358 |
logging.info(f'Best XGBoost model saved to: {model_filename}')
|
359 |
test_report = pd.DataFrame(test_report)
|
protac_degradation_predictor/pytorch_models.py
CHANGED
@@ -336,21 +336,21 @@ class PROTAC_Model(pl.LightningModule):
|
|
336 |
else:
|
337 |
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
338 |
# Define LR scheduler
|
339 |
-
|
340 |
-
total_iters = self.trainer.max_epochs
|
341 |
-
elif self.trainer.max_steps:
|
342 |
-
total_iters = self.trainer.max_steps
|
343 |
-
else:
|
344 |
-
total_iters = 20
|
345 |
-
lr_scheduler = optim.lr_scheduler.LinearLR(
|
346 |
optimizer=optimizer,
|
347 |
-
|
|
|
|
|
348 |
)
|
349 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
# optimizer=optimizer,
|
351 |
-
#
|
352 |
-
# factor=0.01,
|
353 |
-
# patience=0,
|
354 |
# )
|
355 |
return {
|
356 |
'optimizer': optimizer,
|
@@ -418,6 +418,44 @@ class PROTAC_Model(pl.LightningModule):
|
|
418 |
logging.warning("Scalers not found in checkpoint. Consider re-fitting scalers if necessary.")
|
419 |
|
420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
# TODO: Use some sort of **kwargs to pass all the parameters to the model...
|
422 |
def train_model(
|
423 |
protein2embedding: Dict[str, np.ndarray],
|
@@ -448,6 +486,7 @@ def train_model(
|
|
448 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
449 |
return_predictions: bool = False,
|
450 |
shuffle_embedding_prob: float = 0.0,
|
|
|
451 |
) -> tuple:
|
452 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
453 |
|
@@ -532,7 +571,7 @@ def train_model(
|
|
532 |
),
|
533 |
pl.callbacks.EarlyStopping(
|
534 |
monitor='val_acc',
|
535 |
-
patience=10,
|
536 |
mode='max',
|
537 |
verbose=False,
|
538 |
),
|
@@ -604,9 +643,19 @@ def train_model(
|
|
604 |
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
|
605 |
val_pred = trainer.predict(model, val_dl)
|
606 |
val_pred = torch.concat(trainer.predict(model, val_dl)).squeeze()
|
|
|
|
|
|
|
|
|
|
|
607 |
if test_df is not None:
|
608 |
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
|
609 |
test_pred = torch.concat(trainer.predict(model, test_dl)).squeeze()
|
|
|
|
|
|
|
|
|
|
|
610 |
return model, trainer, metrics, val_pred, test_pred
|
611 |
return model, trainer, metrics, val_pred
|
612 |
return model, trainer, metrics
|
|
|
336 |
else:
|
337 |
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
338 |
# Define LR scheduler
|
339 |
+
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
optimizer=optimizer,
|
341 |
+
mode='min',
|
342 |
+
factor=0.1,
|
343 |
+
patience=0,
|
344 |
)
|
345 |
+
# if self.trainer.max_epochs:
|
346 |
+
# total_iters = self.trainer.max_epochs
|
347 |
+
# elif self.trainer.max_steps:
|
348 |
+
# total_iters = self.trainer.max_steps
|
349 |
+
# else:
|
350 |
+
# total_iters = 20
|
351 |
+
# lr_scheduler = optim.lr_scheduler.LinearLR(
|
352 |
# optimizer=optimizer,
|
353 |
+
# total_iters=total_iters,
|
|
|
|
|
354 |
# )
|
355 |
return {
|
356 |
'optimizer': optimizer,
|
|
|
418 |
logging.warning("Scalers not found in checkpoint. Consider re-fitting scalers if necessary.")
|
419 |
|
420 |
|
421 |
+
def get_confidence_scores(true_ds, y_preds, threshold=0.5):
|
422 |
+
# Calculate the likelihood for the false negative: get the mean value of
|
423 |
+
# the prediction for the false-positive and false-negatives
|
424 |
+
|
425 |
+
# Convert PyTorch dataset labels to numpy array
|
426 |
+
if isinstance(true_ds, PROTAC_Dataset):
|
427 |
+
true_vals = np.array([x['active'] for x in true_ds]).flatten()
|
428 |
+
elif isinstance(true_ds, torch.Tensor):
|
429 |
+
true_vals = true_ds.numpy().flatten()
|
430 |
+
elif isinstance(true_ds, np.ndarray):
|
431 |
+
true_vals = true_ds.flatten()
|
432 |
+
else:
|
433 |
+
raise ValueError("Unknown type for true labels.")
|
434 |
+
|
435 |
+
if isinstance(y_preds, torch.Tensor):
|
436 |
+
preds = y_preds.numpy().flatten()
|
437 |
+
elif isinstance(y_preds, np.ndarray):
|
438 |
+
preds = y_preds.flatten()
|
439 |
+
else:
|
440 |
+
raise ValueError("Unknown type for predictions.")
|
441 |
+
|
442 |
+
logging.info(f"True values: {true_vals}")
|
443 |
+
logging.info(f"Predictions: {preds}")
|
444 |
+
|
445 |
+
# Get the indices of the false positives and false negatives
|
446 |
+
false_positives = (true_vals == 0) & ((preds > threshold).astype(int) == 1)
|
447 |
+
false_negatives = (true_vals == 1) & ((preds > threshold).astype(int) == 0)
|
448 |
+
|
449 |
+
logging.info(f"False positives: {false_positives}")
|
450 |
+
logging.info(f"False negatives: {false_negatives}")
|
451 |
+
|
452 |
+
# Get the mean value of the predictions for the false positives and false negatives
|
453 |
+
false_positives_mean = preds[false_positives].mean()
|
454 |
+
false_negatives_mean = preds[false_negatives].mean()
|
455 |
+
|
456 |
+
return false_positives_mean, false_negatives_mean
|
457 |
+
|
458 |
+
|
459 |
# TODO: Use some sort of **kwargs to pass all the parameters to the model...
|
460 |
def train_model(
|
461 |
protein2embedding: Dict[str, np.ndarray],
|
|
|
486 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
487 |
return_predictions: bool = False,
|
488 |
shuffle_embedding_prob: float = 0.0,
|
489 |
+
use_smote: bool = False,
|
490 |
) -> tuple:
|
491 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
492 |
|
|
|
571 |
),
|
572 |
pl.callbacks.EarlyStopping(
|
573 |
monitor='val_acc',
|
574 |
+
patience=10, # Original: 10
|
575 |
mode='max',
|
576 |
verbose=False,
|
577 |
),
|
|
|
643 |
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
|
644 |
val_pred = trainer.predict(model, val_dl)
|
645 |
val_pred = torch.concat(trainer.predict(model, val_dl)).squeeze()
|
646 |
+
|
647 |
+
fp_mean, fn_mean = get_confidence_scores(val_ds, val_pred)
|
648 |
+
metrics['val_false_positives_mean'] = fp_mean
|
649 |
+
metrics['val_false_negatives_mean'] = fn_mean
|
650 |
+
|
651 |
if test_df is not None:
|
652 |
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
|
653 |
test_pred = torch.concat(trainer.predict(model, test_dl)).squeeze()
|
654 |
+
|
655 |
+
fp_mean, fn_mean = get_confidence_scores(test_ds, test_pred)
|
656 |
+
metrics['test_false_positives_mean'] = fp_mean
|
657 |
+
metrics['test_false_negatives_mean'] = fn_mean
|
658 |
+
|
659 |
return model, trainer, metrics, val_pred, test_pred
|
660 |
return model, trainer, metrics, val_pred
|
661 |
return model, trainer, metrics
|