import os, sys import argparse import Bio.SeqIO as SeqIO from loguru import logger import pandas as pd import numpy as np import matplotlib.pyplot as plt import dotenv import tqdm import h5py from sklearn.metrics import average_precision_score, roc_auc_score import math import torch import torchvision.ops.focal_loss as focal_loss from torch.utils.data import Dataset, random_split import torch.nn as nn import torch.nn.functional as F import lightning as L from torch.utils.data.sampler import Sampler from transformers import T5Tokenizer, T5EncoderModel import torchvision.ops.focal_loss as focal_loss MAX_LOGIT = 5.23 def sigmoid(x): return 1/(1 + np.exp(-x)) def embed_sequences(sequences, gpu=True): tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_bfd', do_lower_case=False) gpu_model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_bfd") cpu_model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_bfd") gpu_model.eval() cpu_model.eval() sequence_as_aas = [' '.join(list(seq[1])) for seq in sequences] ids = tokenizer.batch_encode_plus(sequence_as_aas, add_special_tokens=True, padding=True) input_ids = torch.tensor(ids['input_ids']) attention_mask = torch.tensor(ids['attention_mask']) if gpu: gpu = torch.device('cuda:0') gpu_model.to(gpu) try: input_ids = input_ids.to(gpu) attention_mask = attention_mask.to(gpu) with torch.no_grad(): embeddings = gpu_model(input_ids, attention_mask=attention_mask) except torch.OutOfMemoryError: input_ids = input_ids.to('cpu') attention_mask = attention_mask.to('cpu') with torch.no_grad(): embeddings = cpu_model(input_ids, attention_mask=attention_mask) else: input_ids = input_ids.to('cpu') attention_mask = attention_mask.to('cpu') with torch.no_grad(): embeddings = cpu_model(input_ids, attention_mask=attention_mask) embeddings = embeddings.last_hidden_state.cpu() reps = { } for i, sequence in enumerate(sequences): reps[sequence[0]] = {} sequence_embedding = embeddings[i][:len(sequence[1])] site_representations = extract_site_representations(sequence[1], sequence_embedding) for position, site_representation in site_representations.items(): reps[sequence[0]][position] = site_representation return reps MODIFICATIONS = { 'methylation': 0, 'acetylation': 1, 'ubiquitination': 2, 'sumoylation': 3 } def pr_at_re(x_hat, x_true, recall=0.5): df = pd.DataFrame({ 'score': x_hat, 'label': x_true }) thresholds = [] recalls = [] index_past_threshold = -1 for i, threshold in enumerate(np.linspace(df.score.min(), df.score.max(), num=1000)): thresholds.append(threshold) tp = len(df[(df['score'] >= threshold) & (df['label'] == 1)]) fp = len(df[(df['score'] >= threshold) & (df['label'] == 0)]) fn = len(df[(df['score'] < threshold) & (df['label'] == 1)]) re = tp / (tp + fn) recalls.append(re) if re < recall: index_past_threshold = i break if index_past_threshold == -1: return 0 t = thresholds[index_past_threshold - 1] + ((recall - recalls[index_past_threshold - 1]) * (thresholds[index_past_threshold] - thresholds[index_past_threshold - 1]) / (recalls[index_past_threshold] - recalls[index_past_threshold - 1])) tp = len(df[(df['score'] >= threshold) & (df['label'] == 1)]) fp = len(df[(df['score'] >= threshold) & (df['label'] == 0)]) fn = len(df[(df['score'] < threshold) & (df['label'] == 1)]) re = tp / (tp + fn) pr = tp / (tp + fp) return pr class ClassificationHead(nn.Module): def __init__(self, d_model, window_size, layer_widths, dropout=0.15): super().__init__() layers = [] input_dims = int(d_model * window_size) for i in range(len(layer_widths)): layers.append(nn.Sequential( nn.Linear(input_dims, layer_widths[i]), nn.ReLU(), nn.Dropout(dropout), )) input_dims = layer_widths[i] layers.append(nn.Sequential( nn.Linear(input_dims, 1), nn.ReLU(), )) self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) class MultitaskSampler(Sampler): def __init__(self, data, batch_size) -> None: self.data = data.reset_index(drop=True) self.batch_size = batch_size self.methylation_indices = np.array(self.data[self.data['modification'] == 'methylation'].index, dtype=int) self.acetylation_indices = np.array(self.data[self.data['modification'] == 'acetylation'].index, dtype=int) self.ubiquitination_indices = np.array(self.data[self.data['modification'] == 'ubiquitination'].index, dtype=int) self.sumoylation_indices = np.array(self.data[self.data['modification'] == 'sumoylation'].index, dtype=int) self.num_methylation_batches = (len(self.methylation_indices) + self.batch_size - 1) // self.batch_size self.num_acetylation_batches = (len(self.acetylation_indices) + self.batch_size - 1) // self.batch_size self.num_ubiquitination_batches = (len(self.ubiquitination_indices) + self.batch_size - 1) // self.batch_size self.num_sumoylation_batches = (len(self.sumoylation_indices) + self.batch_size - 1) // self.batch_size def __len__(self) -> int: # number of batches to be sampled return self.num_methylation_batches + self.num_acetylation_batches + self.num_ubiquitination_batches + self.num_sumoylation_batches def __iter__(self): # Group into batches where all instances are of the same task # and yield the batches (steps) methylation_indices = np.copy(self.methylation_indices) acetylation_indices = np.copy(self.acetylation_indices) ubiquitination_indices = np.copy(self.ubiquitination_indices) sumoylation_indices = np.copy(self.sumoylation_indices) np.random.shuffle(methylation_indices) np.random.shuffle(acetylation_indices) np.random.shuffle(ubiquitination_indices) np.random.shuffle(sumoylation_indices) methylation_batches = torch.chunk(torch.IntTensor(methylation_indices), self.num_methylation_batches) acetylation_batches = torch.chunk(torch.IntTensor(acetylation_indices), self.num_acetylation_batches) ubiquitination_batches = torch.chunk(torch.IntTensor(ubiquitination_indices), self.num_ubiquitination_batches) sumoylation_batches = torch.chunk(torch.IntTensor(sumoylation_indices), self.num_sumoylation_batches) for batch in methylation_batches + acetylation_batches + ubiquitination_batches + sumoylation_batches: yield batch.tolist() class KmeDataset(Dataset): def __init__(self, embeddings, dataset, window_size): self.sites = dataset self.embeddings = embeddings self.labels = list(self.sites['label']) self.window_size = window_size def __len__(self): return len(self.labels) def __getitem__(self, idx): # Use zero padding (i.e. embedding with all zeros) when the chain is too close to the end of the chain protein = self.sites.iloc[idx]['protein'] position = self.sites.iloc[idx]['uniprot_position'] modification = self.sites.iloc[idx]['modification'] position_index = position - 1 representation = torch.Tensor(np.array(self.embeddings[protein])) protein_len = representation.shape[0] representation_dim = representation.shape[1] half_window = int((self.window_size - 1)/2) padding_left_required = -1 * min(0, position_index - half_window) padding_right_required = max(0, position_index + half_window - protein_len + 1) if padding_left_required > 0: representation = torch.cat([torch.zeros((padding_left_required, representation_dim)), representation], dim=0) if padding_right_required > 0: representation = torch.cat([representation, torch.zeros((padding_right_required, representation_dim))], dim=0) representation = representation[position_index + padding_left_required - half_window: position_index + padding_left_required + half_window + 1] # Prepend task token (IN THIS IMPLEMENTATION OF MULTITASK LEARNING, WE DO NOT DO THIS) # representation = torch.cat([TOKEN_VALUE[modification] * torch.ones(1, representation_dim), representation]) label = float(self.labels[idx]) return modification, representation, label def extract_site_representations(sequence, sequence_embedding, window_size=31): lysine_indices = [i for i, aa in enumerate(sequence) if aa == 'K'] representations = {} for position_index in lysine_indices: protein_len = sequence_embedding.shape[0] representation_dim = sequence_embedding.shape[1] half_window = int((window_size - 1)/2) padding_left_required = -1 * min(0, position_index - half_window) padding_right_required = max(0, position_index + half_window - protein_len + 1) representation = sequence_embedding.clone().detach() if padding_left_required > 0: representation = torch.cat([torch.zeros((padding_left_required, representation_dim)), representation], dim=0) if padding_right_required > 0: representation = torch.cat([representation, torch.zeros((padding_right_required, representation_dim))], dim=0) representation = representation[position_index + padding_left_required - half_window: position_index + padding_left_required + half_window + 1] if position_index == 47: print(representation) representations[position_index + 1] = representation return representations class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) # do i need to modify this? pe = torch.zeros(1, max_len, d_model) pe[0, :, 0::2] = torch.sin(position * div_term) # modified axes here, was pe[:, 0, 0::2] pe[0, :, 1::2] = torch.cos(position * div_term) # modified axes here, was pe[:, 0, 1::2] self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(0)] return self.dropout(x) #@torch.compile class Model(L.LightningModule): def __init__(self, hparams): super().__init__() self.save_hyperparameters() self.lr_step_size = hparams['lr_step_size'] self.lr_gamma = hparams['lr_gamma'] self.learning_rate = hparams['learning_rate'] self.batch_size = hparams['batch_size'] self.methylation_loss_factor = hparams['methylation_loss_factor'] self.loss_function = eval(hparams['loss_function']) self.training_step_outputs = [] self.validation_step_outputs = [] self.embedding_layer = nn.Sequential( nn.Linear(hparams['input_dims'], hparams['embedder_width']), nn.Dropout(hparams['dropout']), nn.ReLU(), nn.Linear(hparams['embedder_width'], hparams['d_model']), nn.Dropout(hparams['dropout']), nn.ReLU(), ) self.positional_encoder = PositionalEncoding(hparams['d_model'], dropout=hparams['dropout'], max_len=hparams['window_size']) transformer_layers = [] for i in range(len(hparams['n_heads'])): transformer_layers.append( nn.TransformerEncoderLayer( hparams["d_model"], hparams["n_heads"][i], dropout=hparams["dropout"], activation=nn.ReLU() ), ) self.transformer_layers = nn.Sequential(*transformer_layers) self.flatten = nn.Flatten() self.methylation_head = ClassificationHead(hparams['d_model'], hparams['window_size'], hparams['hidden_layer_widths'], dropout=hparams['dropout']) self.acetylation_head = ClassificationHead(hparams['d_model'], hparams['window_size'], hparams['hidden_layer_widths'], dropout=hparams['dropout']) self.ubiquitination_head = ClassificationHead(hparams['d_model'], hparams['window_size'], hparams['hidden_layer_widths'], dropout=hparams['dropout']) self.sumoylation_head = ClassificationHead(hparams['d_model'], hparams['window_size'], hparams['hidden_layer_widths'], dropout=hparams['dropout']) def forward(self, x, task): x = self.embedding_layer(x) x = self.positional_encoder(x) x = self.transformer_layers(x) x = self.flatten(x) if task == 'methylation': logits = self.methylation_head(x) elif task == 'acetylation': logits = self.acetylation_head(x) elif task == 'ubiquitination': logits = self.ubiquitination_head(x) elif task == 'sumoylation': logits = self.sumoylation_head(x) else: raise f"Invalid task `{task}` provided." return logits def training_step(self, batch, batch_idx): tasks, x, y = batch task = tasks[0] logits = self.forward(x, task) y_cpu = y.cpu().detach().numpy() y_hat = logits.squeeze(-1) y_hat_cpu = y_hat.cpu().detach().numpy() loss = self.loss_function(y_hat, y).cpu() if task == 'methylation': loss *= self.methylation_loss_factor self.log('metrics/batch/loss', loss) metrics = { 'loss': loss, 'y': y_cpu, 'y_hat': y_hat_cpu } self.training_step_outputs.append(metrics) return metrics def on_training_epoch_end(self): loss = np.array([]) y = np.array([]) y_hat = np.array([]) for results_dict in self.training_step_outputs: loss = np.append(loss, results_dict["loss"]) y = np.append(y, results_dict["y"]) y_hat = np.append(y_hat, results_dict["y_hat"]) auprc = average_precision_score(y, y_hat) self.log("metrics/epoch/loss", loss.mean()) self.log("metrics/epoch/auprc", auprc) self.training_step_outputs.clear() def validation_step(self, batch, batch_idx): task, x, y = batch logits = self.forward(x, task[0]) y_cpu = y.cpu().detach().numpy() y_hat = logits.squeeze(-1) y_hat_cpu = y_hat.cpu().detach().numpy() loss = self.loss_function(y_hat, y).cpu() metrics = { 'loss': loss, 'y': y_cpu, 'y_hat': y_hat_cpu } self.validation_step_outputs.append(metrics) def on_validation_epoch_end(self): loss = np.array([]) y = np.array([]) y_hat = np.array([]) for results_dict in self.validation_step_outputs: loss = np.append(loss, results_dict["loss"]) y = np.append(y, results_dict["y"]) y_hat = np.append(y_hat, results_dict["y_hat"]) auprc = average_precision_score(y, y_hat) auroc = roc_auc_score(y, y_hat) prat50re = pr_at_re(y_hat, y, recall=0.5) self.logger.experiment["val/loss"] = loss.mean() self.logger.experiment["val/auprc"] = auprc self.logger.experiment["val/auroc"] = auroc self.logger.experiment["val/prat50re"] = prat50re self.log("val/loss", loss.mean()) self.log("val/auprc", auprc) self.log("val/auroc", auroc) self.log("val/prat50re", prat50re) self.validation_step_outputs.clear() def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) lr_scheduler = { 'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.lr_step_size, gamma=self.lr_gamma), 'name': 'linear_scheduler' } return [optimizer], [lr_scheduler] if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-i', '--input', required=True, type=str, help='Path to a FASTA file with sequences.') parser.add_argument('-w', '--weights', required=True, type=str, help='Path to model checkpoints.') parser.add_argument('-o', '--output', required=True, type=str, help='Path to output file.') args = vars(parser.parse_args()) # Load sequences sequences = [(x.id, str(x.seq)) for x in SeqIO.parse(args['input'], 'fasta')] # Load model model = Model.load_from_checkpoint(args['weights']) model.eval() # Perform inference if torch.cuda.is_available(): model.to('cuda') scores = [] logger.info("Embedding the sequences on the GPU...") site_embeddings = embed_sequences(sequences, gpu=True) logger.info("Making predictions on the GPU...") for protein, sites in tqdm.tqdm(site_embeddings.items()): try: for position, representation in sites.items(): emb = representation.to('cuda') with torch.no_grad(): logit = model(emb, 'methylation').cpu().squeeze(-2).detach().numpy()[0] scores.append({ 'protein': protein, 'position': position, #'logit': float(logit), #'uncorrected_score': float(sigmoid(logit)), 'score': float(sigmoid(logit - MAX_LOGIT)) }) except Exception as e: print(e) print(f"Could not do {protein}... skipping.") continue df = pd.DataFrame(scores) df.to_csv(args['output'], index=False) else: scores = [] logger.info("Embedding the sequences on the CPU...") site_embeddings = embed_sequences(sequences, gpu=False) logger.info("Making predictions on the CPU...") for protein, sites in tqdm.tqdm(site_embeddings.items()): try: for position, representation in sites.items(): emb = representation with torch.no_grad(): logit = model(emb, 'methylation').squeeze(-2)[0] scores.append({ 'protein': protein, 'position': position, #'logit': float(logit), #'uncorrected_score': float(sigmoid(logit)), 'score': float(sigmoid(logit - MAX_LOGIT)) }) except Exception as e: print(e) print(f"Could not do {protein}... skipping.") continue df = pd.DataFrame(scores) df.to_csv(args['output'], index=False)