import torch from torch.utils.data import DataLoader import csv from dateutil import parser import numpy as np import time import random import os class StructureDataset(): def __init__(self, pdb_dict_list, verbose=True, truncate=None, max_length=100, alphabet='ACDEFGHIKLMNPQRSTVWYX'): alphabet_set = set([a for a in alphabet]) discard_count = { 'bad_chars': 0, 'too_long': 0, 'bad_seq_length': 0 } self.data = [] start = time.time() for i, entry in enumerate(pdb_dict_list): seq = entry['seq'] name = entry['name'] bad_chars = set([s for s in seq]).difference(alphabet_set) if len(bad_chars) == 0: if len(entry['seq']) <= max_length: self.data.append(entry) else: discard_count['too_long'] += 1 else: #print(name, bad_chars, entry['seq']) discard_count['bad_chars'] += 1 # Truncate early if truncate is not None and len(self.data) == truncate: return if verbose and (i + 1) % 1000 == 0: elapsed = time.time() - start #print('{} entries ({} loaded) in {:.1f} s'.format(len(self.data), i+1, elapsed)) #print('Discarded', discard_count) def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] class StructureLoader(): def __init__(self, dataset, batch_size=100, shuffle=True, collate_fn=lambda x:x, drop_last=False): self.dataset = dataset self.size = len(dataset) self.lengths = [len(dataset[i]['seq']) for i in range(self.size)] self.batch_size = batch_size sorted_ix = np.argsort(self.lengths) # Cluster into batches of similar sizes clusters, batch = [], [] batch_max = 0 for ix in sorted_ix: size = self.lengths[ix] if size * (len(batch) + 1) <= self.batch_size: batch.append(ix) batch_max = size else: clusters.append(batch) batch, batch_max = [], 0 if len(batch) > 0: clusters.append(batch) self.clusters = clusters def __len__(self): return len(self.clusters) def __iter__(self): np.random.shuffle(self.clusters) for b_idx in self.clusters: batch = [self.dataset[i] for i in b_idx] yield batch def worker_init_fn(worker_id): np.random.seed() class NoamOpt: "Optim wrapper that implements rate." def __init__(self, model_size, factor, warmup, optimizer, step): self.optimizer = optimizer self._step = step self.warmup = warmup self.factor = factor self.model_size = model_size self._rate = 0 @property def param_groups(self): """Return param_groups.""" return self.optimizer.param_groups def step(self): "Update parameters and rate" self._step += 1 rate = self.rate() for p in self.optimizer.param_groups: p['lr'] = rate self._rate = rate self.optimizer.step() def rate(self, step = None): "Implement `lrate` above" if step is None: step = self._step return self.factor * \ (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))) def zero_grad(self): self.optimizer.zero_grad() def get_std_opt(parameters, d_model, step): return NoamOpt( d_model, 2, 4000, torch.optim.Adam(parameters, lr=0, betas=(0.9, 0.98), eps=1e-9), step ) def get_pdbs(data_loader, repeat=1, max_length=10000, num_units=1000000): init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z'] extra_alphabet = [str(item) for item in list(np.arange(300))] chain_alphabet = init_alphabet + extra_alphabet c = 0 c1 = 0 pdb_dict_list = [] t0 = time.time() for _ in range(repeat): for step,t in enumerate(data_loader): t = {k:v[0] for k,v in t.items()} c1 += 1 if 'label' in list(t): my_dict = {} s = 0 concat_seq = '' concat_N = [] concat_CA = [] concat_C = [] concat_O = [] concat_mask = [] coords_dict = {} mask_list = [] visible_list = [] if len(list(np.unique(t['idx']))) < 352: for idx in list(np.unique(t['idx'])): letter = chain_alphabet[idx] res = np.argwhere(t['idx']==idx) initial_sequence= "".join(list(np.array(list(t['seq']))[res][0,])) if initial_sequence[-6:] == "HHHHHH": res = res[:,:-6] if initial_sequence[0:6] == "HHHHHH": res = res[:,6:] if initial_sequence[-7:-1] == "HHHHHH": res = res[:,:-7] if initial_sequence[-8:-2] == "HHHHHH": res = res[:,:-8] if initial_sequence[-9:-3] == "HHHHHH": res = res[:,:-9] if initial_sequence[-10:-4] == "HHHHHH": res = res[:,:-10] if initial_sequence[1:7] == "HHHHHH": res = res[:,7:] if initial_sequence[2:8] == "HHHHHH": res = res[:,8:] if initial_sequence[3:9] == "HHHHHH": res = res[:,9:] if initial_sequence[4:10] == "HHHHHH": res = res[:,10:] if res.shape[1] < 4: pass else: my_dict['seq_chain_'+letter]= "".join(list(np.array(list(t['seq']))[res][0,])) concat_seq += my_dict['seq_chain_'+letter] if idx in t['masked']: mask_list.append(letter) else: visible_list.append(letter) coords_dict_chain = {} all_atoms = np.array(t['xyz'][res,])[0,] #[L, 14, 3] coords_dict_chain['N_chain_'+letter]=all_atoms[:,0,:].tolist() coords_dict_chain['CA_chain_'+letter]=all_atoms[:,1,:].tolist() coords_dict_chain['C_chain_'+letter]=all_atoms[:,2,:].tolist() coords_dict_chain['O_chain_'+letter]=all_atoms[:,3,:].tolist() my_dict['coords_chain_'+letter]=coords_dict_chain my_dict['name']= t['label'] my_dict['masked_list']= mask_list my_dict['visible_list']= visible_list my_dict['num_of_chains'] = len(mask_list) + len(visible_list) my_dict['seq'] = concat_seq if len(concat_seq) <= max_length: pdb_dict_list.append(my_dict) if len(pdb_dict_list) >= num_units: break return pdb_dict_list class PDB_dataset(torch.utils.data.Dataset): def __init__(self, IDs, loader, train_dict, params): self.IDs = IDs self.train_dict = train_dict self.loader = loader self.params = params def __len__(self): return len(self.IDs) def __getitem__(self, index): ID = self.IDs[index] sel_idx = np.random.randint(0, len(self.train_dict[ID])) out = self.loader(self.train_dict[ID][sel_idx], self.params) return out def loader_pdb(item,params): pdbid,chid = item[0].split('_') PREFIX = "%s/pdb/%s/%s"%(params['DIR'],pdbid[1:3],pdbid) # load metadata if not os.path.isfile(PREFIX+".pt"): return {'seq': np.zeros(5)} meta = torch.load(PREFIX+".pt") asmb_ids = meta['asmb_ids'] asmb_chains = meta['asmb_chains'] chids = np.array(meta['chains']) # find candidate assemblies which contain chid chain asmb_candidates = set([a for a,b in zip(asmb_ids,asmb_chains) if chid in b.split(',')]) # if the chains is missing is missing from all the assemblies # then return this chain alone if len(asmb_candidates)<1: chain = torch.load("%s_%s.pt"%(PREFIX,chid)) L = len(chain['seq']) return {'seq' : chain['seq'], 'xyz' : chain['xyz'], 'idx' : torch.zeros(L).int(), 'masked' : torch.Tensor([0]).int(), 'label' : item[0]} # randomly pick one assembly from candidates asmb_i = random.sample(list(asmb_candidates), 1) # indices of selected transforms idx = np.where(np.array(asmb_ids)==asmb_i)[0] # load relevant chains chains = {c:torch.load("%s_%s.pt"%(PREFIX,c)) for i in idx for c in asmb_chains[i] if c in meta['chains']} # generate assembly asmb = {} for k in idx: # pick k-th xform xform = meta['asmb_xform%d'%k] u = xform[:,:3,:3] r = xform[:,:3,3] # select chains which k-th xform should be applied to s1 = set(meta['chains']) s2 = set(asmb_chains[k].split(',')) chains_k = s1&s2 # transform selected chains for c in chains_k: try: xyz = chains[c]['xyz'] xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:,None,None,:] asmb.update({(c,k,i):xyz_i for i,xyz_i in enumerate(xyz_ru)}) except KeyError: return {'seq': np.zeros(5)} # select chains which share considerable similarity to chid seqid = meta['tm'][chids==chid][0,:,1] homo = set([ch_j for seqid_j,ch_j in zip(seqid,chids) if seqid_j>params['HOMO']]) # stack all chains in the assembly together seq,xyz,idx,masked = "",[],[],[] seq_list = [] for counter,(k,v) in enumerate(asmb.items()): seq += chains[k[0]]['seq'] seq_list.append(chains[k[0]]['seq']) xyz.append(v) idx.append(torch.full((v.shape[0],),counter)) if k[0] in homo: masked.append(counter) return {'seq' : seq, 'xyz' : torch.cat(xyz,dim=0), 'idx' : torch.cat(idx,dim=0), 'masked' : torch.Tensor(masked).int(), 'label' : item[0]} def build_training_clusters(params, debug): val_ids = set([int(l) for l in open(params['VAL']).readlines()]) test_ids = set([int(l) for l in open(params['TEST']).readlines()]) if debug: val_ids = [] test_ids = [] # read & clean list.csv with open(params['LIST'], 'r') as f: reader = csv.reader(f) next(reader) rows = [[r[0],r[3],int(r[4])] for r in reader if float(r[2])<=params['RESCUT'] and parser.parse(r[1])<=parser.parse(params['DATCUT'])] # compile training and validation sets train = {} valid = {} test = {} if debug: rows = rows[:20] for r in rows: if r[2] in val_ids: if r[2] in valid.keys(): valid[r[2]].append(r[:2]) else: valid[r[2]] = [r[:2]] elif r[2] in test_ids: if r[2] in test.keys(): test[r[2]].append(r[:2]) else: test[r[2]] = [r[:2]] else: if r[2] in train.keys(): train[r[2]].append(r[:2]) else: train[r[2]] = [r[:2]] if debug: valid=train return train, valid, test