| |
| """ |
| This module defines PyTorch Lightning modules for the Tahoeformer project. |
| It includes a base model class (`LitBaseModel`) and the main experimental model |
| (`LitEnformerSMILES`) which combines an Enformer-based DNA sequence model with |
| drug information (SMILES string processed into Morgan Fingerprints) and dose information |
| to predict gene expression. |
| |
| Key components: |
| - masked_mse: A utility loss function for Mean Squared Error that handles NaN targets. |
| - LitBaseModel: A base LightningModule providing common training, validation, test steps, |
| optimizer configuration, and basic metric logging hooks. |
| - LitEnformerSMILES: The primary model for predicting drug-induced gene expression changes, |
| using Enformer for DNA and Morgan fingerprints for drugs. |
| - MetricLogger: A PyTorch Lightning Callback for detailed logging of predictions. |
| """ |
|
|
| import pandas as pd |
| import os |
| import torch |
| import torch.nn as nn |
| import lightning.pytorch as pl |
| from enformer_pytorch.finetune import HeadAdapterWrapper |
| from enformer_pytorch import Enformer |
| from torchmetrics.regression import PearsonCorrCoef, R2Score |
| from warnings import warn |
| import wandb |
| import numpy as np |
|
|
| |
| def masked_mse(y_hat, y): |
| """ |
| Computes Mean Squared Error (MSE) while ignoring NaN values in the target tensor. |
| |
| Args: |
| y_hat (torch.Tensor): The predicted values. |
| y (torch.Tensor): The target values, which may contain NaNs. |
| |
| Returns: |
| torch.Tensor: A scalar tensor representing the masked MSE. Returns 0.0 if all targets are NaN. |
| """ |
| mask = torch.isnan(y) |
| if mask.all(): |
| return torch.tensor(0.0, device=y_hat.device, requires_grad=True) |
| mse = torch.mean((y[~mask] - y_hat[~mask])**2) |
| return mse |
|
|
| |
| class LitBaseModel(pl.LightningModule): |
| """ |
| A base PyTorch Lightning module providing common boilerplate for training and evaluation. |
| |
| This class implements a generic training/validation/test step, loss calculation using |
| `masked_mse`, optimizer configuration (AdamW), and hooks for accumulating outputs |
| for detailed metric logging via the `MetricLogger` callback. |
| |
| Derived classes are expected to implement the `forward` method. |
| |
| Hyperparameters: |
| learning_rate (float): The learning rate for the optimizer. |
| loss_alpha (float): A coefficient for the primary loss term (MSE). Useful if |
| additional loss terms were to be added. |
| weight_decay (float, optional): Weight decay for the AdamW optimizer. If None, |
| AdamW's internal default is used. |
| eval_gene_sets (dict, optional): A dictionary where keys are set names (e.g., 'oncogenes') |
| and values are lists of gene IDs. Used by `MetricLogger` |
| to compute metrics for specific gene subsets. |
| """ |
| def __init__(self, learning_rate=5e-6, loss_alpha=1.0, weight_decay=None, |
| eval_gene_sets=None): |
| """ |
| Initializes the LitBaseModel. |
| |
| Args: |
| learning_rate (float, optional): Learning rate. Defaults to 5e-6. |
| loss_alpha (float, optional): Alpha for MSE loss. Defaults to 1.0. |
| weight_decay (float, optional): Weight decay for AdamW. If None, uses optimizer default. |
| Defaults to None. |
| eval_gene_sets (dict, optional): Dictionary of gene sets for targeted evaluation. |
| Keys are names, values are lists of gene IDs. |
| Defaults to None. |
| """ |
| super().__init__() |
| self.save_hyperparameters() |
| self.learning_rate = learning_rate |
| self.loss_alpha = loss_alpha |
| self.weight_decay = weight_decay |
| self.eval_gene_sets = eval_gene_sets if eval_gene_sets else {} |
|
|
| |
| self.epoch_outputs = [] |
|
|
| def loss_fn(self, y_hat, y): |
| """ |
| Calculates the loss for the model. |
| |
| Currently uses `masked_mse` scaled by `self.loss_alpha`. |
| |
| Args: |
| y_hat (torch.Tensor): Predicted values from the model. |
| y (torch.Tensor): Ground truth target values. |
| |
| Returns: |
| torch.Tensor: The computed loss value. |
| """ |
| mse_term = masked_mse(y_hat, y) |
| |
| return self.loss_alpha * mse_term |
|
|
| def _common_step(self, batch, batch_idx, step_type): |
| """ |
| A common step for training, validation, and testing. |
| |
| This method unpacks the batch, performs a forward pass, calculates the loss, |
| logs the loss, and accumulates outputs for epoch-level metric calculation |
| (for validation and test steps). |
| |
| Args: |
| batch: The batch of data from the DataLoader. Expected to be a tuple containing |
| DNA sequence, Morgan fingerprints, dose, target expression, |
| and metadata (gene_id, drug_id, cell_line). |
| batch_idx (int): The index of the current batch. |
| step_type (str): A string indicating the type of step ('train', 'val', or 'test'). |
| |
| Returns: |
| torch.Tensor: The loss for the current batch. |
| """ |
| |
| |
| dna_seq, morgan_fingerprints, dose, target_expression, gene_id, drug_id, cell_line = batch |
| |
| y_hat = self(dna_seq, morgan_fingerprints, dose) |
|
|
| loss = self.loss_fn(y_hat, target_expression) |
| self.log(f'{step_type}_loss', loss, batch_size=target_expression.shape[0], on_step=(step_type=='train' and False), on_epoch=True, prog_bar=(step_type!='train')) |
| |
| if step_type != 'train': |
| |
| batch_size = target_expression.shape[0] |
| for i in range(batch_size): |
| item_data = { |
| 'pred': y_hat[i].detach(), |
| 'target': target_expression[i].detach(), |
| 'gene_id': gene_id[i], |
| 'drug_id': drug_id[i], |
| 'cell_line': cell_line[i], |
| 'rank': self.trainer.global_rank |
| } |
| self.epoch_outputs.append(item_data) |
| return loss |
|
|
| def training_step(self, batch, batch_idx): |
| """PyTorch Lightning training step. Calls `_common_step`.""" |
| return self._common_step(batch, batch_idx, 'train') |
|
|
| def validation_step(self, batch, batch_idx, dataloader_idx=0): |
| """PyTorch Lightning validation step. Calls `_common_step`.""" |
| return self._common_step(batch, batch_idx, 'val') |
|
|
| def test_step(self, batch, batch_idx, dataloader_idx=0): |
| """PyTorch Lightning test step. Calls `_common_step`.""" |
| return self._common_step(batch, batch_idx, 'test') |
| |
| def on_validation_epoch_start(self): |
| """Clears accumulated outputs at the start of each validation epoch.""" |
| self.epoch_outputs = [] |
| |
| def on_test_epoch_start(self): |
| """Clears accumulated outputs at the start of each test epoch.""" |
| self.epoch_outputs = [] |
|
|
| def configure_optimizers(self): |
| """ |
| Configures the optimizer for the model. |
| |
| Uses AdamW with the specified learning rate and weight decay. |
| |
| Returns: |
| torch.optim.Optimizer: The configured AdamW optimizer. |
| """ |
| if self.weight_decay is None: |
| optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) |
| else: |
| optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) |
| return optimizer |
|
|
| |
| class LitEnformerSMILES(LitBaseModel): |
| """ |
| A PyTorch Lightning module that combines genomic sequence information (via Enformer) |
| with drug chemical structure (represented by Morgan fingerprints) and drug dose |
| to predict gene expression changes. |
| |
| The model architecture consists of three main branches: |
| 1. DNA Module: Uses a pre-trained Enformer model (with an adapted head) to extract |
| features from a one-hot encoded DNA sequence centered around a gene's TSS. |
| 2. Drug Module: Uses pre-computed Morgan fingerprints as the drug representation. |
| 3. Dose Module: Directly uses the numerical dose value. |
| |
| Features from these three branches are concatenated and passed through a multi-layer |
| fusion head (MLP with ReLU, BatchNorm, Dropout) to produce the final prediction |
| of gene expression. |
| |
| Inherits common training and evaluation logic from `LitBaseModel`. |
| """ |
| def __init__(self, |
| enformer_model_name: str = 'EleutherAI/enformer-official-rough', |
| enformer_target_length: int = -1, |
| num_output_tracks_enformer_head: int = 1, |
| morgan_fingerprint_dim: int = 2048, |
| dose_input_dim: int = 1, |
| fusion_hidden_dim: int = 256, |
| final_output_tracks: int = 1, |
| learning_rate=5e-6, |
| loss_alpha=1.0, |
| weight_decay=None, |
| eval_gene_sets=None): |
| """ |
| Initializes the LitEnformerSMILES (or LitEnformerMorgan) model. |
| |
| Args: |
| enformer_model_name (str, optional): Name or path of the pre-trained Enformer model. |
| enformer_target_length (int, optional): Target length for Enformer's internal pooling. |
| num_output_tracks_enformer_head (int, optional): Output features from Enformer head. |
| morgan_fingerprint_dim (int, optional): Dimensionality of the Morgan fingerprint vector |
| (e.g., 2048 for ECFP4). Defaults to 2048. |
| dose_input_dim (int, optional): Dimensionality of the drug dose input. Defaults to 1. |
| fusion_hidden_dim (int, optional): Hidden dimension for the fusion MLP. Defaults to 256. |
| final_output_tracks (int, optional): Number of final output values. Defaults to 1. |
| learning_rate (float, optional): Learning rate. Defaults to 5e-6. |
| loss_alpha (float, optional): Weight for MSE loss. Defaults to 1.0. |
| weight_decay (float, optional): Weight decay. Defaults to None. |
| eval_gene_sets (dict, optional): Gene sets for targeted evaluation. Defaults to None. |
| """ |
| super().__init__(learning_rate, loss_alpha, weight_decay, eval_gene_sets) |
| self.save_hyperparameters( |
| "enformer_model_name", "enformer_target_length", |
| "num_output_tracks_enformer_head", "morgan_fingerprint_dim", |
| "dose_input_dim", "fusion_hidden_dim", "final_output_tracks", |
| "learning_rate", "loss_alpha", "weight_decay" |
| ) |
|
|
| |
| enformer_pretrained = Enformer.from_pretrained( |
| self.hparams.enformer_model_name, |
| target_length=self.hparams.enformer_target_length |
| ) |
| self.dna_module = HeadAdapterWrapper( |
| enformer=enformer_pretrained, |
| num_tracks=self.hparams.num_output_tracks_enformer_head, |
| post_transformer_embed=False, |
| output_activation=nn.Identity() |
| ) |
|
|
| |
| |
| |
|
|
| |
| |
| fusion_input_dim = self.hparams.num_output_tracks_enformer_head + self.hparams.morgan_fingerprint_dim + self.hparams.dose_input_dim |
| self.fusion_head = nn.Sequential( |
| nn.Linear(fusion_input_dim, self.hparams.fusion_hidden_dim), |
| nn.ReLU(), |
| nn.BatchNorm1d(self.hparams.fusion_hidden_dim), |
| nn.Dropout(0.25), |
| nn.Linear(self.hparams.fusion_hidden_dim, self.hparams.fusion_hidden_dim // 2), |
| nn.ReLU(), |
| nn.BatchNorm1d(self.hparams.fusion_hidden_dim // 2), |
| nn.Dropout(0.1), |
| nn.Linear(self.hparams.fusion_hidden_dim // 2, self.hparams.final_output_tracks) |
| ) |
|
|
| def forward(self, dna_seq, morgan_fingerprints, dose): |
| """ |
| Defines the forward pass of the LitEnformerSMILES model using Morgan Fingerprints. |
| |
| Args: |
| dna_seq (torch.Tensor): Batch of one-hot encoded DNA sequences. |
| Shape: (batch_size, sequence_length, 4). |
| morgan_fingerprints (torch.Tensor): Batch of pre-computed Morgan fingerprint vectors. |
| Shape: (batch_size, morgan_fingerprint_dim). |
| dose (torch.Tensor): Batch of drug dose values. |
| Shape: (batch_size, dose_input_dim). |
| |
| Returns: |
| torch.Tensor: The model's prediction. Shape: (batch_size, final_output_tracks). |
| """ |
| |
| dna_out_intermediate = self.dna_module(dna_seq, freeze_enformer=False) |
| center_seq_idx = dna_out_intermediate.shape[1] // 2 |
| dna_features = dna_out_intermediate[:, center_seq_idx, :] |
| |
| |
| |
| smiles_features = morgan_fingerprints |
|
|
| |
| if dose.ndim == 1: |
| dose = dose.unsqueeze(-1) |
| |
| |
| combined_features = torch.cat([dna_features, smiles_features, dose], dim=1) |
| prediction = self.fusion_head(combined_features) |
| return prediction |
|
|
| |
| class MetricLogger(pl.Callback): |
| """ |
| A PyTorch Lightning Callback for comprehensive metric calculation and logging. |
| |
| This callback accumulates predictions and targets during validation and test epochs. |
| At the end of these epochs, it: |
| 1. Processes the accumulated outputs into a pandas DataFrame. |
| 2. Saves the raw predictions and targets for the epoch to a CSV file. |
| 3. Logs a sample of these raw predictions as a W&B Table if WandbLogger is used. |
| 4. Calculates overall performance metrics (MSE, Pearson, R2) for the epoch. |
| 5. If `eval_gene_sets` are provided in the LightningModule, calculates metrics for these specific gene subsets. |
| 6. Calculates metrics per cell line if 'cell_line' information is available in the outputs. |
| 7. Logs all calculated metrics to the LightningModule's logger. |
| |
| Attributes: |
| save_dir_prefix (str): Prefix for the directory where metric CSVs will be saved. |
| current_epoch_data (list): List to accumulate dictionaries of pred/target/metadata per item. |
| """ |
| def __init__(self, save_dir_prefix="results"): |
| """ |
| Initializes the MetricLogger callback. |
| |
| Args: |
| save_dir_prefix (str, optional): Directory prefix for saving metrics files. |
| Defaults to "results". |
| """ |
| super().__init__() |
| self.save_dir_prefix = save_dir_prefix |
| self.current_epoch_data = [] |
|
|
| def _process_epoch_outputs(self, pl_module, stage): |
| """ |
| Processes the raw outputs collected during an epoch into a pandas DataFrame. |
| |
| Converts tensor data for 'pred' and 'target' columns to NumPy/Python native types. |
| |
| Args: |
| pl_module (pl.LightningModule): The LightningModule instance. |
| stage (str): The current stage (e.g., "validation", "test"). |
| |
| Returns: |
| pd.DataFrame: A DataFrame containing the processed epoch outputs. |
| Returns an empty DataFrame if no outputs were collected. |
| """ |
| if not hasattr(pl_module, 'epoch_outputs') or not pl_module.epoch_outputs: |
| warn(f"No outputs collected (pl_module.epoch_outputs is missing or empty) during {stage} epoch for MetricLogger.") |
| return pd.DataFrame() |
|
|
| df = pd.DataFrame(pl_module.epoch_outputs) |
| |
| for col in ['pred', 'target']: |
| if col in df.columns and not df[col].empty: |
| if isinstance(df[col].iloc[0], torch.Tensor): |
| df[col] = df[col].apply(lambda x: x.cpu().float().numpy().item() if x.numel() == 1 else x.cpu().float().numpy()) |
| return df |
|
|
| def on_validation_epoch_end(self, trainer, pl_module): |
| """Hook called at the end of the validation epoch.""" |
| if hasattr(pl_module, 'epoch_outputs') and pl_module.epoch_outputs: |
| self.current_epoch_data = self._process_epoch_outputs(pl_module, "validation") |
| if not self.current_epoch_data.empty: |
| self._log_and_save_metrics(trainer, pl_module, "validation") |
| else: |
| warn("MetricLogger: pl_module.epoch_outputs not found or empty at on_validation_epoch_end.") |
|
|
| def on_test_epoch_end(self, trainer, pl_module): |
| """Hook called at the end of the test epoch.""" |
| if hasattr(pl_module, 'epoch_outputs') and pl_module.epoch_outputs: |
| self.current_epoch_data = self._process_epoch_outputs(pl_module, "test") |
| if not self.current_epoch_data.empty: |
| self._log_and_save_metrics(trainer, pl_module, "test") |
| else: |
| warn("MetricLogger: pl_module.epoch_outputs not found or empty at on_test_epoch_end.") |
|
|
|
|
| def _log_and_save_metrics(self, trainer, pl_module, stage): |
| """ |
| Calculates, logs, and saves metrics for the current stage and epoch. |
| |
| Args: |
| trainer (pl.Trainer): The PyTorch Lightning Trainer instance. |
| pl_module (pl.LightningModule): The LightningModule instance. |
| stage (str): The current stage (e.g., "validation", "test"). |
| """ |
| epoch = trainer.current_epoch if trainer.current_epoch is not None else -1 |
| save_dir = getattr(pl_module.hparams, 'save_dir', |
| os.path.join(self.save_dir_prefix, f"run_{trainer.logger.version if trainer.logger else 'local'}")) |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| raw_preds_path = os.path.join(save_dir, f"{stage}_predictions_epoch_{epoch}.csv") |
| self.current_epoch_data.to_csv(raw_preds_path, index=False) |
| |
| if trainer.logger and hasattr(trainer.logger, 'experiment') and isinstance(trainer.logger.experiment, wandb.sdk.wandb_run.Run): |
| try: |
| trainer.logger.experiment.log({f"{stage}_raw_predictions_epoch_{epoch}": wandb.Table(dataframe=self.current_epoch_data.head(1000))}) |
| except Exception as e: |
| warn(f"MetricLogger: Failed to log raw predictions table to W&B: {e}") |
|
|
| overall_metrics = self._calculate_metrics_for_group(self.current_epoch_data, pl_module.device) |
| if overall_metrics: |
| pl_module.log_dict({f"{stage}_{k}_epoch": v for k, v in overall_metrics.items()}, sync_dist=True) |
|
|
| if hasattr(pl_module, 'eval_gene_sets') and pl_module.eval_gene_sets and isinstance(pl_module.eval_gene_sets, dict) and 'gene_id' in self.current_epoch_data.columns: |
| for split_name, gene_list in pl_module.eval_gene_sets.items(): |
| if not gene_list: continue |
| split_df = self.current_epoch_data[self.current_epoch_data['gene_id'].isin(gene_list)] |
| if not split_df.empty: |
| split_metrics = self._calculate_metrics_for_group(split_df, pl_module.device) |
| if split_metrics: |
| pl_module.log_dict({f"{stage}_{split_name}_genes_{k}_epoch": v for k, v in split_metrics.items()}, sync_dist=True) |
| |
| if 'cell_line' in self.current_epoch_data.columns: |
| for cell_line, group_df in self.current_epoch_data.groupby('cell_line'): |
| cl_metrics = self._calculate_metrics_for_group(group_df, pl_module.device) |
| if cl_metrics: |
| pl_module.log_dict({f"{stage}_{cell_line}_cell_line_{k}_epoch": v for k,v in cl_metrics.items()}, sync_dist=True) |
|
|
|
|
| def _calculate_metrics_for_group(self, df_group, device): |
| """ |
| Calculates regression metrics (MSE, Pearson, R2) for a given group of predictions. |
| |
| Args: |
| df_group (pd.DataFrame): DataFrame containing 'pred' and 'target' columns for the group. |
| device (torch.device): The device to perform calculations on. |
| |
| Returns: |
| dict: A dictionary of calculated metrics (mse, pearson, r2). Returns empty if data is insufficient. |
| """ |
| if df_group.empty or 'pred' not in df_group.columns or 'target' not in df_group.columns: |
| return {} |
|
|
| preds_np = np.array(df_group['pred'].tolist(), dtype=np.float32) |
| targets_np = np.array(df_group['target'].tolist(), dtype=np.float32) |
|
|
| preds = torch.tensor(preds_np).to(device) |
| targets = torch.tensor(targets_np).to(device) |
|
|
| if preds.ndim == 1: |
| preds = preds.squeeze() |
| targets = targets.squeeze() |
| |
| if preds.numel() == 0 or targets.numel() == 0 or preds.shape != targets.shape : |
| warn(f"Skipping metrics calculation for a group due to mismatched or empty preds/targets. Pred shape: {preds.shape}, Target shape: {targets.shape}") |
| return {} |
| |
| mse_val_tensor = masked_mse(preds.unsqueeze(-1) if preds.ndim==1 else preds, |
| targets.unsqueeze(-1) if targets.ndim==1 else targets) |
| calculated_metrics = {'mse': mse_val_tensor.item()} |
|
|
| if preds.numel() < 2: |
| warn(f"Skipping Pearson/R2 for a group with < 2 samples. Found {preds.numel()} samples. Only MSE will be reported.") |
| return calculated_metrics |
|
|
| preds_for_corr = preds.squeeze() |
| targets_for_corr = targets.squeeze() |
|
|
| if preds_for_corr.shape != targets_for_corr.shape or preds_for_corr.ndim > 1 and preds_for_corr.shape[1] >1: |
| warn(f"Skipping Pearson/R2 due to incompatible shapes after squeeze for correlation. Pred: {preds_for_corr.shape}, Target: {targets_for_corr.shape}") |
| return calculated_metrics |
|
|
| try: |
| pearson_fn = PearsonCorrCoef().to(device) |
| pearson_val = pearson_fn(preds_for_corr, targets_for_corr) |
| calculated_metrics['pearson'] = pearson_val.item() |
| except Exception as e: |
| warn(f"Could not compute Pearson Correlation: {e}. Preds shape: {preds_for_corr.shape}, Targets shape: {targets_for_corr.shape}") |
|
|
| try: |
| r2_fn = R2Score().to(device) |
| r2_val = r2_fn(preds_for_corr, targets_for_corr) |
| calculated_metrics['r2'] = r2_val.item() |
| except Exception as e: |
| warn(f"Could not compute R2 Score: {e}. Preds shape: {preds_for_corr.shape}, Targets shape: {targets_for_corr.shape}") |
|
|
| return calculated_metrics |