import torch from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DataLoader import numpy as np from sklearn.metrics import * from omegaconf import OmegaConf import os import random from aptatrans_pipeline import AptaTransPipeline from mcts import MCTS import esm from encoders import AptaTrans from utils import get_scores, API_Dataset, get_esm_dataset, rna2vec from accelerate import Accelerator # accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]) # NOTE: Buggy | Disables unused parameter issue accelerator = Accelerator() class AptaTransPipeline_Dist(AptaTransPipeline): """In-house API prediction score pipeline, inheriting from AptaTrans (Shin et al., 2023).""" def __init__(self, lr, weight_decay, epochs, model_type, model_version, model_save_path, accelerate_save_path, tensorboard_logdir, *args, **kwargs): super().__init__(*args, **kwargs) self.device = accelerator.device self.lr = lr self.weight_decay = weight_decay self.epochs = epochs self.model_type = model_type self.model_version = model_version self.model_save_path = model_save_path self.accelerate_save_path = accelerate_save_path self.tensorboard_logdir = tensorboard_logdir esm_prot_encoder, self.esm_alphabet = esm.pretrained.esm.pretrained.esm2_t33_650M_UR50D() # ESM-2 Encoder # Freeze ESM-2 for name, param in esm_prot_encoder.named_parameters(): param.requires_grad = False for name, param in esm_prot_encoder.named_parameters(): # if "layers.28" in name or "layers.29" in name or "layers.30" in name or "layers.31" in name or "layers.32" in name: if "layers.30" in name or "layers.31" in name or "layers.32" in name: param.requires_grad = True self.batch_converter = self.esm_alphabet.get_batch_converter() self.model = AptaTrans( apta_encoder=self.encoder_aptamer.encoder, prot_encoder=esm_prot_encoder, n_apta_vocabs=self.n_apta_vocabs, n_prot_vocabs=self.n_prot_vocabs, dropout=0.1, apta_max_len=self.apta_max_len, prot_max_len=self.prot_max_len ).to(self.device) self.criterion = torch.nn.BCELoss().to(self.device) def train(self): print('Training the model!') # Initialize writer instance writer = SummaryWriter(log_dir=f"log/{self.model_type}/{self.model_version}") # Initialize early stopping self.early_stopper = EarlyStopper(3, 3) self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) # Configure pytorch objects for distributed environment (i.e. sharded dataloader, multiple copies of model, etc.) self.model, self.optimizer, self.train_loader, self.test_loader, self.bench_loader = accelerator.prepare(self.model, self.optimizer, self.train_loader, self.test_loader, self.bench_loader) best_auc = 0 for epoch in range(1, self.epochs+1): self.model.train() loss_train, _, _ = self.batch_step(self.train_loader, train_mode=True) self.model.eval() with torch.no_grad(): loss_test, pred_test, target_test = self.batch_step(self.test_loader, train_mode=False) scores = get_scores(target_test, pred_test) print("\tTrain Loss: {: .6f}\tTest Loss: {: .6f}\tTest ACC: {:.6f}\tTest AUC: {:.6f}\tTest MCC: {:.6f}\tTest PR_AUC: {:.6f}\tF1: {:.6f}\n".format(loss_train ,loss_test, scores['acc'], scores['roc_auc'], scores['mcc'], scores['pr_auc'], scores['f1'])) stop_early = self.early_stopper.early_stop(loss_test) # Early stop - model has not improved on eval set. if stop_early: break if epoch > 15: with torch.no_grad(): loss_bench, pred_bench, target_bench = self.batch_step(self.bench_loader, train_mode=False) scores = get_scores(target_bench, pred_bench) print("\Bench Loss: {: .6f}\Bench ACC: {:.6f}\Bench AUC: {:.6f}\tBench MCC: {:.6f}\tBench PR_AUC: {:.6f}\tBench F1: {:.6f}\n".format(loss_bench, scores['acc'], scores['roc_auc'], scores['mcc'], scores['pr_auc'], scores['f1'])) # Checkpoint based off of benchmark criteria # If model has improved and early stopping patience counter was just reset: if scores['roc_auc'] > best_auc and self.early_stopper.counter == 0 and accelerator.is_main_process: best_auc = scores['roc_auc'] accelerator.save_state(self.accelerate_save_path) model = accelerator.unwrap_model(self.model) torch.save(model.state_dict(), f'{self.model_save_path}/model_epoch={epoch}.pt') print(f'Model saved at {self.model_save_path}') print(f'Accelerate statistics saved at {self.accelerate_save_path}!') # Access via accelerator.load_state("./output") # logging writer.add_scalar("Loss/train", loss_train, epoch) writer.add_scalar("Loss/test", loss_test, epoch) for k, v in scores.items(): if isinstance(v, float): writer.add_scalar(f'{k}/test', scores[k], epoch) print("Training finished | access tensorboard via 'tensorboard --logdir=runs'.") writer.flush() writer.close() def batch_step(self, loader, train_mode = True): loss_total = 0 pred = np.array([]) target = np.array([]) for batch_idx, (apta, esm_prot, y) in enumerate(loader): if train_mode: self.optimizer.zero_grad() y_pred = self.predict(apta, esm_prot) y_true = torch.tensor(y, dtype=torch.float32).to(self.device) # not needed since accelerator modifies dataloader to automatically map input objects to correct dev loss = self.criterion(torch.flatten(y_pred), y_true) if train_mode: accelerator.backward(loss) # Accelerate backward() method scales gradients and uses appropriate backward method as configured across devices self.optimizer.step() loss_total += loss.item() pred = np.append(pred, torch.flatten(y_pred).clone().detach().cpu().numpy()) target = np.append(target, torch.flatten(y_true).clone().detach().cpu().numpy()) mode = 'train' if train_mode else 'eval' print(mode + "[{}/{}({:.0f}%)]".format(batch_idx, len(loader), 100. * batch_idx / len(loader)), end = "\r", flush=True) loss_total /= len(loader) return loss_total, pred, target def predict(self, apta, esm_prot): y_pred = self.model(apta, esm_prot) return y_pred def inference(self, apta, prot, labels): """Perform inference on a batch of aptamer/protein pairs.""" print('Predicting the Aptamer-Protein Interaction') try: print("loading the best model for api!") self.model.load_state_dict(torch.load('./models/model.pt', map_location=self.device)) except: print('there is no best model file.') print('You need to train the model for predicting API!') y_place = np.zeros((len(prot), 1)) inputs = [(i, j) for i, j in zip(y_place, prot)] _, _, prot_tokens = self.batch_converter(inputs) apta_tokenized = rna2vec(np.array(apta)) # truncating prot_tokenized = prot_tokens[:, :1678] # padding prot_ex = torch.ones((prot_tokenized.shape[0], 1678), dtype=torch.int64)*self.esm_alphabet.padding_idx prot_ex[:, :prot_tokenized.shape[1]] = prot_tokenized loader = DataLoader(API_Dataset(apta_tokenized, prot_ex, labels), batch_size=4, shuffle=False) self.model, loader = accelerator.prepare(self.model, loader) _, pred, _ = self.batch_step(loader, train_mode=False) # self.model.eval() # with torch.no_grad(): # y_pred = self.model(apta_tokenized.to('cuda'), prot_ex.to('cuda')) # score = y_pred.detach().cpu().numpy() # print('Score : ', score) return pred def recommend(self, target, n_aptamers, depth, iteration, verbose=True): try: print("load the best model for api!") self.model.load_state_dict(torch.load('models/aptaESM2_trainable_l5/test_lr=1e-06_batch_size=16_dropout=0.05_wd=1e-06/model.pt', map_location=self.device)) except: print('there is no best model file.') print('You need to train the model for predicting API!') candidates = [] _, _, prot_tokens = self.batch_converter([(1, target)]) prot_tokenized = torch.tensor(prot_tokens, dtype=torch.int64) # adjusting for max protein sequence length during model training encoded_targetprotein = torch.ones((prot_tokenized.shape[0], 1678), dtype=torch.int64)*self.esm_alphabet.padding_idx encoded_targetprotein[:, :prot_tokenized.shape[1]] = prot_tokenized encoded_targetprotein = encoded_targetprotein.to(self.device) mcts = MCTS(encoded_targetprotein, depth=depth, iteration=iteration, states=8, target_protein=target, device=self.device) for _ in range(n_aptamers): mcts.make_candidate(self.model) candidates.append(mcts.get_candidate()) self.model.eval() with torch.no_grad(): sim_seq = np.array([mcts.get_candidate()]) apta = torch.tensor(rna2vec(sim_seq), dtype=torch.int64).to(self.device) score = self.model(apta.to(self.device), encoded_targetprotein) if verbose: print("candidate:\t", mcts.get_candidate(), "\tscore:\t", score) print("*"*80) mcts.reset() def set_data_for_training(self, filepath, batch_size): ds_train, ds_test = get_esm_dataset(filepath, self.batch_converter) self.train_loader = DataLoader(API_Dataset(ds_train[0], ds_train[1], ds_train[2]), batch_size=batch_size, shuffle=True) self.test_loader = DataLoader(API_Dataset(ds_test[0], ds_test[1], ds_test[2]), batch_size=batch_size, shuffle=False) class EarlyStopper: def __init__(self, patience=1, min_delta=0): self.patience = patience self.min_delta = min_delta self.counter = 0 self.min_validation_loss = float('inf') def early_stop(self, validation_loss): if validation_loss < self.min_validation_loss: self.min_validation_loss = validation_loss self.counter = 0 elif validation_loss > (self.min_validation_loss + self.min_delta): self.counter += 1 if self.counter >= self.patience: return True return False def seed_torch(seed=1029): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def main(): conf = OmegaConf.load('config.yaml') hyperparameters = conf.hyperparameters logging = conf.logging lr = hyperparameters['lr'] wd = hyperparameters['weight_decay'] dropout = hyperparameters['dropout'] batch_size = hyperparameters['batch_size'] epochs = hyperparameters['epochs'] model_type = logging['model_type'] model_version = logging['model_version'] model_save_path = logging['model_save_path'] accelerate_save_path = logging['accelerate_save_path'] tensorboard_logdir = logging['tensorboard_logdir'] seed = hyperparameters['seed'] if not os.path.exists(model_save_path): os.makedirs(model_save_path) seed_torch(seed=seed) pipeline = AptaTransPipeline_Dist( lr=lr, weight_decay=wd, epochs=epochs, model_type=model_type, model_version=model_version, model_save_path=model_save_path, accelerate_save_path=accelerate_save_path, tensorboard_logdir=tensorboard_logdir, d_model=128, d_ff=512, n_layers=6, n_heads=8, dropout=dropout, load_best_pt=True, # already loads the pretrained model using the datasets included in repo -- no need to run the bottom two cells device='cuda', seed=seed) datapath = "./data/dataset_li.pickle" pipeline.set_data_for_training(datapath, batch_size=batch_size) pipeline.train() return if __name__ == "__main__": # launch training w/ the following: "accelerate launch api_prediction.py [args]" main()