|
import os |
|
from typing import Literal, List, Tuple, Optional, Dict, Any |
|
import logging |
|
|
|
from .pytorch_models import ( |
|
train_model, |
|
PROTAC_Model, |
|
evaluate_model, |
|
get_confidence_scores, |
|
) |
|
from .protac_dataset import get_datasets |
|
|
|
import torch |
|
import optuna |
|
from optuna.samplers import TPESampler, QMCSampler |
|
import joblib |
|
import pandas as pd |
|
from sklearn.model_selection import ( |
|
StratifiedKFold, |
|
StratifiedGroupKFold, |
|
) |
|
import numpy as np |
|
import pytorch_lightning as pl |
|
from torchmetrics import ( |
|
Accuracy, |
|
AUROC, |
|
Precision, |
|
Recall, |
|
F1Score, |
|
) |
|
|
|
|
|
def get_dataframe_stats( |
|
train_df = None, |
|
val_df = None, |
|
test_df = None, |
|
active_label = 'Active', |
|
) -> Dict: |
|
""" Get some statistics from the dataframes. |
|
|
|
Args: |
|
train_df (pd.DataFrame): The training set. |
|
val_df (pd.DataFrame): The validation set. |
|
test_df (pd.DataFrame): The test set. |
|
""" |
|
stats = {} |
|
if train_df is not None: |
|
stats['train_len'] = len(train_df) |
|
stats['train_active_perc'] = train_df[active_label].sum() / len(train_df) |
|
stats['train_inactive_perc'] = (len(train_df) - train_df[active_label].sum()) / len(train_df) |
|
stats['train_avg_tanimoto_dist'] = train_df['Avg Tanimoto'].mean() |
|
if val_df is not None: |
|
stats['val_len'] = len(val_df) |
|
stats['val_active_perc'] = val_df[active_label].sum() / len(val_df) |
|
stats['val_inactive_perc'] = (len(val_df) - val_df[active_label].sum()) / len(val_df) |
|
stats['val_avg_tanimoto_dist'] = val_df['Avg Tanimoto'].mean() |
|
if test_df is not None: |
|
stats['test_len'] = len(test_df) |
|
stats['test_active_perc'] = test_df[active_label].sum() / len(test_df) |
|
stats['test_inactive_perc'] = (len(test_df) - test_df[active_label].sum()) / len(test_df) |
|
stats['test_avg_tanimoto_dist'] = test_df['Avg Tanimoto'].mean() |
|
if train_df is not None and val_df is not None: |
|
leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(val_df['Uniprot']))) |
|
leaking_smiles = list(set(train_df['Smiles']).intersection(set(val_df['Smiles']))) |
|
stats['num_leaking_uniprot_train_val'] = len(leaking_uniprot) |
|
stats['num_leaking_smiles_train_val'] = len(leaking_smiles) |
|
stats['perc_leaking_uniprot_train_val'] = len(train_df[train_df['Uniprot'].isin(leaking_uniprot)]) / len(train_df) |
|
stats['perc_leaking_smiles_train_val'] = len(train_df[train_df['Smiles'].isin(leaking_smiles)]) / len(train_df) |
|
if train_df is not None and test_df is not None: |
|
leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(test_df['Uniprot']))) |
|
leaking_smiles = list(set(train_df['Smiles']).intersection(set(test_df['Smiles']))) |
|
stats['num_leaking_uniprot_train_test'] = len(leaking_uniprot) |
|
stats['num_leaking_smiles_train_test'] = len(leaking_smiles) |
|
stats['perc_leaking_uniprot_train_test'] = len(train_df[train_df['Uniprot'].isin(leaking_uniprot)]) / len(train_df) |
|
stats['perc_leaking_smiles_train_test'] = len(train_df[train_df['Smiles'].isin(leaking_smiles)]) / len(train_df) |
|
return stats |
|
|
|
|
|
def get_majority_vote_metrics( |
|
test_preds: List, |
|
test_df: pd.DataFrame, |
|
active_label: str = 'Active', |
|
) -> Dict: |
|
""" Get the majority vote metrics. """ |
|
test_preds_mean = np.array(test_preds).mean(axis=0) |
|
logging.info(f'Test predictions: {test_preds}') |
|
logging.info(f'Test predictions mean: {test_preds_mean}') |
|
test_preds = torch.stack(test_preds) |
|
test_preds, _ = torch.mode(test_preds, dim=0) |
|
y = torch.tensor(test_df[active_label].tolist()) |
|
|
|
majority_vote_metrics = { |
|
'test_acc': Accuracy(task='binary')(test_preds, y).item(), |
|
'test_roc_auc': AUROC(task='binary')(test_preds, y).item(), |
|
'test_precision': Precision(task='binary')(test_preds, y).item(), |
|
'test_recall': Recall(task='binary')(test_preds, y).item(), |
|
'test_f1_score': F1Score(task='binary')(test_preds, y).item(), |
|
} |
|
|
|
|
|
fp_mean, fn_mean = get_confidence_scores(y, test_preds_mean) |
|
majority_vote_metrics['test_false_negatives_mean'] = fn_mean |
|
majority_vote_metrics['test_false_positives_mean'] = fp_mean |
|
|
|
return majority_vote_metrics |
|
|
|
def get_suggestion(trial, dtype, hparams_range): |
|
if dtype == 'int': |
|
return trial.suggest_int(**hparams_range) |
|
elif dtype == 'float': |
|
return trial.suggest_float(**hparams_range) |
|
elif dtype == 'categorical': |
|
return trial.suggest_categorical(**hparams_range) |
|
else: |
|
raise ValueError(f'Invalid dtype for trial.suggest: {dtype}') |
|
|
|
def pytorch_model_objective( |
|
trial: optuna.Trial, |
|
protein2embedding: Dict, |
|
cell2embedding: Dict, |
|
smiles2fp: Dict, |
|
train_val_df: pd.DataFrame, |
|
kf: StratifiedKFold | StratifiedGroupKFold, |
|
groups: Optional[np.array] = None, |
|
test_df: Optional[pd.DataFrame] = None, |
|
hparams_ranges: Optional[List[Tuple[str, Dict[str, Any]]]] = None, |
|
fast_dev_run: bool = False, |
|
active_label: str = 'Active', |
|
disabled_embeddings: List[str] = [], |
|
max_epochs: int = 100, |
|
use_logger: bool = False, |
|
logger_save_dir: str = 'logs', |
|
logger_name: str = 'cv_model', |
|
enable_checkpointing: bool = False, |
|
) -> float: |
|
""" Objective function for hyperparameter optimization. |
|
|
|
Args: |
|
trial (optuna.Trial): The Optuna trial object. |
|
train_df (pd.DataFrame): The training set. |
|
val_df (pd.DataFrame): The validation set. |
|
hparams_ranges (List[Dict[str, Any]]): NOT IMPLEMENTED YET. Hyperparameters ranges. |
|
The list must be of a tuple of the type of hparam to suggest ('int', 'float', or 'categorical'), and the dictionary must contain the arguments of the corresponding trial.suggest method. |
|
fast_dev_run (bool): Whether to run a fast development run. |
|
active_label (str): The active label column. |
|
disabled_embeddings (List[str]): The list of disabled embeddings. |
|
""" |
|
|
|
batch_size = 128 |
|
apply_scaling = True |
|
use_batch_norm = True |
|
|
|
|
|
hidden_dim = trial.suggest_categorical('hidden_dim', [16, 32, 64, 128, 256, 512]) |
|
smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', [0] + list(range(3, 16))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
learning_rate = trial.suggest_float('learning_rate', 1e-6, 1e-1, log=True) |
|
beta1 = trial.suggest_float('beta1', 0.1, 0.999) |
|
beta2 = trial.suggest_float('beta2', 0.1, 0.999) |
|
eps = trial.suggest_float('eps', 1e-9, 1.0, log=True) |
|
|
|
|
|
X = train_val_df.copy().drop(columns=active_label) |
|
y = train_val_df[active_label].tolist() |
|
report = [] |
|
val_preds = [] |
|
test_preds = [] |
|
for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)): |
|
logging.info(f'Fold {k + 1}/{kf.get_n_splits()}') |
|
|
|
train_df = train_val_df.iloc[train_index] |
|
val_df = train_val_df.iloc[val_index] |
|
|
|
|
|
stats = { |
|
'model_type': 'Pytorch', |
|
'fold': k, |
|
'train_len': len(train_df), |
|
'val_len': len(val_df), |
|
'train_perc': len(train_df) / len(train_val_df), |
|
'val_perc': len(val_df) / len(train_val_df), |
|
} |
|
stats.update(get_dataframe_stats(train_df, val_df, test_df, active_label)) |
|
if groups is not None: |
|
stats['train_unique_groups'] = len(np.unique(groups[train_index])) |
|
stats['val_unique_groups'] = len(np.unique(groups[val_index])) |
|
|
|
|
|
|
|
ret = train_model( |
|
protein2embedding=protein2embedding, |
|
cell2embedding=cell2embedding, |
|
smiles2fp=smiles2fp, |
|
train_df=train_df, |
|
val_df=val_df, |
|
test_df=test_df, |
|
hidden_dim=hidden_dim, |
|
batch_size=batch_size, |
|
learning_rate=learning_rate, |
|
beta1=beta1, |
|
beta2=beta2, |
|
eps=eps, |
|
use_batch_norm=use_batch_norm, |
|
|
|
max_epochs=max_epochs, |
|
smote_k_neighbors=smote_k_neighbors, |
|
apply_scaling=apply_scaling, |
|
fast_dev_run=fast_dev_run, |
|
active_label=active_label, |
|
return_predictions=True, |
|
disabled_embeddings=disabled_embeddings, |
|
use_logger=use_logger, |
|
logger_save_dir=logger_save_dir, |
|
logger_name=f'{logger_name}_fold{k}', |
|
enable_checkpointing=enable_checkpointing, |
|
) |
|
if test_df is not None: |
|
_, _, metrics, val_pred, test_pred = ret |
|
test_preds.append(test_pred) |
|
else: |
|
_, _, metrics, val_pred = ret |
|
stats.update(metrics) |
|
report.append(stats.copy()) |
|
val_preds.append(val_pred) |
|
|
|
|
|
trial.set_user_attr('report', report) |
|
|
|
|
|
if test_df is not None and not fast_dev_run: |
|
majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label) |
|
majority_vote_metrics.update(get_dataframe_stats(train_df, val_df, test_df, active_label)) |
|
trial.set_user_attr('majority_vote_metrics', majority_vote_metrics) |
|
logging.info(f'Majority vote metrics: {majority_vote_metrics}') |
|
|
|
|
|
val_roc_auc = np.mean([r['val_roc_auc'] for r in report]) |
|
val_acc = np.mean([r['val_acc'] for r in report]) |
|
logging.info(f'Average val accuracy: {val_acc}') |
|
logging.info(f'Average val ROC AUC: {val_roc_auc}') |
|
|
|
|
|
return - val_roc_auc |
|
|
|
|
|
def hyperparameter_tuning_and_training( |
|
protein2embedding: Dict, |
|
cell2embedding: Dict, |
|
smiles2fp: Dict, |
|
train_val_df: pd.DataFrame, |
|
test_df: pd.DataFrame, |
|
kf: StratifiedKFold | StratifiedGroupKFold, |
|
groups: Optional[np.array] = None, |
|
split_type: str = 'standard', |
|
n_models_for_test: int = 3, |
|
fast_dev_run: bool = False, |
|
n_trials: int = 50, |
|
logger_save_dir: str = 'logs', |
|
logger_name: str = 'protac_hparam_search', |
|
active_label: str = 'Active', |
|
max_epochs: int = 100, |
|
study_filename: Optional[str] = None, |
|
force_study: bool = False, |
|
) -> tuple: |
|
""" Hyperparameter tuning and training of a PROTAC model. |
|
|
|
Args: |
|
protein2embedding (Dict): The protein to embedding dictionary. |
|
cell2embedding (Dict): The cell to embedding dictionary. |
|
smiles2fp (Dict): The SMILES to fingerprint dictionary. |
|
train_val_df (pd.DataFrame): The training and validation set. |
|
test_df (pd.DataFrame): The test set. |
|
kf (StratifiedKFold | StratifiedGroupKFold): The KFold object. |
|
groups (np.array): The groups for the StratifiedGroupKFold. |
|
split_type (str): The split type of the current study. Used for reporting. |
|
n_models_for_test (int): The number of models to train for the test set. |
|
fast_dev_run (bool): Whether to run a fast development run. |
|
n_trials (int): The number of trials for the hyperparameter search. |
|
logger_save_dir (str): The logger save directory. |
|
logger_name (str): The logger name. |
|
active_label (str): The active label column. |
|
max_epochs (int): The maximum number of epochs. |
|
study_filename (str): The study filename. |
|
force_study (bool): Whether to force the study. |
|
|
|
Returns: |
|
tuple: The trained model, the trainer, and the best metrics. |
|
""" |
|
pl.seed_everything(42) |
|
|
|
|
|
|
|
hparams_ranges = None |
|
|
|
|
|
optuna.logging.set_verbosity(optuna.logging.WARNING) |
|
|
|
|
|
sampler = TPESampler(seed=42, multivariate=True) |
|
|
|
study = optuna.create_study(direction='minimize', sampler=sampler) |
|
|
|
study_loaded = False |
|
if study_filename and not force_study: |
|
if os.path.exists(study_filename): |
|
study = joblib.load(study_filename) |
|
study_loaded = True |
|
logging.info(f'Loaded study from {study_filename}') |
|
logging.info(f'Study best params: {study.best_params}') |
|
|
|
if not study_loaded or force_study: |
|
study.optimize( |
|
lambda trial: pytorch_model_objective( |
|
trial=trial, |
|
protein2embedding=protein2embedding, |
|
cell2embedding=cell2embedding, |
|
smiles2fp=smiles2fp, |
|
train_val_df=train_val_df, |
|
kf=kf, |
|
groups=groups, |
|
test_df=test_df, |
|
hparams_ranges=hparams_ranges, |
|
fast_dev_run=fast_dev_run, |
|
active_label=active_label, |
|
max_epochs=max_epochs, |
|
disabled_embeddings=[], |
|
), |
|
n_trials=n_trials, |
|
) |
|
if study_filename: |
|
joblib.dump(study, study_filename) |
|
|
|
cv_report = pd.DataFrame(study.best_trial.user_attrs['report']) |
|
hparam_report = pd.DataFrame([study.best_params]) |
|
|
|
|
|
pytorch_model_objective( |
|
trial=study.best_trial, |
|
protein2embedding=protein2embedding, |
|
cell2embedding=cell2embedding, |
|
smiles2fp=smiles2fp, |
|
train_val_df=train_val_df, |
|
kf=kf, |
|
groups=groups, |
|
test_df=test_df, |
|
hparams_ranges=hparams_ranges, |
|
fast_dev_run=fast_dev_run, |
|
active_label=active_label, |
|
max_epochs=max_epochs, |
|
disabled_embeddings=[], |
|
use_logger=True, |
|
logger_save_dir=logger_save_dir, |
|
logger_name=f'cv_model_{logger_name}', |
|
enable_checkpointing=True, |
|
) |
|
|
|
|
|
best_models = [] |
|
test_report = [] |
|
test_preds = [] |
|
dfs_stats = get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label) |
|
for i in range(n_models_for_test): |
|
pl.seed_everything(42 + i + 1) |
|
model, trainer, metrics, test_pred = train_model( |
|
protein2embedding=protein2embedding, |
|
cell2embedding=cell2embedding, |
|
smiles2fp=smiles2fp, |
|
train_df=train_val_df, |
|
val_df=test_df, |
|
fast_dev_run=fast_dev_run, |
|
active_label=active_label, |
|
max_epochs=max_epochs, |
|
disabled_embeddings=[], |
|
use_logger=True, |
|
logger_save_dir=logger_save_dir, |
|
logger_name=f'best_model_n{i}_{logger_name}', |
|
enable_checkpointing=True, |
|
checkpoint_model_name=f'best_model_n{i}_{split_type}', |
|
return_predictions=True, |
|
batch_size=128, |
|
apply_scaling=True, |
|
|
|
**study.best_params, |
|
) |
|
|
|
metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()} |
|
metrics['model_type'] = 'Pytorch' |
|
metrics['test_model_id'] = i |
|
metrics.update(dfs_stats) |
|
|
|
test_report.append(metrics.copy()) |
|
test_preds.append(test_pred) |
|
best_models.append({'model': model, 'trainer': trainer}) |
|
test_report = pd.DataFrame(test_report) |
|
|
|
|
|
if not fast_dev_run: |
|
majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label) |
|
majority_vote_metrics.update(get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label)) |
|
majority_vote_metrics_cv = study.best_trial.user_attrs['majority_vote_metrics'] |
|
majority_vote_metrics_cv['cv_models'] = True |
|
majority_vote_report = pd.DataFrame([ |
|
majority_vote_metrics, |
|
majority_vote_metrics_cv, |
|
]) |
|
majority_vote_report['model_type'] = 'Pytorch' |
|
majority_vote_report['split_type'] = split_type |
|
|
|
|
|
ablation_report = [] |
|
dfs_stats = get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label) |
|
disabled_embeddings_combinations = [ |
|
['e3'], |
|
['poi'], |
|
['cell'], |
|
['smiles'], |
|
['e3', 'cell'], |
|
['poi', 'e3'], |
|
['poi', 'e3', 'cell'], |
|
] |
|
for disabled_embeddings in disabled_embeddings_combinations: |
|
logging.info('-' * 100) |
|
logging.info(f'Ablation study with disabled embeddings: {disabled_embeddings}') |
|
logging.info('-' * 100) |
|
disabled_embeddings_str = 'disabled ' + ' '.join(disabled_embeddings) |
|
test_preds = [] |
|
for i, model_trainer in enumerate(best_models): |
|
logging.info(f'Evaluating model n.{i} on {disabled_embeddings_str}.') |
|
model = model_trainer['model'] |
|
trainer = model_trainer['trainer'] |
|
_, test_ds, _ = get_datasets( |
|
protein2embedding=protein2embedding, |
|
cell2embedding=cell2embedding, |
|
smiles2fp=smiles2fp, |
|
train_df=train_val_df, |
|
val_df=test_df, |
|
disabled_embeddings=disabled_embeddings, |
|
active_label=active_label, |
|
scaler=model.scalers, |
|
use_single_scaler=model.join_embeddings == 'beginning', |
|
) |
|
ret = evaluate_model(model, trainer, test_ds, batch_size=128) |
|
|
|
|
|
test_preds.append(ret['val_pred']) |
|
ret['val_metrics'] = {k.replace('val_', 'test_'): v for k, v in ret['val_metrics'].items()} |
|
ret['val_metrics'].update(dfs_stats) |
|
ret['val_metrics']['majority_vote'] = False |
|
ret['val_metrics']['model_type'] = 'Pytorch' |
|
ret['val_metrics']['disabled_embeddings'] = disabled_embeddings_str |
|
ablation_report.append(ret['val_metrics'].copy()) |
|
|
|
|
|
if not fast_dev_run: |
|
majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label) |
|
majority_vote_metrics.update(dfs_stats) |
|
majority_vote_metrics['majority_vote'] = True |
|
majority_vote_metrics['model_type'] = 'Pytorch' |
|
majority_vote_metrics['disabled_embeddings'] = disabled_embeddings_str |
|
ablation_report.append(majority_vote_metrics.copy()) |
|
|
|
ablation_report = pd.DataFrame(ablation_report) |
|
|
|
|
|
for report in [cv_report, hparam_report, test_report, ablation_report]: |
|
report['split_type'] = split_type |
|
|
|
|
|
ret = { |
|
'cv_report': cv_report, |
|
'hparam_report': hparam_report, |
|
'test_report': test_report, |
|
'ablation_report': ablation_report, |
|
} |
|
if not fast_dev_run: |
|
ret['majority_vote_report'] = majority_vote_report |
|
return ret |