import argparse import os.path def main(args): import json, time, os, sys, glob import shutil import warnings import numpy as np import torch from torch import optim from torch.utils.data import DataLoader import queue import copy import torch.nn as nn import torch.nn.functional as F import random import os.path import subprocess from concurrent.futures import ProcessPoolExecutor from utils import worker_init_fn, get_pdbs, loader_pdb, build_training_clusters, PDB_dataset, StructureDataset, StructureLoader from model_utils import featurize, loss_smoothed, loss_nll, get_std_opt, ProteinMPNN scaler = torch.cuda.amp.GradScaler() device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") base_folder = time.strftime(args.path_for_outputs, time.localtime()) if base_folder[-1] != '/': base_folder += '/' if not os.path.exists(base_folder): os.makedirs(base_folder) subfolders = ['model_weights'] for subfolder in subfolders: if not os.path.exists(base_folder + subfolder): os.makedirs(base_folder + subfolder) PATH = args.previous_checkpoint logfile = base_folder + 'log.txt' if not PATH: with open(logfile, 'w') as f: f.write('Epoch\tTrain\tValidation\n') data_path = args.path_for_training_data params = { "LIST" : f"{data_path}/list.csv", "VAL" : f"{data_path}/valid_clusters.txt", "TEST" : f"{data_path}/test_clusters.txt", "DIR" : f"{data_path}", "DATCUT" : "2030-Jan-01", "RESCUT" : args.rescut, #resolution cutoff for PDBs "HOMO" : 0.70 #min seq.id. to detect homo chains } LOAD_PARAM = {'batch_size': 1, 'shuffle': True, 'pin_memory':False, 'num_workers': 4} if args.debug: args.num_examples_per_epoch = 50 args.max_protein_length = 1000 args.batch_size = 1000 train, valid, test = build_training_clusters(params, args.debug) train_set = PDB_dataset(list(train.keys()), loader_pdb, train, params) train_loader = torch.utils.data.DataLoader(train_set, worker_init_fn=worker_init_fn, **LOAD_PARAM) valid_set = PDB_dataset(list(valid.keys()), loader_pdb, valid, params) valid_loader = torch.utils.data.DataLoader(valid_set, worker_init_fn=worker_init_fn, **LOAD_PARAM) model = ProteinMPNN(node_features=args.hidden_dim, edge_features=args.hidden_dim, hidden_dim=args.hidden_dim, num_encoder_layers=args.num_encoder_layers, num_decoder_layers=args.num_encoder_layers, k_neighbors=args.num_neighbors, dropout=args.dropout, augment_eps=args.backbone_noise) model.to(device) if PATH: checkpoint = torch.load(PATH) total_step = checkpoint['step'] #write total_step from the checkpoint epoch = checkpoint['epoch'] #write epoch from the checkpoint model.load_state_dict(checkpoint['model_state_dict']) else: total_step = 0 epoch = 0 optimizer = get_std_opt(model.parameters(), args.hidden_dim, total_step) if PATH: optimizer.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) with ProcessPoolExecutor(max_workers=12) as executor: q = queue.Queue(maxsize=3) p = queue.Queue(maxsize=3) for i in range(3): q.put_nowait(executor.submit(get_pdbs, train_loader, 1, args.max_protein_length, args.num_examples_per_epoch)) p.put_nowait(executor.submit(get_pdbs, valid_loader, 1, args.max_protein_length, args.num_examples_per_epoch)) pdb_dict_train = q.get().result() pdb_dict_valid = p.get().result() dataset_train = StructureDataset(pdb_dict_train, truncate=None, max_length=args.max_protein_length) dataset_valid = StructureDataset(pdb_dict_valid, truncate=None, max_length=args.max_protein_length) loader_train = StructureLoader(dataset_train, batch_size=args.batch_size) loader_valid = StructureLoader(dataset_valid, batch_size=args.batch_size) reload_c = 0 for e in range(args.num_epochs): t0 = time.time() e = epoch + e model.train() train_sum, train_weights = 0., 0. train_acc = 0. if e % args.reload_data_every_n_epochs == 0: if reload_c != 0: pdb_dict_train = q.get().result() dataset_train = StructureDataset(pdb_dict_train, truncate=None, max_length=args.max_protein_length) loader_train = StructureLoader(dataset_train, batch_size=args.batch_size) pdb_dict_valid = p.get().result() dataset_valid = StructureDataset(pdb_dict_valid, truncate=None, max_length=args.max_protein_length) loader_valid = StructureLoader(dataset_valid, batch_size=args.batch_size) q.put_nowait(executor.submit(get_pdbs, train_loader, 1, args.max_protein_length, args.num_examples_per_epoch)) p.put_nowait(executor.submit(get_pdbs, valid_loader, 1, args.max_protein_length, args.num_examples_per_epoch)) reload_c += 1 for _, batch in enumerate(loader_train): start_batch = time.time() X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device) elapsed_featurize = time.time() - start_batch optimizer.zero_grad() mask_for_loss = mask*chain_M if args.mixed_precision: with torch.cuda.amp.autocast(): log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all) _, loss_av_smoothed = loss_smoothed(S, log_probs, mask_for_loss) scaler.scale(loss_av_smoothed).backward() if args.gradient_norm > 0.0: total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_norm) scaler.step(optimizer) scaler.update() else: log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all) _, loss_av_smoothed = loss_smoothed(S, log_probs, mask_for_loss) loss_av_smoothed.backward() if args.gradient_norm > 0.0: total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_norm) optimizer.step() loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss) train_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy() train_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy() train_weights += torch.sum(mask_for_loss).cpu().data.numpy() total_step += 1 model.eval() with torch.no_grad(): validation_sum, validation_weights = 0., 0. validation_acc = 0. for _, batch in enumerate(loader_valid): X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device) log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all) mask_for_loss = mask*chain_M loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss) validation_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy() validation_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy() validation_weights += torch.sum(mask_for_loss).cpu().data.numpy() train_loss = train_sum / train_weights train_accuracy = train_acc / train_weights train_perplexity = np.exp(train_loss) validation_loss = validation_sum / validation_weights validation_accuracy = validation_acc / validation_weights validation_perplexity = np.exp(validation_loss) train_perplexity_ = np.format_float_positional(np.float32(train_perplexity), unique=False, precision=3) validation_perplexity_ = np.format_float_positional(np.float32(validation_perplexity), unique=False, precision=3) train_accuracy_ = np.format_float_positional(np.float32(train_accuracy), unique=False, precision=3) validation_accuracy_ = np.format_float_positional(np.float32(validation_accuracy), unique=False, precision=3) t1 = time.time() dt = np.format_float_positional(np.float32(t1-t0), unique=False, precision=1) with open(logfile, 'a') as f: f.write(f'epoch: {e+1}, step: {total_step}, time: {dt}, train: {train_perplexity_}, valid: {validation_perplexity_}, train_acc: {train_accuracy_}, valid_acc: {validation_accuracy_}\n') print(f'epoch: {e+1}, step: {total_step}, time: {dt}, train: {train_perplexity_}, valid: {validation_perplexity_}, train_acc: {train_accuracy_}, valid_acc: {validation_accuracy_}') checkpoint_filename_last = base_folder+'model_weights/epoch_last.pt'.format(e+1, total_step) torch.save({ 'epoch': e+1, 'step': total_step, 'num_edges' : args.num_neighbors, 'noise_level': args.backbone_noise, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.optimizer.state_dict(), }, checkpoint_filename_last) if (e+1) % args.save_model_every_n_epochs == 0: checkpoint_filename = base_folder+'model_weights/epoch{}_step{}.pt'.format(e+1, total_step) torch.save({ 'epoch': e+1, 'step': total_step, 'num_edges' : args.num_neighbors, 'noise_level': args.backbone_noise, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.optimizer.state_dict(), }, checkpoint_filename) if __name__ == "__main__": argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) argparser.add_argument("--path_for_training_data", type=str, default="my_path/pdb_2021aug02", help="path for loading training data") argparser.add_argument("--path_for_outputs", type=str, default="./exp_020", help="path for logs and model weights") argparser.add_argument("--previous_checkpoint", type=str, default="", help="path for previous model weights, e.g. file.pt") argparser.add_argument("--num_epochs", type=int, default=200, help="number of epochs to train for") argparser.add_argument("--save_model_every_n_epochs", type=int, default=10, help="save model weights every n epochs") argparser.add_argument("--reload_data_every_n_epochs", type=int, default=2, help="reload training data every n epochs") argparser.add_argument("--num_examples_per_epoch", type=int, default=1000000, help="number of training example to load for one epoch") argparser.add_argument("--batch_size", type=int, default=10000, help="number of tokens for one batch") argparser.add_argument("--max_protein_length", type=int, default=10000, help="maximum length of the protein complext") argparser.add_argument("--hidden_dim", type=int, default=128, help="hidden model dimension") argparser.add_argument("--num_encoder_layers", type=int, default=3, help="number of encoder layers") argparser.add_argument("--num_decoder_layers", type=int, default=3, help="number of decoder layers") argparser.add_argument("--num_neighbors", type=int, default=48, help="number of neighbors for the sparse graph") argparser.add_argument("--dropout", type=float, default=0.1, help="dropout level; 0.0 means no dropout") argparser.add_argument("--backbone_noise", type=float, default=0.2, help="amount of noise added to backbone during training") argparser.add_argument("--rescut", type=float, default=3.5, help="PDB resolution cutoff") argparser.add_argument("--debug", type=bool, default=False, help="minimal data loading for debugging") argparser.add_argument("--gradient_norm", type=float, default=-1.0, help="clip gradient norm, set to negative to omit clipping") argparser.add_argument("--mixed_precision", type=bool, default=True, help="train with mixed precision") args = argparser.parse_args() main(args)