Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils.data import DataLoader | |
import csv | |
from dateutil import parser | |
import numpy as np | |
import time | |
import random | |
import os | |
class StructureDataset(): | |
def __init__(self, pdb_dict_list, verbose=True, truncate=None, max_length=100, | |
alphabet='ACDEFGHIKLMNPQRSTVWYX'): | |
alphabet_set = set([a for a in alphabet]) | |
discard_count = { | |
'bad_chars': 0, | |
'too_long': 0, | |
'bad_seq_length': 0 | |
} | |
self.data = [] | |
start = time.time() | |
for i, entry in enumerate(pdb_dict_list): | |
seq = entry['seq'] | |
name = entry['name'] | |
bad_chars = set([s for s in seq]).difference(alphabet_set) | |
if len(bad_chars) == 0: | |
if len(entry['seq']) <= max_length: | |
self.data.append(entry) | |
else: | |
discard_count['too_long'] += 1 | |
else: | |
#print(name, bad_chars, entry['seq']) | |
discard_count['bad_chars'] += 1 | |
# Truncate early | |
if truncate is not None and len(self.data) == truncate: | |
return | |
if verbose and (i + 1) % 1000 == 0: | |
elapsed = time.time() - start | |
#print('{} entries ({} loaded) in {:.1f} s'.format(len(self.data), i+1, elapsed)) | |
#print('Discarded', discard_count) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
return self.data[idx] | |
class StructureLoader(): | |
def __init__(self, dataset, batch_size=100, shuffle=True, | |
collate_fn=lambda x:x, drop_last=False): | |
self.dataset = dataset | |
self.size = len(dataset) | |
self.lengths = [len(dataset[i]['seq']) for i in range(self.size)] | |
self.batch_size = batch_size | |
sorted_ix = np.argsort(self.lengths) | |
# Cluster into batches of similar sizes | |
clusters, batch = [], [] | |
batch_max = 0 | |
for ix in sorted_ix: | |
size = self.lengths[ix] | |
if size * (len(batch) + 1) <= self.batch_size: | |
batch.append(ix) | |
batch_max = size | |
else: | |
clusters.append(batch) | |
batch, batch_max = [], 0 | |
if len(batch) > 0: | |
clusters.append(batch) | |
self.clusters = clusters | |
def __len__(self): | |
return len(self.clusters) | |
def __iter__(self): | |
np.random.shuffle(self.clusters) | |
for b_idx in self.clusters: | |
batch = [self.dataset[i] for i in b_idx] | |
yield batch | |
def worker_init_fn(worker_id): | |
np.random.seed() | |
class NoamOpt: | |
"Optim wrapper that implements rate." | |
def __init__(self, model_size, factor, warmup, optimizer, step): | |
self.optimizer = optimizer | |
self._step = step | |
self.warmup = warmup | |
self.factor = factor | |
self.model_size = model_size | |
self._rate = 0 | |
def param_groups(self): | |
"""Return param_groups.""" | |
return self.optimizer.param_groups | |
def step(self): | |
"Update parameters and rate" | |
self._step += 1 | |
rate = self.rate() | |
for p in self.optimizer.param_groups: | |
p['lr'] = rate | |
self._rate = rate | |
self.optimizer.step() | |
def rate(self, step = None): | |
"Implement `lrate` above" | |
if step is None: | |
step = self._step | |
return self.factor * \ | |
(self.model_size ** (-0.5) * | |
min(step ** (-0.5), step * self.warmup ** (-1.5))) | |
def zero_grad(self): | |
self.optimizer.zero_grad() | |
def get_std_opt(parameters, d_model, step): | |
return NoamOpt( | |
d_model, 2, 4000, torch.optim.Adam(parameters, lr=0, betas=(0.9, 0.98), eps=1e-9), step | |
) | |
def get_pdbs(data_loader, repeat=1, max_length=10000, num_units=1000000): | |
init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z'] | |
extra_alphabet = [str(item) for item in list(np.arange(300))] | |
chain_alphabet = init_alphabet + extra_alphabet | |
c = 0 | |
c1 = 0 | |
pdb_dict_list = [] | |
t0 = time.time() | |
for _ in range(repeat): | |
for step,t in enumerate(data_loader): | |
t = {k:v[0] for k,v in t.items()} | |
c1 += 1 | |
if 'label' in list(t): | |
my_dict = {} | |
s = 0 | |
concat_seq = '' | |
concat_N = [] | |
concat_CA = [] | |
concat_C = [] | |
concat_O = [] | |
concat_mask = [] | |
coords_dict = {} | |
mask_list = [] | |
visible_list = [] | |
if len(list(np.unique(t['idx']))) < 352: | |
for idx in list(np.unique(t['idx'])): | |
letter = chain_alphabet[idx] | |
res = np.argwhere(t['idx']==idx) | |
initial_sequence= "".join(list(np.array(list(t['seq']))[res][0,])) | |
if initial_sequence[-6:] == "HHHHHH": | |
res = res[:,:-6] | |
if initial_sequence[0:6] == "HHHHHH": | |
res = res[:,6:] | |
if initial_sequence[-7:-1] == "HHHHHH": | |
res = res[:,:-7] | |
if initial_sequence[-8:-2] == "HHHHHH": | |
res = res[:,:-8] | |
if initial_sequence[-9:-3] == "HHHHHH": | |
res = res[:,:-9] | |
if initial_sequence[-10:-4] == "HHHHHH": | |
res = res[:,:-10] | |
if initial_sequence[1:7] == "HHHHHH": | |
res = res[:,7:] | |
if initial_sequence[2:8] == "HHHHHH": | |
res = res[:,8:] | |
if initial_sequence[3:9] == "HHHHHH": | |
res = res[:,9:] | |
if initial_sequence[4:10] == "HHHHHH": | |
res = res[:,10:] | |
if res.shape[1] < 4: | |
pass | |
else: | |
my_dict['seq_chain_'+letter]= "".join(list(np.array(list(t['seq']))[res][0,])) | |
concat_seq += my_dict['seq_chain_'+letter] | |
if idx in t['masked']: | |
mask_list.append(letter) | |
else: | |
visible_list.append(letter) | |
coords_dict_chain = {} | |
all_atoms = np.array(t['xyz'][res,])[0,] #[L, 14, 3] | |
coords_dict_chain['N_chain_'+letter]=all_atoms[:,0,:].tolist() | |
coords_dict_chain['CA_chain_'+letter]=all_atoms[:,1,:].tolist() | |
coords_dict_chain['C_chain_'+letter]=all_atoms[:,2,:].tolist() | |
coords_dict_chain['O_chain_'+letter]=all_atoms[:,3,:].tolist() | |
my_dict['coords_chain_'+letter]=coords_dict_chain | |
my_dict['name']= t['label'] | |
my_dict['masked_list']= mask_list | |
my_dict['visible_list']= visible_list | |
my_dict['num_of_chains'] = len(mask_list) + len(visible_list) | |
my_dict['seq'] = concat_seq | |
if len(concat_seq) <= max_length: | |
pdb_dict_list.append(my_dict) | |
if len(pdb_dict_list) >= num_units: | |
break | |
return pdb_dict_list | |
class PDB_dataset(torch.utils.data.Dataset): | |
def __init__(self, IDs, loader, train_dict, params): | |
self.IDs = IDs | |
self.train_dict = train_dict | |
self.loader = loader | |
self.params = params | |
def __len__(self): | |
return len(self.IDs) | |
def __getitem__(self, index): | |
ID = self.IDs[index] | |
sel_idx = np.random.randint(0, len(self.train_dict[ID])) | |
out = self.loader(self.train_dict[ID][sel_idx], self.params) | |
return out | |
def loader_pdb(item,params): | |
pdbid,chid = item[0].split('_') | |
PREFIX = "%s/pdb/%s/%s"%(params['DIR'],pdbid[1:3],pdbid) | |
# load metadata | |
if not os.path.isfile(PREFIX+".pt"): | |
return {'seq': np.zeros(5)} | |
meta = torch.load(PREFIX+".pt") | |
asmb_ids = meta['asmb_ids'] | |
asmb_chains = meta['asmb_chains'] | |
chids = np.array(meta['chains']) | |
# find candidate assemblies which contain chid chain | |
asmb_candidates = set([a for a,b in zip(asmb_ids,asmb_chains) | |
if chid in b.split(',')]) | |
# if the chains is missing is missing from all the assemblies | |
# then return this chain alone | |
if len(asmb_candidates)<1: | |
chain = torch.load("%s_%s.pt"%(PREFIX,chid)) | |
L = len(chain['seq']) | |
return {'seq' : chain['seq'], | |
'xyz' : chain['xyz'], | |
'idx' : torch.zeros(L).int(), | |
'masked' : torch.Tensor([0]).int(), | |
'label' : item[0]} | |
# randomly pick one assembly from candidates | |
asmb_i = random.sample(list(asmb_candidates), 1) | |
# indices of selected transforms | |
idx = np.where(np.array(asmb_ids)==asmb_i)[0] | |
# load relevant chains | |
chains = {c:torch.load("%s_%s.pt"%(PREFIX,c)) | |
for i in idx for c in asmb_chains[i] | |
if c in meta['chains']} | |
# generate assembly | |
asmb = {} | |
for k in idx: | |
# pick k-th xform | |
xform = meta['asmb_xform%d'%k] | |
u = xform[:,:3,:3] | |
r = xform[:,:3,3] | |
# select chains which k-th xform should be applied to | |
s1 = set(meta['chains']) | |
s2 = set(asmb_chains[k].split(',')) | |
chains_k = s1&s2 | |
# transform selected chains | |
for c in chains_k: | |
try: | |
xyz = chains[c]['xyz'] | |
xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:,None,None,:] | |
asmb.update({(c,k,i):xyz_i for i,xyz_i in enumerate(xyz_ru)}) | |
except KeyError: | |
return {'seq': np.zeros(5)} | |
# select chains which share considerable similarity to chid | |
seqid = meta['tm'][chids==chid][0,:,1] | |
homo = set([ch_j for seqid_j,ch_j in zip(seqid,chids) | |
if seqid_j>params['HOMO']]) | |
# stack all chains in the assembly together | |
seq,xyz,idx,masked = "",[],[],[] | |
seq_list = [] | |
for counter,(k,v) in enumerate(asmb.items()): | |
seq += chains[k[0]]['seq'] | |
seq_list.append(chains[k[0]]['seq']) | |
xyz.append(v) | |
idx.append(torch.full((v.shape[0],),counter)) | |
if k[0] in homo: | |
masked.append(counter) | |
return {'seq' : seq, | |
'xyz' : torch.cat(xyz,dim=0), | |
'idx' : torch.cat(idx,dim=0), | |
'masked' : torch.Tensor(masked).int(), | |
'label' : item[0]} | |
def build_training_clusters(params, debug): | |
val_ids = set([int(l) for l in open(params['VAL']).readlines()]) | |
test_ids = set([int(l) for l in open(params['TEST']).readlines()]) | |
if debug: | |
val_ids = [] | |
test_ids = [] | |
# read & clean list.csv | |
with open(params['LIST'], 'r') as f: | |
reader = csv.reader(f) | |
next(reader) | |
rows = [[r[0],r[3],int(r[4])] for r in reader | |
if float(r[2])<=params['RESCUT'] and | |
parser.parse(r[1])<=parser.parse(params['DATCUT'])] | |
# compile training and validation sets | |
train = {} | |
valid = {} | |
test = {} | |
if debug: | |
rows = rows[:20] | |
for r in rows: | |
if r[2] in val_ids: | |
if r[2] in valid.keys(): | |
valid[r[2]].append(r[:2]) | |
else: | |
valid[r[2]] = [r[:2]] | |
elif r[2] in test_ids: | |
if r[2] in test.keys(): | |
test[r[2]].append(r[:2]) | |
else: | |
test[r[2]] = [r[:2]] | |
else: | |
if r[2] in train.keys(): | |
train[r[2]].append(r[:2]) | |
else: | |
train[r[2]] = [r[:2]] | |
if debug: | |
valid=train | |
return train, valid, test | |