Simon Duerr
add proteinmpnn
00aa807
raw history blame
No virus
12.4 kB
import torch
from torch.utils.data import DataLoader
import csv
from dateutil import parser
import numpy as np
import time
import random
import os
class StructureDataset():
def __init__(self, pdb_dict_list, verbose=True, truncate=None, max_length=100,
alphabet='ACDEFGHIKLMNPQRSTVWYX'):
alphabet_set = set([a for a in alphabet])
discard_count = {
'bad_chars': 0,
'too_long': 0,
'bad_seq_length': 0
}
self.data = []
start = time.time()
for i, entry in enumerate(pdb_dict_list):
seq = entry['seq']
name = entry['name']
bad_chars = set([s for s in seq]).difference(alphabet_set)
if len(bad_chars) == 0:
if len(entry['seq']) <= max_length:
self.data.append(entry)
else:
discard_count['too_long'] += 1
else:
#print(name, bad_chars, entry['seq'])
discard_count['bad_chars'] += 1
# Truncate early
if truncate is not None and len(self.data) == truncate:
return
if verbose and (i + 1) % 1000 == 0:
elapsed = time.time() - start
#print('{} entries ({} loaded) in {:.1f} s'.format(len(self.data), i+1, elapsed))
#print('Discarded', discard_count)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class StructureLoader():
def __init__(self, dataset, batch_size=100, shuffle=True,
collate_fn=lambda x:x, drop_last=False):
self.dataset = dataset
self.size = len(dataset)
self.lengths = [len(dataset[i]['seq']) for i in range(self.size)]
self.batch_size = batch_size
sorted_ix = np.argsort(self.lengths)
# Cluster into batches of similar sizes
clusters, batch = [], []
batch_max = 0
for ix in sorted_ix:
size = self.lengths[ix]
if size * (len(batch) + 1) <= self.batch_size:
batch.append(ix)
batch_max = size
else:
clusters.append(batch)
batch, batch_max = [], 0
if len(batch) > 0:
clusters.append(batch)
self.clusters = clusters
def __len__(self):
return len(self.clusters)
def __iter__(self):
np.random.shuffle(self.clusters)
for b_idx in self.clusters:
batch = [self.dataset[i] for i in b_idx]
yield batch
def worker_init_fn(worker_id):
np.random.seed()
class NoamOpt:
"Optim wrapper that implements rate."
def __init__(self, model_size, factor, warmup, optimizer, step):
self.optimizer = optimizer
self._step = step
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"Update parameters and rate"
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step = None):
"Implement `lrate` above"
if step is None:
step = self._step
return self.factor * \
(self.model_size ** (-0.5) *
min(step ** (-0.5), step * self.warmup ** (-1.5)))
def zero_grad(self):
self.optimizer.zero_grad()
def get_std_opt(parameters, d_model, step):
return NoamOpt(
d_model, 2, 4000, torch.optim.Adam(parameters, lr=0, betas=(0.9, 0.98), eps=1e-9), step
)
def get_pdbs(data_loader, repeat=1, max_length=10000, num_units=1000000):
init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z']
extra_alphabet = [str(item) for item in list(np.arange(300))]
chain_alphabet = init_alphabet + extra_alphabet
c = 0
c1 = 0
pdb_dict_list = []
t0 = time.time()
for _ in range(repeat):
for step,t in enumerate(data_loader):
t = {k:v[0] for k,v in t.items()}
c1 += 1
if 'label' in list(t):
my_dict = {}
s = 0
concat_seq = ''
concat_N = []
concat_CA = []
concat_C = []
concat_O = []
concat_mask = []
coords_dict = {}
mask_list = []
visible_list = []
if len(list(np.unique(t['idx']))) < 352:
for idx in list(np.unique(t['idx'])):
letter = chain_alphabet[idx]
res = np.argwhere(t['idx']==idx)
initial_sequence= "".join(list(np.array(list(t['seq']))[res][0,]))
if initial_sequence[-6:] == "HHHHHH":
res = res[:,:-6]
if initial_sequence[0:6] == "HHHHHH":
res = res[:,6:]
if initial_sequence[-7:-1] == "HHHHHH":
res = res[:,:-7]
if initial_sequence[-8:-2] == "HHHHHH":
res = res[:,:-8]
if initial_sequence[-9:-3] == "HHHHHH":
res = res[:,:-9]
if initial_sequence[-10:-4] == "HHHHHH":
res = res[:,:-10]
if initial_sequence[1:7] == "HHHHHH":
res = res[:,7:]
if initial_sequence[2:8] == "HHHHHH":
res = res[:,8:]
if initial_sequence[3:9] == "HHHHHH":
res = res[:,9:]
if initial_sequence[4:10] == "HHHHHH":
res = res[:,10:]
if res.shape[1] < 4:
pass
else:
my_dict['seq_chain_'+letter]= "".join(list(np.array(list(t['seq']))[res][0,]))
concat_seq += my_dict['seq_chain_'+letter]
if idx in t['masked']:
mask_list.append(letter)
else:
visible_list.append(letter)
coords_dict_chain = {}
all_atoms = np.array(t['xyz'][res,])[0,] #[L, 14, 3]
coords_dict_chain['N_chain_'+letter]=all_atoms[:,0,:].tolist()
coords_dict_chain['CA_chain_'+letter]=all_atoms[:,1,:].tolist()
coords_dict_chain['C_chain_'+letter]=all_atoms[:,2,:].tolist()
coords_dict_chain['O_chain_'+letter]=all_atoms[:,3,:].tolist()
my_dict['coords_chain_'+letter]=coords_dict_chain
my_dict['name']= t['label']
my_dict['masked_list']= mask_list
my_dict['visible_list']= visible_list
my_dict['num_of_chains'] = len(mask_list) + len(visible_list)
my_dict['seq'] = concat_seq
if len(concat_seq) <= max_length:
pdb_dict_list.append(my_dict)
if len(pdb_dict_list) >= num_units:
break
return pdb_dict_list
class PDB_dataset(torch.utils.data.Dataset):
def __init__(self, IDs, loader, train_dict, params):
self.IDs = IDs
self.train_dict = train_dict
self.loader = loader
self.params = params
def __len__(self):
return len(self.IDs)
def __getitem__(self, index):
ID = self.IDs[index]
sel_idx = np.random.randint(0, len(self.train_dict[ID]))
out = self.loader(self.train_dict[ID][sel_idx], self.params)
return out
def loader_pdb(item,params):
pdbid,chid = item[0].split('_')
PREFIX = "%s/pdb/%s/%s"%(params['DIR'],pdbid[1:3],pdbid)
# load metadata
if not os.path.isfile(PREFIX+".pt"):
return {'seq': np.zeros(5)}
meta = torch.load(PREFIX+".pt")
asmb_ids = meta['asmb_ids']
asmb_chains = meta['asmb_chains']
chids = np.array(meta['chains'])
# find candidate assemblies which contain chid chain
asmb_candidates = set([a for a,b in zip(asmb_ids,asmb_chains)
if chid in b.split(',')])
# if the chains is missing is missing from all the assemblies
# then return this chain alone
if len(asmb_candidates)<1:
chain = torch.load("%s_%s.pt"%(PREFIX,chid))
L = len(chain['seq'])
return {'seq' : chain['seq'],
'xyz' : chain['xyz'],
'idx' : torch.zeros(L).int(),
'masked' : torch.Tensor([0]).int(),
'label' : item[0]}
# randomly pick one assembly from candidates
asmb_i = random.sample(list(asmb_candidates), 1)
# indices of selected transforms
idx = np.where(np.array(asmb_ids)==asmb_i)[0]
# load relevant chains
chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
for i in idx for c in asmb_chains[i]
if c in meta['chains']}
# generate assembly
asmb = {}
for k in idx:
# pick k-th xform
xform = meta['asmb_xform%d'%k]
u = xform[:,:3,:3]
r = xform[:,:3,3]
# select chains which k-th xform should be applied to
s1 = set(meta['chains'])
s2 = set(asmb_chains[k].split(','))
chains_k = s1&s2
# transform selected chains
for c in chains_k:
try:
xyz = chains[c]['xyz']
xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:,None,None,:]
asmb.update({(c,k,i):xyz_i for i,xyz_i in enumerate(xyz_ru)})
except KeyError:
return {'seq': np.zeros(5)}
# select chains which share considerable similarity to chid
seqid = meta['tm'][chids==chid][0,:,1]
homo = set([ch_j for seqid_j,ch_j in zip(seqid,chids)
if seqid_j>params['HOMO']])
# stack all chains in the assembly together
seq,xyz,idx,masked = "",[],[],[]
seq_list = []
for counter,(k,v) in enumerate(asmb.items()):
seq += chains[k[0]]['seq']
seq_list.append(chains[k[0]]['seq'])
xyz.append(v)
idx.append(torch.full((v.shape[0],),counter))
if k[0] in homo:
masked.append(counter)
return {'seq' : seq,
'xyz' : torch.cat(xyz,dim=0),
'idx' : torch.cat(idx,dim=0),
'masked' : torch.Tensor(masked).int(),
'label' : item[0]}
def build_training_clusters(params, debug):
val_ids = set([int(l) for l in open(params['VAL']).readlines()])
test_ids = set([int(l) for l in open(params['TEST']).readlines()])
if debug:
val_ids = []
test_ids = []
# read & clean list.csv
with open(params['LIST'], 'r') as f:
reader = csv.reader(f)
next(reader)
rows = [[r[0],r[3],int(r[4])] for r in reader
if float(r[2])<=params['RESCUT'] and
parser.parse(r[1])<=parser.parse(params['DATCUT'])]
# compile training and validation sets
train = {}
valid = {}
test = {}
if debug:
rows = rows[:20]
for r in rows:
if r[2] in val_ids:
if r[2] in valid.keys():
valid[r[2]].append(r[:2])
else:
valid[r[2]] = [r[:2]]
elif r[2] in test_ids:
if r[2] in test.keys():
test[r[2]].append(r[:2])
else:
test[r[2]] = [r[:2]]
else:
if r[2] in train.keys():
train[r[2]].append(r[:2])
else:
train[r[2]] = [r[:2]]
if debug:
valid=train
return train, valid, test