|
import warnings |
|
import pickle |
|
import logging |
|
from typing import Literal, List, Tuple, Optional, Dict |
|
|
|
from .protac_dataset import PROTAC_Dataset, get_datasets |
|
from .config import config |
|
|
|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
import pytorch_lightning as pl |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchmetrics import ( |
|
Accuracy, |
|
AUROC, |
|
Precision, |
|
Recall, |
|
F1Score, |
|
MetricCollection, |
|
) |
|
from imblearn.over_sampling import SMOTE |
|
from sklearn.preprocessing import StandardScaler |
|
|
|
|
|
class PROTAC_Predictor(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
hidden_dim: int, |
|
smiles_emb_dim: int = config.fingerprint_size, |
|
poi_emb_dim: int = config.protein_embedding_size, |
|
e3_emb_dim: int = config.protein_embedding_size, |
|
cell_emb_dim: int = config.cell_embedding_size, |
|
dropout: float = 0.2, |
|
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum', |
|
use_batch_norm: bool = False, |
|
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [], |
|
): |
|
""" Initialize the PROTAC model. |
|
|
|
Args: |
|
hidden_dim (int): The hidden dimension of the model |
|
smiles_emb_dim (int): The dimension of the SMILES embeddings |
|
poi_emb_dim (int): The dimension of the POI embeddings |
|
e3_emb_dim (int): The dimension of the E3 Ligase embeddings |
|
cell_emb_dim (int): The dimension of the cell line embeddings |
|
dropout (float): The dropout rate |
|
join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings |
|
disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles' |
|
""" |
|
super().__init__() |
|
|
|
self.__dict__.update(locals()) |
|
|
|
|
|
|
|
|
|
if self.join_embeddings != 'beginning': |
|
if 'poi' not in self.disabled_embeddings: |
|
self.poi_fc = nn.Sequential( |
|
nn.Linear(poi_emb_dim, hidden_dim), |
|
nn.Softmax(dim=1), |
|
) |
|
if 'e3' not in self.disabled_embeddings: |
|
self.e3_fc = nn.Sequential( |
|
nn.Linear(e3_emb_dim, hidden_dim), |
|
nn.Softmax(dim=1), |
|
) |
|
if 'cell' not in self.disabled_embeddings: |
|
self.cell_fc = nn.Sequential( |
|
nn.Linear(cell_emb_dim, hidden_dim), |
|
nn.Softmax(dim=1), |
|
) |
|
if 'smiles' not in self.disabled_embeddings: |
|
self.smiles_emb = nn.Sequential( |
|
nn.Linear(smiles_emb_dim, hidden_dim), |
|
nn.Softmax(dim=1), |
|
) |
|
|
|
|
|
if self.join_embeddings == 'beginning': |
|
joint_dim = smiles_emb_dim if 'smiles' not in self.disabled_embeddings else 0 |
|
joint_dim += poi_emb_dim if 'poi' not in self.disabled_embeddings else 0 |
|
joint_dim += e3_emb_dim if 'e3' not in self.disabled_embeddings else 0 |
|
joint_dim += cell_emb_dim if 'cell' not in self.disabled_embeddings else 0 |
|
self.fc0 = nn.Linear(joint_dim, joint_dim) |
|
elif self.join_embeddings == 'concat': |
|
joint_dim = hidden_dim * (4 - len(self.disabled_embeddings)) |
|
elif self.join_embeddings == 'sum': |
|
joint_dim = hidden_dim |
|
|
|
self.fc1 = nn.Linear(joint_dim, hidden_dim) |
|
self.fc2 = nn.Linear(hidden_dim, hidden_dim) |
|
self.fc3 = nn.Linear(hidden_dim, 1) |
|
|
|
self.bnorm = nn.BatchNorm1d(hidden_dim) |
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
|
|
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, return_embeddings=False): |
|
embeddings = [] |
|
if self.join_embeddings == 'beginning': |
|
|
|
if 'poi' not in self.disabled_embeddings: |
|
embeddings.append(poi_emb) |
|
if 'e3' not in self.disabled_embeddings: |
|
embeddings.append(e3_emb) |
|
if 'cell' not in self.disabled_embeddings: |
|
embeddings.append(cell_emb) |
|
if 'smiles' not in self.disabled_embeddings: |
|
embeddings.append(smiles_emb) |
|
x = torch.cat(embeddings, dim=1) |
|
x = self.dropout(F.relu(self.fc0(x))) |
|
else: |
|
if 'poi' not in self.disabled_embeddings: |
|
embeddings.append(self.poi_fc(poi_emb)) |
|
if torch.isnan(embeddings[-1]).any(): |
|
raise ValueError("NaN values found in POI embeddings.") |
|
|
|
if 'e3' not in self.disabled_embeddings: |
|
embeddings.append(self.e3_fc(e3_emb)) |
|
if torch.isnan(embeddings[-1]).any(): |
|
raise ValueError("NaN values found in E3 embeddings.") |
|
|
|
if 'cell' not in self.disabled_embeddings: |
|
embeddings.append(self.cell_fc(cell_emb)) |
|
if torch.isnan(embeddings[-1]).any(): |
|
raise ValueError("NaN values found in cell embeddings.") |
|
|
|
if 'smiles' not in self.disabled_embeddings: |
|
embeddings.append(self.smiles_emb(smiles_emb)) |
|
if torch.isnan(embeddings[-1]).any(): |
|
raise ValueError("NaN values found in SMILES embeddings.") |
|
|
|
if self.join_embeddings == 'concat': |
|
x = torch.cat(embeddings, dim=1) |
|
elif self.join_embeddings == 'sum': |
|
if len(embeddings) > 1: |
|
embeddings = torch.stack(embeddings, dim=1) |
|
x = torch.sum(embeddings, dim=1) |
|
else: |
|
x = embeddings[0] |
|
if torch.isnan(x).any(): |
|
raise ValueError("NaN values found in sum of softmax-ed embeddings.") |
|
x = F.relu(self.fc1(x)) |
|
h = self.bnorm(x) if self.use_batch_norm else self.self.dropout(x) |
|
x = self.fc3(h) |
|
if return_embeddings: |
|
return x, h |
|
return x |
|
|
|
|
|
class PROTAC_Model(pl.LightningModule): |
|
|
|
def __init__( |
|
self, |
|
hidden_dim: int, |
|
smiles_emb_dim: int = config.fingerprint_size, |
|
poi_emb_dim: int = config.protein_embedding_size, |
|
e3_emb_dim: int = config.protein_embedding_size, |
|
cell_emb_dim: int = config.cell_embedding_size, |
|
batch_size: int = 128, |
|
learning_rate: float = 1e-3, |
|
dropout: float = 0.2, |
|
use_batch_norm: bool = False, |
|
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum', |
|
train_dataset: PROTAC_Dataset = None, |
|
val_dataset: PROTAC_Dataset = None, |
|
test_dataset: PROTAC_Dataset = None, |
|
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [], |
|
apply_scaling: bool = True, |
|
extra_optim_params: Optional[dict] = None, |
|
): |
|
""" Initialize the PROTAC Pytorch Lightning model. |
|
|
|
Args: |
|
hidden_dim (int): The hidden dimension of the model |
|
smiles_emb_dim (int): The dimension of the SMILES embeddings |
|
poi_emb_dim (int): The dimension of the POI embeddings |
|
e3_emb_dim (int): The dimension of the E3 Ligase embeddings |
|
cell_emb_dim (int): The dimension of the cell line embeddings |
|
batch_size (int): The batch size |
|
learning_rate (float): The learning rate |
|
dropout (float): The dropout rate |
|
join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings |
|
train_dataset (PROTAC_Dataset): The training dataset |
|
val_dataset (PROTAC_Dataset): The validation dataset |
|
test_dataset (PROTAC_Dataset): The test dataset |
|
disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles' |
|
apply_scaling (bool): Whether to apply scaling to the embeddings |
|
extra_optim_params (dict): Extra parameters for the optimizer |
|
""" |
|
super().__init__() |
|
|
|
self.__dict__.update(locals()) |
|
|
|
ignore_args_as_hyperparams = [ |
|
'train_dataset', |
|
'test_dataset', |
|
'val_dataset', |
|
] |
|
self.save_hyperparameters(ignore=ignore_args_as_hyperparams) |
|
|
|
self.model = PROTAC_Predictor( |
|
hidden_dim=hidden_dim, |
|
smiles_emb_dim=smiles_emb_dim, |
|
poi_emb_dim=poi_emb_dim, |
|
e3_emb_dim=e3_emb_dim, |
|
cell_emb_dim=cell_emb_dim, |
|
dropout=dropout, |
|
join_embeddings=join_embeddings, |
|
use_batch_norm=use_batch_norm, |
|
disabled_embeddings=[], |
|
) |
|
|
|
stages = ['train_metrics', 'val_metrics', 'test_metrics'] |
|
self.metrics = nn.ModuleDict({s: MetricCollection({ |
|
'acc': Accuracy(task='binary'), |
|
'roc_auc': AUROC(task='binary'), |
|
'precision': Precision(task='binary'), |
|
'recall': Recall(task='binary'), |
|
'f1_score': F1Score(task='binary'), |
|
}, prefix=s.replace('metrics', '')) for s in stages}) |
|
|
|
|
|
self.missing_dataset_error = \ |
|
'''Class variable `{0}` is None. If the model was loaded from a checkpoint, the dataset must be set manually: |
|
|
|
model = {1}.load_from_checkpoint('checkpoint.ckpt') |
|
model.{0} = my_{0} |
|
''' |
|
|
|
|
|
self.scalers = None |
|
if self.apply_scaling and self.train_dataset is not None: |
|
self.initialize_scalers() |
|
|
|
def initialize_scalers(self): |
|
"""Initialize or reinitialize scalers based on dataset properties.""" |
|
if self.scalers is None: |
|
use_single_scaler = self.join_embeddings == 'beginning' |
|
self.scalers = self.train_dataset.fit_scaling(use_single_scaler) |
|
self.apply_scalers() |
|
|
|
def apply_scalers(self): |
|
"""Apply scalers to all datasets.""" |
|
use_single_scaler = self.join_embeddings == 'beginning' |
|
if self.train_dataset: |
|
self.train_dataset.apply_scaling(self.scalers, use_single_scaler) |
|
if self.val_dataset: |
|
self.val_dataset.apply_scaling(self.scalers, use_single_scaler) |
|
if self.test_dataset: |
|
self.test_dataset.apply_scaling(self.scalers, use_single_scaler) |
|
|
|
def scale_tensor( |
|
self, |
|
tensor: torch.Tensor, |
|
scaler: StandardScaler, |
|
alpha: float = 1e-10, |
|
) -> torch.Tensor: |
|
"""Scale a tensor using a scaler. This is done to avoid using numpy |
|
arrays (and stay on the same device). |
|
|
|
Args: |
|
tensor (torch.Tensor): The tensor to scale. |
|
scaler (StandardScaler): The scaler to use. |
|
|
|
Returns: |
|
torch.Tensor: The scaled tensor. |
|
""" |
|
tensor = tensor.float() |
|
if scaler.with_mean: |
|
tensor -= torch.tensor(scaler.mean_, dtype=tensor.dtype, device=tensor.device) |
|
if scaler.with_std: |
|
tensor /= torch.tensor(scaler.scale_, dtype=tensor.dtype, device=tensor.device) + alpha |
|
return tensor |
|
|
|
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb, prescaled_embeddings=True, return_embeddings=False): |
|
if not prescaled_embeddings: |
|
if self.apply_scaling: |
|
if self.join_embeddings == 'beginning': |
|
embeddings = self.scale_tensor( |
|
torch.hstack([smiles_emb, poi_emb, e3_emb, cell_emb]), |
|
self.scalers, |
|
) |
|
smiles_emb = embeddings[:, :self.smiles_emb_dim] |
|
poi_emb = embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.poi_emb_dim] |
|
e3_emb = embeddings[:, self.smiles_emb_dim+self.poi_emb_dim:self.smiles_emb_dim+2*self.poi_emb_dim] |
|
cell_emb = embeddings[:, -self.cell_emb_dim:] |
|
else: |
|
poi_emb = self.scale_tensor(poi_emb, self.scalers['Uniprot']) |
|
e3_emb = self.scale_tensor(e3_emb, self.scalers['E3 Ligase Uniprot']) |
|
cell_emb = self.scale_tensor(cell_emb, self.scalers['Cell Line Identifier']) |
|
smiles_emb = self.scale_tensor(smiles_emb, self.scalers['Smiles']) |
|
if torch.isnan(poi_emb).any(): |
|
raise ValueError("NaN values found in POI embeddings.") |
|
if torch.isnan(e3_emb).any(): |
|
raise ValueError("NaN values found in E3 embeddings.") |
|
if torch.isnan(cell_emb).any(): |
|
raise ValueError("NaN values found in cell embeddings.") |
|
if torch.isnan(smiles_emb).any(): |
|
raise ValueError("NaN values found in SMILES embeddings.") |
|
return self.model(poi_emb, e3_emb, cell_emb, smiles_emb, return_embeddings) |
|
|
|
def step(self, batch, batch_idx, stage): |
|
poi_emb = batch['poi_emb'] |
|
e3_emb = batch['e3_emb'] |
|
cell_emb = batch['cell_emb'] |
|
smiles_emb = batch['smiles_emb'] |
|
y = batch['active'].float().unsqueeze(1) |
|
|
|
y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb) |
|
loss = F.binary_cross_entropy_with_logits(y_hat, y) |
|
|
|
self.metrics[f'{stage}_metrics'].update(y_hat, y) |
|
self.log(f'{stage}_loss', loss, on_epoch=True, prog_bar=True) |
|
self.log_dict(self.metrics[f'{stage}_metrics'], on_epoch=True) |
|
|
|
return loss |
|
|
|
def training_step(self, batch, batch_idx): |
|
return self.step(batch, batch_idx, 'train') |
|
|
|
def validation_step(self, batch, batch_idx): |
|
return self.step(batch, batch_idx, 'val') |
|
|
|
def test_step(self, batch, batch_idx): |
|
return self.step(batch, batch_idx, 'test') |
|
|
|
def configure_optimizers(self): |
|
|
|
if self.extra_optim_params is not None: |
|
optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate, **self.extra_optim_params) |
|
else: |
|
optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate) |
|
|
|
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
|
optimizer=optimizer, |
|
mode='min', |
|
factor=0.1, |
|
patience=0, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return { |
|
'optimizer': optimizer, |
|
'lr_scheduler': lr_scheduler, |
|
'interval': 'step', |
|
'frequency': 1, |
|
'monitor': 'val_loss', |
|
} |
|
|
|
def predict_step(self, batch, batch_idx): |
|
poi_emb = batch['poi_emb'] |
|
e3_emb = batch['e3_emb'] |
|
cell_emb = batch['cell_emb'] |
|
smiles_emb = batch['smiles_emb'] |
|
y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb) |
|
return torch.sigmoid(y_hat) |
|
|
|
def train_dataloader(self): |
|
if self.train_dataset is None: |
|
format = 'train_dataset', self.__class__.__name__ |
|
raise ValueError(self.missing_dataset_error.format(*format)) |
|
|
|
return DataLoader( |
|
self.train_dataset, |
|
batch_size=self.batch_size, |
|
shuffle=True, |
|
|
|
) |
|
|
|
def val_dataloader(self): |
|
if self.val_dataset is None: |
|
format = 'val_dataset', self.__class__.__name__ |
|
raise ValueError(self.missing_dataset_error.format(*format)) |
|
return DataLoader( |
|
self.val_dataset, |
|
batch_size=self.batch_size, |
|
shuffle=False, |
|
) |
|
|
|
def test_dataloader(self): |
|
if self.test_dataset is None: |
|
format = 'test_dataset', self.__class__.__name__ |
|
raise ValueError(self.missing_dataset_error.format(*format)) |
|
return DataLoader( |
|
self.test_dataset, |
|
batch_size=self.batch_size, |
|
shuffle=False, |
|
) |
|
|
|
def on_save_checkpoint(self, checkpoint): |
|
""" Serialize the scalers to the checkpoint. """ |
|
checkpoint['scalers'] = pickle.dumps(self.scalers) |
|
|
|
def on_load_checkpoint(self, checkpoint): |
|
"""Deserialize the scalers from the checkpoint.""" |
|
if 'scalers' in checkpoint: |
|
self.scalers = pickle.loads(checkpoint['scalers']) |
|
else: |
|
self.scalers = None |
|
if self.apply_scaling: |
|
if self.scalers is not None: |
|
|
|
self.apply_scalers() |
|
else: |
|
logging.warning("Scalers not found in checkpoint. Consider re-fitting scalers if necessary.") |
|
|
|
|
|
def get_confidence_scores( |
|
true_ds: PROTAC_Dataset | torch.Tensor | np.ndarray, |
|
y_preds: torch.Tensor | np.ndarray, |
|
threshold: float = 0.5, |
|
) -> Tuple[float, float]: |
|
""" Get the mean value of the predictions for the false positives and false negatives. |
|
|
|
Args: |
|
true_ds (PROTAC_Dataset | torch.Tensor | np.ndarray): The true labels |
|
y_preds (torch.Tensor | np.ndarray): The predictions |
|
threshold (float): The threshold to use for the predictions |
|
|
|
Returns: |
|
Tuple[float, float]: The mean value of the predictions for the false positives and false negatives. |
|
""" |
|
|
|
|
|
if isinstance(true_ds, PROTAC_Dataset): |
|
true_vals = np.array([x['active'] for x in true_ds]).flatten() |
|
elif isinstance(true_ds, torch.Tensor): |
|
true_vals = true_ds.numpy().flatten() |
|
elif isinstance(true_ds, np.ndarray): |
|
true_vals = true_ds.flatten() |
|
else: |
|
raise ValueError("Unknown type for true labels.") |
|
|
|
if isinstance(y_preds, torch.Tensor): |
|
preds = y_preds.numpy().flatten() |
|
elif isinstance(y_preds, np.ndarray): |
|
preds = y_preds.flatten() |
|
else: |
|
raise ValueError("Unknown type for predictions.") |
|
|
|
|
|
false_positives = (true_vals == 0) & ((preds > threshold).astype(int) == 1) |
|
false_negatives = (true_vals == 1) & ((preds > threshold).astype(int) == 0) |
|
|
|
|
|
false_positives_mean = preds[false_positives].mean() |
|
false_negatives_mean = preds[false_negatives].mean() |
|
|
|
return false_positives_mean, false_negatives_mean |
|
|
|
|
|
|
|
def train_model( |
|
protein2embedding: Dict[str, np.ndarray], |
|
cell2embedding: Dict[str, np.ndarray], |
|
smiles2fp: Dict[str, np.ndarray], |
|
train_df: pd.DataFrame, |
|
val_df: pd.DataFrame, |
|
test_df: Optional[pd.DataFrame] = None, |
|
hidden_dim: int = 768, |
|
batch_size: int = 128, |
|
learning_rate: float = 2e-5, |
|
beta1: float = 0.9, |
|
beta2: float = 0.999, |
|
eps: float = 1e-8, |
|
dropout: float = 0.2, |
|
max_epochs: int = 50, |
|
use_batch_norm: bool = False, |
|
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum', |
|
smote_k_neighbors: int = 5, |
|
apply_scaling: bool = True, |
|
active_label: str = 'Active', |
|
fast_dev_run: bool = False, |
|
use_logger: bool = True, |
|
logger_save_dir: str = '../logs', |
|
logger_name: str = 'protac', |
|
enable_checkpointing: bool = False, |
|
checkpoint_model_name: str = 'protac', |
|
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [], |
|
return_predictions: bool = False, |
|
shuffle_embedding_prob: float = 0.0, |
|
use_smote: bool = False, |
|
) -> tuple: |
|
""" Train a PROTAC model using the given datasets and hyperparameters. |
|
|
|
Args: |
|
protein2embedding (dict): A dictionary mapping protein identifiers to embeddings. |
|
cell2embedding (dict): A dictionary mapping cell line identifiers to embeddings. |
|
smiles2fp (dict): A dictionary mapping SMILES strings to fingerprints. |
|
train_df (pd.DataFrame): The training dataframe. |
|
val_df (pd.DataFrame): The validation dataframe. |
|
test_df (Optional[pd.DataFrame]): The test dataframe. |
|
hidden_dim (int): The hidden dimension of the model |
|
batch_size (int): The batch size |
|
learning_rate (float): The learning rate |
|
dropout (float): The dropout rate |
|
max_epochs (int): The maximum number of epochs |
|
use_batch_norm (bool): Whether to use batch normalization |
|
join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings |
|
smote_k_neighbors (int): The number of neighbors to use in SMOTE |
|
use_smote (bool): Whether to use SMOTE |
|
apply_scaling (bool): Whether to apply scaling to the embeddings |
|
active_label (str): The name of the active label. Default: 'Active' |
|
fast_dev_run (bool): Whether to run a fast development run (see PyTorch Lightning documentation) |
|
use_logger (bool): Whether to use a logger |
|
logger_save_dir (str): The directory to save the logs |
|
logger_name (str): The name of the logger |
|
enable_checkpointing (bool): Whether to enable checkpointing |
|
checkpoint_model_name (str): The name of the model for checkpointing |
|
disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles' |
|
return_predictions (bool): Whether to return predictions on the validation and test sets |
|
|
|
Returns: |
|
tuple: The trained model, the trainer, and the metrics over the validation and test sets. |
|
""" |
|
train_ds, val_ds, test_ds = get_datasets( |
|
train_df, |
|
val_df, |
|
test_df, |
|
protein2embedding, |
|
cell2embedding, |
|
smiles2fp, |
|
smote_k_neighbors=smote_k_neighbors, |
|
active_label=active_label, |
|
disabled_embeddings=disabled_embeddings, |
|
shuffle_embedding_prob=shuffle_embedding_prob, |
|
) |
|
|
|
smiles_emb_dim = train_ds.get_smiles_emb_dim() |
|
poi_emb_dim = train_ds.get_protein_emb_dim() |
|
e3_emb_dim = train_ds.get_protein_emb_dim() |
|
cell_emb_dim = train_ds.get_cell_emb_dim() |
|
|
|
loggers = [ |
|
pl.loggers.TensorBoardLogger( |
|
save_dir=logger_save_dir, |
|
version=logger_name, |
|
name=logger_name, |
|
), |
|
pl.loggers.CSVLogger( |
|
save_dir=logger_save_dir, |
|
version=logger_name, |
|
name=logger_name, |
|
), |
|
] |
|
callbacks = [ |
|
pl.callbacks.EarlyStopping( |
|
monitor='train_loss', |
|
patience=10, |
|
mode='min', |
|
verbose=False, |
|
), |
|
pl.callbacks.EarlyStopping( |
|
monitor='train_acc', |
|
patience=10, |
|
mode='max', |
|
verbose=False, |
|
), |
|
pl.callbacks.EarlyStopping( |
|
monitor='val_loss', |
|
patience=5, |
|
mode='min', |
|
verbose=False, |
|
), |
|
pl.callbacks.EarlyStopping( |
|
monitor='val_acc', |
|
patience=10, |
|
mode='max', |
|
verbose=False, |
|
), |
|
] |
|
if use_logger: |
|
callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='step')) |
|
if enable_checkpointing: |
|
callbacks.append(pl.callbacks.ModelCheckpoint( |
|
monitor='val_acc', |
|
mode='max', |
|
verbose=False, |
|
filename=checkpoint_model_name + '-{epoch}-{val_acc:.2f}-{val_roc_auc:.3f}', |
|
)) |
|
|
|
trainer = pl.Trainer( |
|
logger=loggers if use_logger else False, |
|
callbacks=callbacks, |
|
max_epochs=max_epochs, |
|
|
|
fast_dev_run=fast_dev_run, |
|
enable_model_summary=False, |
|
enable_checkpointing=enable_checkpointing, |
|
enable_progress_bar=False, |
|
devices=1, |
|
num_nodes=1, |
|
) |
|
extra_optim_params = { |
|
'betas': (beta1, beta2), |
|
'eps': eps, |
|
} |
|
model = PROTAC_Model( |
|
hidden_dim=hidden_dim, |
|
smiles_emb_dim=smiles_emb_dim, |
|
poi_emb_dim=poi_emb_dim, |
|
e3_emb_dim=e3_emb_dim, |
|
cell_emb_dim=cell_emb_dim, |
|
batch_size=batch_size, |
|
join_embeddings=join_embeddings, |
|
dropout=dropout, |
|
use_batch_norm=use_batch_norm, |
|
learning_rate=learning_rate, |
|
apply_scaling=apply_scaling, |
|
train_dataset=train_ds, |
|
val_dataset=val_ds, |
|
test_dataset=test_ds if test_df is not None else None, |
|
disabled_embeddings=disabled_embeddings, |
|
extra_optim_params=extra_optim_params, |
|
) |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
trainer.fit(model) |
|
metrics = {} |
|
|
|
train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m} |
|
metrics.update(train_metrics) |
|
|
|
val_metrics = trainer.validate(model, verbose=False)[0] |
|
val_metrics = {m: v for m, v in val_metrics.items() if 'val' in m} |
|
metrics.update(val_metrics) |
|
|
|
|
|
if test_df is not None: |
|
test_metrics = trainer.test(model, verbose=False)[0] |
|
test_metrics = {m: v for m, v in test_metrics.items() if 'test' in m} |
|
metrics.update(test_metrics) |
|
|
|
|
|
if return_predictions: |
|
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False) |
|
val_pred = trainer.predict(model, val_dl) |
|
val_pred = torch.concat(trainer.predict(model, val_dl)).squeeze() |
|
|
|
fp_mean, fn_mean = get_confidence_scores(val_ds, val_pred) |
|
metrics['val_false_positives_mean'] = fp_mean |
|
metrics['val_false_negatives_mean'] = fn_mean |
|
|
|
if test_df is not None: |
|
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False) |
|
test_pred = torch.concat(trainer.predict(model, test_dl)).squeeze() |
|
|
|
fp_mean, fn_mean = get_confidence_scores(test_ds, test_pred) |
|
metrics['test_false_positives_mean'] = fp_mean |
|
metrics['test_false_negatives_mean'] = fn_mean |
|
|
|
return model, trainer, metrics, val_pred, test_pred |
|
return model, trainer, metrics, val_pred |
|
return model, trainer, metrics |
|
|
|
|
|
def evaluate_model( |
|
model: PROTAC_Model, |
|
trainer: pl.Trainer, |
|
val_ds: PROTAC_Dataset, |
|
test_ds: Optional[PROTAC_Dataset] = None, |
|
batch_size: int = 128, |
|
) -> tuple: |
|
""" Evaluate a PROTAC model using the given datasets. """ |
|
ret = {} |
|
|
|
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False) |
|
val_metrics = trainer.validate(model, val_dl, verbose=False)[0] |
|
val_metrics = {m: v for m, v in val_metrics.items() if 'val' in m} |
|
|
|
val_pred = torch.cat(trainer.predict(model, val_dl)).squeeze() |
|
ret['val_metrics'] = val_metrics |
|
ret['val_pred'] = val_pred |
|
|
|
if test_ds is not None: |
|
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False) |
|
test_metrics = trainer.test(model, test_dl, verbose=False)[0] |
|
test_metrics = {m: v for m, v in test_metrics.items() if 'test' in m} |
|
|
|
test_pred = torch.cat(trainer.predict(model, test_dl)).squeeze() |
|
ret['test_metrics'] = test_metrics |
|
ret['test_pred'] = test_pred |
|
|
|
return ret |
|
|
|
|
|
def load_model( |
|
ckpt_path: str, |
|
) -> PROTAC_Model: |
|
""" Load a PROTAC model from a checkpoint. |
|
|
|
Args: |
|
ckpt_path (str): The path to the checkpoint. |
|
|
|
Returns: |
|
PROTAC_Model: The loaded model. |
|
""" |
|
|
|
|
|
model = PROTAC_Model.load_from_checkpoint( |
|
ckpt_path, |
|
map_location=torch.device('cpu') if not torch.cuda.is_available() else None, |
|
) |
|
|
|
|
|
|
|
|
|
return model.eval() |