import pdbx from pdbx.reader.PdbxReader import PdbxReader from pdbx.reader.PdbxContainers import DataCategory import gzip import numpy as np import torch import os,sys import glob import re from scipy.spatial import KDTree from itertools import combinations,permutations import tempfile import subprocess RES_NAMES = [ 'ALA','ARG','ASN','ASP','CYS', 'GLN','GLU','GLY','HIS','ILE', 'LEU','LYS','MET','PHE','PRO', 'SER','THR','TRP','TYR','VAL' ] RES_NAMES_1 = 'ARNDCQEGHILKMFPSTWYV' to1letter = {aaa:a for a,aaa in zip(RES_NAMES_1,RES_NAMES)} to3letter = {a:aaa for a,aaa in zip(RES_NAMES_1,RES_NAMES)} ATOM_NAMES = [ ("N", "CA", "C", "O", "CB"), # ala ("N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"), # arg ("N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"), # asn ("N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"), # asp ("N", "CA", "C", "O", "CB", "SG"), # cys ("N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"), # gln ("N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"), # glu ("N", "CA", "C", "O"), # gly ("N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"), # his ("N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"), # ile ("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"), # leu ("N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"), # lys ("N", "CA", "C", "O", "CB", "CG", "SD", "CE"), # met ("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"), # phe ("N", "CA", "C", "O", "CB", "CG", "CD"), # pro ("N", "CA", "C", "O", "CB", "OG"), # ser ("N", "CA", "C", "O", "CB", "OG1", "CG2"), # thr ("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE2", "CE3", "NE1", "CZ2", "CZ3", "CH2"), # trp ("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"), # tyr ("N", "CA", "C", "O", "CB", "CG1", "CG2") # val ] idx2ra = {(RES_NAMES_1[i],j):(RES_NAMES[i],a) for i in range(20) for j,a in enumerate(ATOM_NAMES[i])} aa2idx = {(r,a):i for r,atoms in zip(RES_NAMES,ATOM_NAMES) for i,a in enumerate(atoms)} aa2idx.update({(r,'OXT'):3 for r in RES_NAMES}) def writepdb(f, xyz, seq, bfac=None): #f = open(filename,"w") f.seek(0) ctr = 1 seq = str(seq) L = len(seq) if bfac is None: bfac = np.zeros((L)) idx = [] for i in range(L): for j,xyz_ij in enumerate(xyz[i]): key = (seq[i],j) if key not in idx2ra.keys(): continue if np.isnan(xyz_ij).sum()>0: continue r,a = idx2ra[key] f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( "ATOM", ctr, a, r, "A", i+1, xyz_ij[0], xyz_ij[1], xyz_ij[2], 1.0, bfac[i,j] ) ) if a == 'CA': idx.append(i) ctr += 1 #f.close() f.flush() return np.array(idx) def TMalign(chainA, chainB): # temp files to save the two input protein chains # and TMalign transformation fA = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm') fB = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm') mtx = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm') # create temp PDB files keep track of residue indices which were saved idxA = writepdb(fA, chainA['xyz'], chainA['seq'], bfac=chainA['bfac']) idxB = writepdb(fB, chainB['xyz'], chainB['seq'], bfac=chainB['bfac']) # run TMalign tm = subprocess.Popen('/home/aivan/prog/TMalign %s %s -m %s'%(fA.name, fB.name, mtx.name), shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='utf-8') stdout,stderr = tm.communicate() lines = stdout.split('\n') # if TMalign failed if len(stderr) > 0: return None,None # parse transformation mtx.seek(0) tu = np.fromstring(''.join(l[2:] for l in mtx.readlines()[2:5]), dtype=float, sep=' ').reshape((3,4)) t = tu[:,0] u = tu[:,1:] # parse rmsd, sequence identity, and two TM-scores rmsd = float(lines[16].split()[4][:-1]) seqid = float(lines[16].split()[-1]) tm1 = float(lines[17].split()[1]) tm2 = float(lines[18].split()[1]) # parse alignment seq1 = lines[-5] seq2 = lines[-3] ss1 = np.array(list(seq1.strip()))!='-' ss2 = np.array(list(seq2.strip()))!='-' #print(ss1) #print(ss2) mask = np.logical_and(ss1, ss2) alnAB = np.stack((idxA[(np.cumsum(ss1)-1)[mask]], idxB[(np.cumsum(ss2)-1)[mask]])) alnBA = np.stack((alnAB[1],alnAB[0])) # clean up fA.close() fB.close() mtx.close() resAB = {'rmsd':rmsd, 'seqid':seqid, 'tm':tm1, 'aln':alnAB, 't':t, 'u':u} resBA = {'rmsd':rmsd, 'seqid':seqid, 'tm':tm2, 'aln':alnBA, 't':-u.T@t, 'u':u.T} return resAB,resBA def get_tm_pairs(chains): """run TM-align for all pairs of chains""" tm_pairs = {} for A,B in combinations(chains.keys(),r=2): resAB,resBA = TMalign(chains[A],chains[B]) #if resAB is None: # continue tm_pairs.update({(A,B):resAB}) tm_pairs.update({(B,A):resBA}) # add self-alignments for A in chains.keys(): L = chains[A]['xyz'].shape[0] aln = np.arange(L)[chains[A]['mask'][:,1]] aln = np.stack((aln,aln)) tm_pairs.update({(A,A):{'rmsd':0.0, 'seqid':1.0, 'tm':1.0, 'aln':aln}}) return tm_pairs def parseOperationExpression(expression) : expression = expression.strip('() ') operations = [] for e in expression.split(','): e = e.strip() pos = e.find('-') if pos>0: start = int(e[0:pos]) stop = int(e[pos+1:]) operations.extend([str(i) for i in range(start,stop+1)]) else: operations.append(e) return operations def parseAssemblies(data,chids): xforms = {'asmb_chains' : None, 'asmb_details' : None, 'asmb_method' : None, 'asmb_ids' : None} assembly_data = data.getObj("pdbx_struct_assembly") assembly_gen = data.getObj("pdbx_struct_assembly_gen") oper_list = data.getObj("pdbx_struct_oper_list") if (assembly_data is None) or (assembly_gen is None) or (oper_list is None): return xforms # save all basic transformations in a dictionary opers = {} for k in range(oper_list.getRowCount()): key = oper_list.getValue("id", k) val = np.eye(4) for i in range(3): val[i,3] = float(oper_list.getValue("vector[%d]"%(i+1), k)) for j in range(3): val[i,j] = float(oper_list.getValue("matrix[%d][%d]"%(i+1,j+1), k)) opers.update({key:val}) chains,details,method,ids = [],[],[],[] for index in range(assembly_gen.getRowCount()): # Retrieve the assembly_id attribute value for this assembly assemblyId = assembly_gen.getValue("assembly_id", index) ids.append(assemblyId) # Retrieve the operation expression for this assembly from the oper_expression attribute oper_expression = assembly_gen.getValue("oper_expression", index) oper_list = [parseOperationExpression(expression) for expression in re.split('\(|\)', oper_expression) if expression] # chain IDs which the transform should be applied to chains.append(assembly_gen.getValue("asym_id_list", index)) index_asmb = min(index,assembly_data.getRowCount()-1) details.append(assembly_data.getValue("details", index_asmb)) method.append(assembly_data.getValue("method_details", index_asmb)) # if len(oper_list)==1: xform = np.stack([opers[o] for o in oper_list[0]]) elif len(oper_list)==2: xform = np.stack([opers[o1]@opers[o2] for o1 in oper_list[0] for o2 in oper_list[1]]) else: print('Error in processing assembly') return xforms xforms.update({'asmb_xform%d'%(index):xform}) xforms['asmb_chains'] = chains xforms['asmb_details'] = details xforms['asmb_method'] = method xforms['asmb_ids'] = ids return xforms def parse_mmcif(filename): #print(filename) chains = {} # 'chain_id' -> chain_strucure # read a gzipped .cif file data = [] with gzip.open(filename,'rt') as cif: reader = PdbxReader(cif) reader.read(data) data = data[0] # # get sequences # # map chain entity to chain ID entity_poly = data.getObj('entity_poly') if entity_poly is None: return {},{} pdbx_poly_seq_scheme = data.getObj('pdbx_poly_seq_scheme') pdb2asym = dict({ (r[pdbx_poly_seq_scheme.getIndex('pdb_strand_id')], r[pdbx_poly_seq_scheme.getIndex('asym_id')]) for r in data.getObj('pdbx_poly_seq_scheme').getRowList() }) chs2num = {pdb2asym[ch]:r[entity_poly.getIndex('entity_id')] for r in entity_poly.getRowList() for ch in r[entity_poly.getIndex('pdbx_strand_id')].split(',') if r[entity_poly.getIndex('type')]=='polypeptide(L)'} # get canonical sequences for polypeptide chains num2seq = {r[entity_poly.getIndex('entity_id')]:r[entity_poly.getIndex('pdbx_seq_one_letter_code_can')].replace('\n','') for r in entity_poly.getRowList() if r[entity_poly.getIndex('type')]=='polypeptide(L)'} # map chain entity to amino acid sequence #entity_poly_seq = data.getObj('entity_poly_seq') #num2seq = dict.fromkeys(set(chs2num.values()), "") #for row in entity_poly_seq.getRowList(): # num = row[entity_poly_seq.getIndex('entity_id')] # res = row[entity_poly_seq.getIndex('mon_id')] # if num not in num2seq.keys(): # continue # num2seq[num] += (to1letter[res] if res in to1letter.keys() else 'X') # modified residues pdbx_struct_mod_residue = data.getObj('pdbx_struct_mod_residue') if pdbx_struct_mod_residue is None: modres = {} else: modres = dict({(r[pdbx_struct_mod_residue.getIndex('label_comp_id')], r[pdbx_struct_mod_residue.getIndex('parent_comp_id')]) for r in pdbx_struct_mod_residue.getRowList()}) for k,v in modres.items(): print("# non-standard residue: %s %s"%(k,v)) # initialize dict of chains for c,n in chs2num.items(): seq = num2seq[n] L = len(seq) chains.update({c : {'seq' : seq, 'xyz' : np.full((L,14,3),np.nan,dtype=np.float32), 'mask' : np.zeros((L,14),dtype=bool), 'bfac' : np.full((L,14),np.nan,dtype=np.float32), 'occ' : np.zeros((L,14),dtype=np.float32) }}) # # populate structures # # get indices of fields of interest atom_site = data.getObj('atom_site') i = {k:atom_site.getIndex(val) for k,val in [('atm', 'label_atom_id'), # atom name ('atype', 'type_symbol'), # atom chemical type ('res', 'label_comp_id'), # residue name (3-letter) #('chid', 'auth_asym_id'), # chain ID ('chid', 'label_asym_id'), # chain ID ('num', 'label_seq_id'), # sequence number ('alt', 'label_alt_id'), # alternative location ID ('x', 'Cartn_x'), # xyz coords ('y', 'Cartn_y'), ('z', 'Cartn_z'), ('occ', 'occupancy'), # occupancy ('bfac', 'B_iso_or_equiv'), # B-factors ('model', 'pdbx_PDB_model_num') # model number (for multi-model PDBs, e.g. NMR) ]} for a in atom_site.getRowList(): # skip HETATM #if a[0] != 'ATOM': # continue # skip hydrogens if a[i['atype']] == 'H': continue # skip if not a polypeptide if a[i['chid']] not in chains.keys(): continue # parse atom atm, res, chid, num, alt, x, y, z, occ, Bfac, model = \ (t(a[i[k]]) for k,t in (('atm',str), ('res',str), ('chid',str), ('num',int), ('alt',str), ('x',float), ('y',float), ('z',float), ('occ',float), ('bfac',float), ('model',int))) #print(atm, res, chid, num, alt, x, y, z, occ, Bfac, model) c = chains[chid] # remap residue to canonical a = c['seq'][num-1] if a in to3letter.keys(): res = to3letter[a] else: if res in modres.keys() and modres[res] in to1letter.keys(): res = modres[res] c['seq'] = c['seq'][:num-1] + to1letter[res] + c['seq'][num:] else: res = 'GLY' # skip if not a standard residue/atom if (res,atm) not in aa2idx.keys(): continue # skip everything except model #1 if model > 1: continue # populate chians using max occup atoms idx = (num-1, aa2idx[(res,atm)]) if occ > c['occ'][idx]: c['xyz'][idx] = [x,y,z] c['mask'][idx] = True c['occ'][idx] = occ c['bfac'][idx] = Bfac # # metadata # #if data.getObj('reflns') is not None: # res = data.getObj('reflns').getValue('d_resolution_high',0) res = None if data.getObj('refine') is not None: try: res = float(data.getObj('refine').getValue('ls_d_res_high',0)) except: res = None if (data.getObj('em_3d_reconstruction') is not None) and (res is None): try: res = float(data.getObj('em_3d_reconstruction').getValue('resolution',0)) except: res = None chids = list(chains.keys()) seq = [] for ch in chids: mask = chains[ch]['mask'][:,:3].sum(1)==3 ref_seq = chains[ch]['seq'] atom_seq = ''.join([a if m else '-' for a,m in zip(ref_seq,mask)]) seq.append([ref_seq,atom_seq]) metadata = { 'method' : data.getObj('exptl').getValue('method',0).replace(' ','_'), 'date' : data.getObj('pdbx_database_status').getValue('recvd_initial_deposition_date',0), 'resolution' : res, 'chains' : chids, 'seq' : seq, 'id' : data.getObj('entry').getValue('id',0) } # # assemblies # asmbs = parseAssemblies(data,chains) metadata.update(asmbs) return chains, metadata IN = sys.argv[1] OUT = sys.argv[2] chains,metadata = parse_mmcif(IN) ID = metadata['id'] tm_pairs = get_tm_pairs(chains) if 'chains' in metadata.keys() and len(metadata['chains'])>0: chids = metadata['chains'] tm = [] for a in chids: tm_a = [] for b in chids: tm_ab = tm_pairs[(a,b)] if tm_ab is None: tm_a.append([0.0,0.0,999.9]) else: tm_a.append([tm_ab[k] for k in ['tm','seqid','rmsd']]) tm.append(tm_a) metadata.update({'tm':tm}) for k,v in chains.items(): nres = (v['mask'][:,:3].sum(1)==3).sum() print(">%s_%s %s %s %s %d %d\n%s"%(ID,k,metadata['date'],metadata['method'], metadata['resolution'],len(v['seq']),nres,v['seq'])) torch.save({kc:torch.Tensor(vc) if kc!='seq' else str(vc) for kc,vc in v.items()}, f"{OUT}_{k}.pt") meta_pt = {} for k,v in metadata.items(): if "asmb_xform" in k or k=="tm": v = torch.Tensor(v) meta_pt.update({k:v}) torch.save(meta_pt, f"{OUT}.pt")