import torch from torch.utils.data import DataLoader, Subset from torch.optim import AdamW import torch.nn.functional as F from datasets import load_from_disk import esm import numpy as np import os from transformers import AutoTokenizer from torch.optim.lr_scheduler import CosineAnnealingLR from transformers import get_linear_schedule_with_warmup from tqdm import tqdm import pytorch_lightning as pl max_epochs = 30 batch_size = 4 lr = 1e-4 dropout = 0.1 margin = 10 vhse8_values = { 'A': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48], 'R': [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83], 'N': [-0.99, 0.00, 0.69, -0.37, -0.55, 0.85, 0.73, -0.80], 'D': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56], 'C': [0.18, -1.67, -0.21, 0.00, 1.20, -1.61, -0.19, -0.41], 'Q': [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41], 'E': [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.36], 'G': [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, -1.34, 1.10], 'H': [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65], 'I': [1.27, 0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13], 'L': [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62], 'K': [-1.17, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13, -0.01], 'M': [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68], 'F': [1.52, 0.61, 0.95, -0.16, 0.25, 0.28, -1.33, -0.65], 'P': [0.22, -0.17, -0.50, -0.05, 0.01, -1.34, 0.19, 3.56], 'S': [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11], 'T': [-0.34, -0.51, -0.55, -1.06, 0.01, -0.01, -0.79, 0.39], 'W': [1.50, 2.06, 1.79, 0.75, 0.75, 0.13, -1.06, -0.85], 'Y': [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52], 'V': [0.76, -0.92, 0.17, -1.91, 0.22, -1.40, -0.24, -0.03], } aa_to_idx = {'A': 5, 'R': 10, 'N': 17, 'D': 13, 'C': 23, 'Q': 16, 'E': 9, 'G': 6, 'H': 21, 'I': 12, 'L': 4, 'K': 15, 'M': 20, 'F': 18, 'P': 14, 'S': 8, 'T': 11, 'W': 22, 'Y': 19, 'V': 7} vhse8_tensor = torch.zeros(24, 8) for aa, values in vhse8_values.items(): aa_index = aa_to_idx[aa] vhse8_tensor[aa_index] = torch.tensor(values) vhse8_tensor.requires_grad = False train_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/train/ppiref_skempi_2') val_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/val/ppiref_skempi_2') def collate_fn(batch): binders = [] mutants = [] wildtypes = [] affs = [] tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") for b in batch: binder = torch.tensor(b['binder_input_ids']['input_ids'][1:-1]) mutant = torch.tensor(b['mutant_input_ids']['input_ids'][1:-1]) wildtype = torch.tensor(b['wildtype_input_ids']['input_ids'][1:-1]) if binder.dim() == 0 or binder.numel() == 0 or mutant.dim() == 0 or mutant.numel() == 0 or wildtype.dim() == 0 or wildtype.numel() == 0: continue binders.append(binder) mutants.append(mutant) wildtypes.append(wildtype) affs.append(b['aff']) binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id) mutant_input_ids = torch.nn.utils.rnn.pad_sequence(mutants, batch_first=True, padding_value=tokenizer.pad_token_id) wildtype_input_ids = torch.nn.utils.rnn.pad_sequence(wildtypes, batch_first=True, padding_value=tokenizer.pad_token_id) affs = torch.tensor(affs) return { 'binder_input_ids': binder_input_ids.int(), 'mutant_input_ids': mutant_input_ids.int(), 'wildtype_input_ids': wildtype_input_ids.int(), 'aff': affs } class muPPItLightning(pl.LightningModule): def __init__(self, d_node, num_heads, dropout, margin, lr, train_dataset, easy_example_indices, hard_example_indices): super(muPPItLightning, self).__init__() self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() for param in self.esm.parameters(): param.requires_grad = False self.attention = torch.nn.MultiheadAttention(embed_dim=d_node, num_heads=num_heads) self.layer_norm = torch.nn.LayerNorm(d_node) self.map = torch.nn.Sequential( torch.nn.Linear(d_node, d_node // 2), torch.nn.SiLU(), torch.nn.Dropout(dropout), torch.nn.Linear(d_node // 2, 1) ) self.margin = margin self.learning_rate = lr self.loss_threshold = 0.5 self.save_hyperparameters() # Curriculum learning self.train_dataset = train_dataset self.easy_example_indices = easy_example_indices self.hard_example_indices = hard_example_indices self.current_subset_indices = easy_example_indices # Start with easy examples self.max_epochs = max_epochs def forward(self, binder_tokens, wt_tokens, mut_tokens): device = self.device global vhse8_tensor vhse8_tensor = vhse8_tensor.to(device) with torch.no_grad(): binder_pad_mask = (binder_tokens != self.alphabet.padding_idx).int() binder_embed = self.esm(binder_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * binder_pad_mask.unsqueeze(-1) binder_vhse8 = vhse8_tensor[binder_tokens] binder_embed = torch.concat([binder_embed, binder_vhse8], dim=-1) mut_pad_mask = (mut_tokens != self.alphabet.padding_idx).int() mut_embed = self.esm(mut_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * mut_pad_mask.unsqueeze(-1) mut_vhse8 = vhse8_tensor[mut_tokens] mut_embed = torch.concat([mut_embed, mut_vhse8], dim=-1) wt_pad_mask = (wt_tokens != self.alphabet.padding_idx).int() wt_embed = self.esm(wt_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * wt_pad_mask.unsqueeze(-1) wt_vhse8 = vhse8_tensor[wt_tokens] wt_embed = torch.concat([wt_embed, wt_vhse8], dim=-1) binder_wt = torch.concat([binder_embed, wt_embed], dim=1) binder_mut = torch.concat([binder_embed, mut_embed], dim=1) binder_wt = binder_wt.transpose(0, 1) binder_mut = binder_mut.transpose(0, 1) binder_wt_attn, _ = self.attention(binder_wt, binder_wt, binder_wt) binder_mut_attn, _ = self.attention(binder_mut, binder_mut, binder_mut) binder_wt_attn = binder_wt + binder_wt_attn binder_mut_attn = binder_mut + binder_mut_attn binder_wt_attn = binder_wt_attn.transpose(0, 1) binder_mut_attn = binder_mut_attn.transpose(0, 1) binder_wt_attn = self.layer_norm(binder_wt_attn) binder_mut_attn = self.layer_norm(binder_mut_attn) mapped_binder_wt = self.map(binder_wt_attn).squeeze(-1) mapped_binder_mut = self.map(binder_mut_attn).squeeze(-1) distance = torch.sqrt(torch.sum((mapped_binder_wt - mapped_binder_mut) ** 2, dim=-1)) return distance def compute_loss(self, binder_tokens, wt_tokens, mut_tokens, aff): distance = self(binder_tokens, wt_tokens, mut_tokens) upper_loss = F.relu(distance - self.margin * (aff + 1)) lower_loss = F.relu(self.margin * aff - distance) loss = upper_loss + lower_loss loss_weights = torch.ones_like(loss) hard_example_mask = loss > self.loss_threshold loss_weights[hard_example_mask] = 2.0 weighted_loss = loss * loss_weights return weighted_loss.mean() def training_step(self, batch, batch_idx): binder_tokens = batch['binder_input_ids'] mut_tokens = batch['mutant_input_ids'] wt_tokens = batch['wildtype_input_ids'] aff = batch['aff'] loss = self.compute_loss(binder_tokens, wt_tokens, mut_tokens, aff) self.log("train_loss", loss) return loss def validation_step(self, batch, batch_idx): binder_tokens = batch['binder_input_ids'] mut_tokens = batch['mutant_input_ids'] wt_tokens = batch['wildtype_input_ids'] aff = batch['aff'] val_loss = self.compute_loss(binder_tokens, wt_tokens, mut_tokens, aff) self.log("val_loss", val_loss) return val_loss def configure_optimizers(self): optimizer = AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.95)) total_steps = len(self.train_dataset) // batch_size * max_epochs warmup_steps = int(0.1 * total_steps) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps ) cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=1e-6) return [optimizer], [scheduler, cosine_scheduler] def train_dataloader(self): train_subset = Subset(self.train_dataset, self.current_subset_indices) return DataLoader(train_subset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=4) def on_train_epoch_start(self): # Curriculum learning logic epoch = self.current_epoch if epoch < 5: # Use only easy examples in the first 5 epochs self.current_subset_indices = self.easy_example_indices else: # After 5 epochs, start adding more hard examples num_hard_examples = int((epoch / self.max_epochs) * len(self.hard_example_indices)) selected_hard_indices = self.hard_example_indices[:num_hard_examples] self.current_subset_indices = self.easy_example_indices + selected_hard_indices # Load data indices for curriculum learning easy_example_indices = np.load('/home/tc415/muPPIt_embedding/dataset/ppiref_index.npy').tolist() hard_example_indices = np.load('/home/tc415/muPPIt_embedding/dataset/skempi_index.npy').tolist() # Instantiate the model with curriculum learning data model = muPPItLightning( d_node=1288, num_heads=8, dropout=dropout, margin=margin, lr=lr, train_dataset=train_dataset, easy_example_indices=easy_example_indices, hard_example_indices=hard_example_indices ) # Trainer trainer = pl.Trainer( max_epochs=max_epochs, gpus=-1, # Use all available GPUs accelerator='gpu', strategy='ddp' ) # Train the model trainer.fit(model)