Simon Duerr
add proteinmpnn
00aa807
raw
history blame contribute delete
No virus
13.1 kB
import argparse
import os.path
def main(args):
import json, time, os, sys, glob
import shutil
import warnings
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
import queue
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import os.path
import subprocess
from concurrent.futures import ProcessPoolExecutor
from utils import worker_init_fn, get_pdbs, loader_pdb, build_training_clusters, PDB_dataset, StructureDataset, StructureLoader
from model_utils import featurize, loss_smoothed, loss_nll, get_std_opt, ProteinMPNN
scaler = torch.cuda.amp.GradScaler()
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
base_folder = time.strftime(args.path_for_outputs, time.localtime())
if base_folder[-1] != '/':
base_folder += '/'
if not os.path.exists(base_folder):
os.makedirs(base_folder)
subfolders = ['model_weights']
for subfolder in subfolders:
if not os.path.exists(base_folder + subfolder):
os.makedirs(base_folder + subfolder)
PATH = args.previous_checkpoint
logfile = base_folder + 'log.txt'
if not PATH:
with open(logfile, 'w') as f:
f.write('Epoch\tTrain\tValidation\n')
data_path = args.path_for_training_data
params = {
"LIST" : f"{data_path}/list.csv",
"VAL" : f"{data_path}/valid_clusters.txt",
"TEST" : f"{data_path}/test_clusters.txt",
"DIR" : f"{data_path}",
"DATCUT" : "2030-Jan-01",
"RESCUT" : args.rescut, #resolution cutoff for PDBs
"HOMO" : 0.70 #min seq.id. to detect homo chains
}
LOAD_PARAM = {'batch_size': 1,
'shuffle': True,
'pin_memory':False,
'num_workers': 4}
if args.debug:
args.num_examples_per_epoch = 50
args.max_protein_length = 1000
args.batch_size = 1000
train, valid, test = build_training_clusters(params, args.debug)
train_set = PDB_dataset(list(train.keys()), loader_pdb, train, params)
train_loader = torch.utils.data.DataLoader(train_set, worker_init_fn=worker_init_fn, **LOAD_PARAM)
valid_set = PDB_dataset(list(valid.keys()), loader_pdb, valid, params)
valid_loader = torch.utils.data.DataLoader(valid_set, worker_init_fn=worker_init_fn, **LOAD_PARAM)
model = ProteinMPNN(node_features=args.hidden_dim,
edge_features=args.hidden_dim,
hidden_dim=args.hidden_dim,
num_encoder_layers=args.num_encoder_layers,
num_decoder_layers=args.num_encoder_layers,
k_neighbors=args.num_neighbors,
dropout=args.dropout,
augment_eps=args.backbone_noise)
model.to(device)
if PATH:
checkpoint = torch.load(PATH)
total_step = checkpoint['step'] #write total_step from the checkpoint
epoch = checkpoint['epoch'] #write epoch from the checkpoint
model.load_state_dict(checkpoint['model_state_dict'])
else:
total_step = 0
epoch = 0
optimizer = get_std_opt(model.parameters(), args.hidden_dim, total_step)
if PATH:
optimizer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
with ProcessPoolExecutor(max_workers=12) as executor:
q = queue.Queue(maxsize=3)
p = queue.Queue(maxsize=3)
for i in range(3):
q.put_nowait(executor.submit(get_pdbs, train_loader, 1, args.max_protein_length, args.num_examples_per_epoch))
p.put_nowait(executor.submit(get_pdbs, valid_loader, 1, args.max_protein_length, args.num_examples_per_epoch))
pdb_dict_train = q.get().result()
pdb_dict_valid = p.get().result()
dataset_train = StructureDataset(pdb_dict_train, truncate=None, max_length=args.max_protein_length)
dataset_valid = StructureDataset(pdb_dict_valid, truncate=None, max_length=args.max_protein_length)
loader_train = StructureLoader(dataset_train, batch_size=args.batch_size)
loader_valid = StructureLoader(dataset_valid, batch_size=args.batch_size)
reload_c = 0
for e in range(args.num_epochs):
t0 = time.time()
e = epoch + e
model.train()
train_sum, train_weights = 0., 0.
train_acc = 0.
if e % args.reload_data_every_n_epochs == 0:
if reload_c != 0:
pdb_dict_train = q.get().result()
dataset_train = StructureDataset(pdb_dict_train, truncate=None, max_length=args.max_protein_length)
loader_train = StructureLoader(dataset_train, batch_size=args.batch_size)
pdb_dict_valid = p.get().result()
dataset_valid = StructureDataset(pdb_dict_valid, truncate=None, max_length=args.max_protein_length)
loader_valid = StructureLoader(dataset_valid, batch_size=args.batch_size)
q.put_nowait(executor.submit(get_pdbs, train_loader, 1, args.max_protein_length, args.num_examples_per_epoch))
p.put_nowait(executor.submit(get_pdbs, valid_loader, 1, args.max_protein_length, args.num_examples_per_epoch))
reload_c += 1
for _, batch in enumerate(loader_train):
start_batch = time.time()
X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device)
elapsed_featurize = time.time() - start_batch
optimizer.zero_grad()
mask_for_loss = mask*chain_M
if args.mixed_precision:
with torch.cuda.amp.autocast():
log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all)
_, loss_av_smoothed = loss_smoothed(S, log_probs, mask_for_loss)
scaler.scale(loss_av_smoothed).backward()
if args.gradient_norm > 0.0:
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_norm)
scaler.step(optimizer)
scaler.update()
else:
log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all)
_, loss_av_smoothed = loss_smoothed(S, log_probs, mask_for_loss)
loss_av_smoothed.backward()
if args.gradient_norm > 0.0:
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_norm)
optimizer.step()
loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss)
train_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy()
train_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy()
train_weights += torch.sum(mask_for_loss).cpu().data.numpy()
total_step += 1
model.eval()
with torch.no_grad():
validation_sum, validation_weights = 0., 0.
validation_acc = 0.
for _, batch in enumerate(loader_valid):
X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device)
log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all)
mask_for_loss = mask*chain_M
loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss)
validation_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy()
validation_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy()
validation_weights += torch.sum(mask_for_loss).cpu().data.numpy()
train_loss = train_sum / train_weights
train_accuracy = train_acc / train_weights
train_perplexity = np.exp(train_loss)
validation_loss = validation_sum / validation_weights
validation_accuracy = validation_acc / validation_weights
validation_perplexity = np.exp(validation_loss)
train_perplexity_ = np.format_float_positional(np.float32(train_perplexity), unique=False, precision=3)
validation_perplexity_ = np.format_float_positional(np.float32(validation_perplexity), unique=False, precision=3)
train_accuracy_ = np.format_float_positional(np.float32(train_accuracy), unique=False, precision=3)
validation_accuracy_ = np.format_float_positional(np.float32(validation_accuracy), unique=False, precision=3)
t1 = time.time()
dt = np.format_float_positional(np.float32(t1-t0), unique=False, precision=1)
with open(logfile, 'a') as f:
f.write(f'epoch: {e+1}, step: {total_step}, time: {dt}, train: {train_perplexity_}, valid: {validation_perplexity_}, train_acc: {train_accuracy_}, valid_acc: {validation_accuracy_}\n')
print(f'epoch: {e+1}, step: {total_step}, time: {dt}, train: {train_perplexity_}, valid: {validation_perplexity_}, train_acc: {train_accuracy_}, valid_acc: {validation_accuracy_}')
checkpoint_filename_last = base_folder+'model_weights/epoch_last.pt'.format(e+1, total_step)
torch.save({
'epoch': e+1,
'step': total_step,
'num_edges' : args.num_neighbors,
'noise_level': args.backbone_noise,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.optimizer.state_dict(),
}, checkpoint_filename_last)
if (e+1) % args.save_model_every_n_epochs == 0:
checkpoint_filename = base_folder+'model_weights/epoch{}_step{}.pt'.format(e+1, total_step)
torch.save({
'epoch': e+1,
'step': total_step,
'num_edges' : args.num_neighbors,
'noise_level': args.backbone_noise,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.optimizer.state_dict(),
}, checkpoint_filename)
if __name__ == "__main__":
argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
argparser.add_argument("--path_for_training_data", type=str, default="my_path/pdb_2021aug02", help="path for loading training data")
argparser.add_argument("--path_for_outputs", type=str, default="./exp_020", help="path for logs and model weights")
argparser.add_argument("--previous_checkpoint", type=str, default="", help="path for previous model weights, e.g. file.pt")
argparser.add_argument("--num_epochs", type=int, default=200, help="number of epochs to train for")
argparser.add_argument("--save_model_every_n_epochs", type=int, default=10, help="save model weights every n epochs")
argparser.add_argument("--reload_data_every_n_epochs", type=int, default=2, help="reload training data every n epochs")
argparser.add_argument("--num_examples_per_epoch", type=int, default=1000000, help="number of training example to load for one epoch")
argparser.add_argument("--batch_size", type=int, default=10000, help="number of tokens for one batch")
argparser.add_argument("--max_protein_length", type=int, default=10000, help="maximum length of the protein complext")
argparser.add_argument("--hidden_dim", type=int, default=128, help="hidden model dimension")
argparser.add_argument("--num_encoder_layers", type=int, default=3, help="number of encoder layers")
argparser.add_argument("--num_decoder_layers", type=int, default=3, help="number of decoder layers")
argparser.add_argument("--num_neighbors", type=int, default=48, help="number of neighbors for the sparse graph")
argparser.add_argument("--dropout", type=float, default=0.1, help="dropout level; 0.0 means no dropout")
argparser.add_argument("--backbone_noise", type=float, default=0.2, help="amount of noise added to backbone during training")
argparser.add_argument("--rescut", type=float, default=3.5, help="PDB resolution cutoff")
argparser.add_argument("--debug", type=bool, default=False, help="minimal data loading for debugging")
argparser.add_argument("--gradient_norm", type=float, default=-1.0, help="clip gradient norm, set to negative to omit clipping")
argparser.add_argument("--mixed_precision", type=bool, default=True, help="train with mixed precision")
args = argparser.parse_args()
main(args)