protpardelle / ProteinMPNN /training /parse_cif_noX.py
Simon Duerr
add proteinmpnn
00aa807
raw
history blame
16.5 kB
import pdbx
from pdbx.reader.PdbxReader import PdbxReader
from pdbx.reader.PdbxContainers import DataCategory
import gzip
import numpy as np
import torch
import os,sys
import glob
import re
from scipy.spatial import KDTree
from itertools import combinations,permutations
import tempfile
import subprocess
RES_NAMES = [
'ALA','ARG','ASN','ASP','CYS',
'GLN','GLU','GLY','HIS','ILE',
'LEU','LYS','MET','PHE','PRO',
'SER','THR','TRP','TYR','VAL'
]
RES_NAMES_1 = 'ARNDCQEGHILKMFPSTWYV'
to1letter = {aaa:a for a,aaa in zip(RES_NAMES_1,RES_NAMES)}
to3letter = {a:aaa for a,aaa in zip(RES_NAMES_1,RES_NAMES)}
ATOM_NAMES = [
("N", "CA", "C", "O", "CB"), # ala
("N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"), # arg
("N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"), # asn
("N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"), # asp
("N", "CA", "C", "O", "CB", "SG"), # cys
("N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"), # gln
("N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"), # glu
("N", "CA", "C", "O"), # gly
("N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"), # his
("N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"), # ile
("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"), # leu
("N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"), # lys
("N", "CA", "C", "O", "CB", "CG", "SD", "CE"), # met
("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"), # phe
("N", "CA", "C", "O", "CB", "CG", "CD"), # pro
("N", "CA", "C", "O", "CB", "OG"), # ser
("N", "CA", "C", "O", "CB", "OG1", "CG2"), # thr
("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE2", "CE3", "NE1", "CZ2", "CZ3", "CH2"), # trp
("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"), # tyr
("N", "CA", "C", "O", "CB", "CG1", "CG2") # val
]
idx2ra = {(RES_NAMES_1[i],j):(RES_NAMES[i],a) for i in range(20) for j,a in enumerate(ATOM_NAMES[i])}
aa2idx = {(r,a):i for r,atoms in zip(RES_NAMES,ATOM_NAMES)
for i,a in enumerate(atoms)}
aa2idx.update({(r,'OXT'):3 for r in RES_NAMES})
def writepdb(f, xyz, seq, bfac=None):
#f = open(filename,"w")
f.seek(0)
ctr = 1
seq = str(seq)
L = len(seq)
if bfac is None:
bfac = np.zeros((L))
idx = []
for i in range(L):
for j,xyz_ij in enumerate(xyz[i]):
key = (seq[i],j)
if key not in idx2ra.keys():
continue
if np.isnan(xyz_ij).sum()>0:
continue
r,a = idx2ra[key]
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, a, r,
"A", i+1, xyz_ij[0], xyz_ij[1], xyz_ij[2],
1.0, bfac[i,j] ) )
if a == 'CA':
idx.append(i)
ctr += 1
#f.close()
f.flush()
return np.array(idx)
def TMalign(chainA, chainB):
# temp files to save the two input protein chains
# and TMalign transformation
fA = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm')
fB = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm')
mtx = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm')
# create temp PDB files keep track of residue indices which were saved
idxA = writepdb(fA, chainA['xyz'], chainA['seq'], bfac=chainA['bfac'])
idxB = writepdb(fB, chainB['xyz'], chainB['seq'], bfac=chainB['bfac'])
# run TMalign
tm = subprocess.Popen('/home/aivan/prog/TMalign %s %s -m %s'%(fA.name, fB.name, mtx.name),
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding='utf-8')
stdout,stderr = tm.communicate()
lines = stdout.split('\n')
# if TMalign failed
if len(stderr) > 0:
return None,None
# parse transformation
mtx.seek(0)
tu = np.fromstring(''.join(l[2:] for l in mtx.readlines()[2:5]),
dtype=float, sep=' ').reshape((3,4))
t = tu[:,0]
u = tu[:,1:]
# parse rmsd, sequence identity, and two TM-scores
rmsd = float(lines[16].split()[4][:-1])
seqid = float(lines[16].split()[-1])
tm1 = float(lines[17].split()[1])
tm2 = float(lines[18].split()[1])
# parse alignment
seq1 = lines[-5]
seq2 = lines[-3]
ss1 = np.array(list(seq1.strip()))!='-'
ss2 = np.array(list(seq2.strip()))!='-'
#print(ss1)
#print(ss2)
mask = np.logical_and(ss1, ss2)
alnAB = np.stack((idxA[(np.cumsum(ss1)-1)[mask]],
idxB[(np.cumsum(ss2)-1)[mask]]))
alnBA = np.stack((alnAB[1],alnAB[0]))
# clean up
fA.close()
fB.close()
mtx.close()
resAB = {'rmsd':rmsd, 'seqid':seqid, 'tm':tm1, 'aln':alnAB, 't':t, 'u':u}
resBA = {'rmsd':rmsd, 'seqid':seqid, 'tm':tm2, 'aln':alnBA, 't':-u.T@t, 'u':u.T}
return resAB,resBA
def get_tm_pairs(chains):
"""run TM-align for all pairs of chains"""
tm_pairs = {}
for A,B in combinations(chains.keys(),r=2):
resAB,resBA = TMalign(chains[A],chains[B])
#if resAB is None:
# continue
tm_pairs.update({(A,B):resAB})
tm_pairs.update({(B,A):resBA})
# add self-alignments
for A in chains.keys():
L = chains[A]['xyz'].shape[0]
aln = np.arange(L)[chains[A]['mask'][:,1]]
aln = np.stack((aln,aln))
tm_pairs.update({(A,A):{'rmsd':0.0, 'seqid':1.0, 'tm':1.0, 'aln':aln}})
return tm_pairs
def parseOperationExpression(expression) :
expression = expression.strip('() ')
operations = []
for e in expression.split(','):
e = e.strip()
pos = e.find('-')
if pos>0:
start = int(e[0:pos])
stop = int(e[pos+1:])
operations.extend([str(i) for i in range(start,stop+1)])
else:
operations.append(e)
return operations
def parseAssemblies(data,chids):
xforms = {'asmb_chains' : None,
'asmb_details' : None,
'asmb_method' : None,
'asmb_ids' : None}
assembly_data = data.getObj("pdbx_struct_assembly")
assembly_gen = data.getObj("pdbx_struct_assembly_gen")
oper_list = data.getObj("pdbx_struct_oper_list")
if (assembly_data is None) or (assembly_gen is None) or (oper_list is None):
return xforms
# save all basic transformations in a dictionary
opers = {}
for k in range(oper_list.getRowCount()):
key = oper_list.getValue("id", k)
val = np.eye(4)
for i in range(3):
val[i,3] = float(oper_list.getValue("vector[%d]"%(i+1), k))
for j in range(3):
val[i,j] = float(oper_list.getValue("matrix[%d][%d]"%(i+1,j+1), k))
opers.update({key:val})
chains,details,method,ids = [],[],[],[]
for index in range(assembly_gen.getRowCount()):
# Retrieve the assembly_id attribute value for this assembly
assemblyId = assembly_gen.getValue("assembly_id", index)
ids.append(assemblyId)
# Retrieve the operation expression for this assembly from the oper_expression attribute
oper_expression = assembly_gen.getValue("oper_expression", index)
oper_list = [parseOperationExpression(expression)
for expression in re.split('\(|\)', oper_expression) if expression]
# chain IDs which the transform should be applied to
chains.append(assembly_gen.getValue("asym_id_list", index))
index_asmb = min(index,assembly_data.getRowCount()-1)
details.append(assembly_data.getValue("details", index_asmb))
method.append(assembly_data.getValue("method_details", index_asmb))
#
if len(oper_list)==1:
xform = np.stack([opers[o] for o in oper_list[0]])
elif len(oper_list)==2:
xform = np.stack([opers[o1]@opers[o2]
for o1 in oper_list[0]
for o2 in oper_list[1]])
else:
print('Error in processing assembly')
return xforms
xforms.update({'asmb_xform%d'%(index):xform})
xforms['asmb_chains'] = chains
xforms['asmb_details'] = details
xforms['asmb_method'] = method
xforms['asmb_ids'] = ids
return xforms
def parse_mmcif(filename):
#print(filename)
chains = {} # 'chain_id' -> chain_strucure
# read a gzipped .cif file
data = []
with gzip.open(filename,'rt') as cif:
reader = PdbxReader(cif)
reader.read(data)
data = data[0]
#
# get sequences
#
# map chain entity to chain ID
entity_poly = data.getObj('entity_poly')
if entity_poly is None:
return {},{}
pdbx_poly_seq_scheme = data.getObj('pdbx_poly_seq_scheme')
pdb2asym = dict({
(r[pdbx_poly_seq_scheme.getIndex('pdb_strand_id')],
r[pdbx_poly_seq_scheme.getIndex('asym_id')])
for r in data.getObj('pdbx_poly_seq_scheme').getRowList()
})
chs2num = {pdb2asym[ch]:r[entity_poly.getIndex('entity_id')]
for r in entity_poly.getRowList()
for ch in r[entity_poly.getIndex('pdbx_strand_id')].split(',')
if r[entity_poly.getIndex('type')]=='polypeptide(L)'}
# get canonical sequences for polypeptide chains
num2seq = {r[entity_poly.getIndex('entity_id')]:r[entity_poly.getIndex('pdbx_seq_one_letter_code_can')].replace('\n','')
for r in entity_poly.getRowList()
if r[entity_poly.getIndex('type')]=='polypeptide(L)'}
# map chain entity to amino acid sequence
#entity_poly_seq = data.getObj('entity_poly_seq')
#num2seq = dict.fromkeys(set(chs2num.values()), "")
#for row in entity_poly_seq.getRowList():
# num = row[entity_poly_seq.getIndex('entity_id')]
# res = row[entity_poly_seq.getIndex('mon_id')]
# if num not in num2seq.keys():
# continue
# num2seq[num] += (to1letter[res] if res in to1letter.keys() else 'X')
# modified residues
pdbx_struct_mod_residue = data.getObj('pdbx_struct_mod_residue')
if pdbx_struct_mod_residue is None:
modres = {}
else:
modres = dict({(r[pdbx_struct_mod_residue.getIndex('label_comp_id')],
r[pdbx_struct_mod_residue.getIndex('parent_comp_id')])
for r in pdbx_struct_mod_residue.getRowList()})
for k,v in modres.items():
print("# non-standard residue: %s %s"%(k,v))
# initialize dict of chains
for c,n in chs2num.items():
seq = num2seq[n]
L = len(seq)
chains.update({c : {'seq' : seq,
'xyz' : np.full((L,14,3),np.nan,dtype=np.float32),
'mask' : np.zeros((L,14),dtype=bool),
'bfac' : np.full((L,14),np.nan,dtype=np.float32),
'occ' : np.zeros((L,14),dtype=np.float32) }})
#
# populate structures
#
# get indices of fields of interest
atom_site = data.getObj('atom_site')
i = {k:atom_site.getIndex(val) for k,val in [('atm', 'label_atom_id'), # atom name
('atype', 'type_symbol'), # atom chemical type
('res', 'label_comp_id'), # residue name (3-letter)
#('chid', 'auth_asym_id'), # chain ID
('chid', 'label_asym_id'), # chain ID
('num', 'label_seq_id'), # sequence number
('alt', 'label_alt_id'), # alternative location ID
('x', 'Cartn_x'), # xyz coords
('y', 'Cartn_y'),
('z', 'Cartn_z'),
('occ', 'occupancy'), # occupancy
('bfac', 'B_iso_or_equiv'), # B-factors
('model', 'pdbx_PDB_model_num') # model number (for multi-model PDBs, e.g. NMR)
]}
for a in atom_site.getRowList():
# skip HETATM
#if a[0] != 'ATOM':
# continue
# skip hydrogens
if a[i['atype']] == 'H':
continue
# skip if not a polypeptide
if a[i['chid']] not in chains.keys():
continue
# parse atom
atm, res, chid, num, alt, x, y, z, occ, Bfac, model = \
(t(a[i[k]]) for k,t in (('atm',str), ('res',str), ('chid',str),
('num',int), ('alt',str),
('x',float), ('y',float), ('z',float),
('occ',float), ('bfac',float), ('model',int)))
#print(atm, res, chid, num, alt, x, y, z, occ, Bfac, model)
c = chains[chid]
# remap residue to canonical
a = c['seq'][num-1]
if a in to3letter.keys():
res = to3letter[a]
else:
if res in modres.keys() and modres[res] in to1letter.keys():
res = modres[res]
c['seq'] = c['seq'][:num-1] + to1letter[res] + c['seq'][num:]
else:
res = 'GLY'
# skip if not a standard residue/atom
if (res,atm) not in aa2idx.keys():
continue
# skip everything except model #1
if model > 1:
continue
# populate chians using max occup atoms
idx = (num-1, aa2idx[(res,atm)])
if occ > c['occ'][idx]:
c['xyz'][idx] = [x,y,z]
c['mask'][idx] = True
c['occ'][idx] = occ
c['bfac'][idx] = Bfac
#
# metadata
#
#if data.getObj('reflns') is not None:
# res = data.getObj('reflns').getValue('d_resolution_high',0)
res = None
if data.getObj('refine') is not None:
try:
res = float(data.getObj('refine').getValue('ls_d_res_high',0))
except:
res = None
if (data.getObj('em_3d_reconstruction') is not None) and (res is None):
try:
res = float(data.getObj('em_3d_reconstruction').getValue('resolution',0))
except:
res = None
chids = list(chains.keys())
seq = []
for ch in chids:
mask = chains[ch]['mask'][:,:3].sum(1)==3
ref_seq = chains[ch]['seq']
atom_seq = ''.join([a if m else '-' for a,m in zip(ref_seq,mask)])
seq.append([ref_seq,atom_seq])
metadata = {
'method' : data.getObj('exptl').getValue('method',0).replace(' ','_'),
'date' : data.getObj('pdbx_database_status').getValue('recvd_initial_deposition_date',0),
'resolution' : res,
'chains' : chids,
'seq' : seq,
'id' : data.getObj('entry').getValue('id',0)
}
#
# assemblies
#
asmbs = parseAssemblies(data,chains)
metadata.update(asmbs)
return chains, metadata
IN = sys.argv[1]
OUT = sys.argv[2]
chains,metadata = parse_mmcif(IN)
ID = metadata['id']
tm_pairs = get_tm_pairs(chains)
if 'chains' in metadata.keys() and len(metadata['chains'])>0:
chids = metadata['chains']
tm = []
for a in chids:
tm_a = []
for b in chids:
tm_ab = tm_pairs[(a,b)]
if tm_ab is None:
tm_a.append([0.0,0.0,999.9])
else:
tm_a.append([tm_ab[k] for k in ['tm','seqid','rmsd']])
tm.append(tm_a)
metadata.update({'tm':tm})
for k,v in chains.items():
nres = (v['mask'][:,:3].sum(1)==3).sum()
print(">%s_%s %s %s %s %d %d\n%s"%(ID,k,metadata['date'],metadata['method'],
metadata['resolution'],len(v['seq']),nres,v['seq']))
torch.save({kc:torch.Tensor(vc) if kc!='seq' else str(vc)
for kc,vc in v.items()}, f"{OUT}_{k}.pt")
meta_pt = {}
for k,v in metadata.items():
if "asmb_xform" in k or k=="tm":
v = torch.Tensor(v)
meta_pt.update({k:v})
torch.save(meta_pt, f"{OUT}.pt")