import numpy as np import torch # ============================================================ def get_pair_dist(a, b): """calculate pair distances between two sets of points Parameters ---------- a,b : pytorch tensors of shape [batch,nres,3] store Cartesian coordinates of two sets of atoms Returns ------- dist : pytorch tensor of shape [batch,nres,nres] stores paitwise distances between atoms in a and b """ dist = torch.cdist(a, b, p=2) return dist # ============================================================ def get_ang(a, b, c): """calculate planar angles for all consecutive triples (a[i],b[i],c[i]) from Cartesian coordinates of three sets of atoms a,b,c Parameters ---------- a,b,c : pytorch tensors of shape [batch,nres,3] store Cartesian coordinates of three sets of atoms Returns ------- ang : pytorch tensor of shape [batch,nres] stores resulting planar angles """ v = a - b w = c - b v = v / torch.norm(v, dim=-1, keepdim=True) w = w / torch.norm(w, dim=-1, keepdim=True) # this is not stable at the poles #vw = torch.sum(v*w, dim=-1) #ang = torch.acos(vw) # this is better # https://math.stackexchange.com/questions/1143354/numerically-stable-method-for-angle-between-3d-vectors/1782769 y = torch.norm(v-w,dim=-1) x = torch.norm(v+w,dim=-1) ang = 2*torch.atan2(y, x) return ang # ============================================================ def get_dih(a, b, c, d): """calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i]) given Cartesian coordinates of four sets of atoms a,b,c,d Parameters ---------- a,b,c,d : pytorch tensors of shape [batch,nres,3] store Cartesian coordinates of four sets of atoms Returns ------- dih : pytorch tensor of shape [batch,nres] stores resulting dihedrals """ b0 = a - b b1r = c - b b2 = d - c b1 = b1r/torch.norm(b1r, dim=-1, keepdim=True) v = b0 - torch.sum(b0*b1, dim=-1, keepdim=True)*b1 w = b2 - torch.sum(b2*b1, dim=-1, keepdim=True)*b1 x = torch.sum(v*w, dim=-1) y = torch.sum(torch.cross(b1,v,dim=-1)*w, dim=-1) ang = torch.atan2(y, x) return ang # ============================================================ def xyz_to_c6d(xyz, params): """convert cartesian coordinates into 2d distance and orientation maps Parameters ---------- xyz : pytorch tensor of shape [batch,3,nres,3] stores Cartesian coordinates of backbone N,Ca,C atoms Returns ------- c6d : pytorch tensor of shape [batch,nres,nres,4] stores stacked dist,omega,theta,phi 2D maps """ batch = xyz.shape[0] nres = xyz.shape[2] # three anchor atoms N = xyz[:,0] Ca = xyz[:,1] C = xyz[:,2] # recreate Cb given N,Ca,C b = Ca - N c = C - Ca a = torch.cross(b, c, dim=-1) Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca # 6d coordinates order: (dist,omega,theta,phi) c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device) dist = get_pair_dist(Cb,Cb) dist[torch.isnan(dist)] = 999.9 c6d[...,0] = dist + 999.9*torch.eye(nres,device=xyz.device)[None,...] b,i,j = torch.where(c6d[...,0]=params['DMAX']] = 999.9 return c6d # ============================================================ def c6d_to_bins(c6d,params): """bin 2d distance and orientation maps """ dstep = (params['DMAX'] - params['DMIN']) / params['DBINS'] astep = 2.0*np.pi / params['ABINS'] dbins = torch.linspace(params['DMIN']+dstep, params['DMAX'], params['DBINS'],dtype=c6d.dtype,device=c6d.device) ab360 = torch.linspace(-np.pi+astep, np.pi, params['ABINS'],dtype=c6d.dtype,device=c6d.device) ab180 = torch.linspace(astep, np.pi, params['ABINS']//2,dtype=c6d.dtype,device=c6d.device) db = torch.bucketize(c6d[...,0].contiguous(),dbins) ob = torch.bucketize(c6d[...,1].contiguous(),ab360) tb = torch.bucketize(c6d[...,2].contiguous(),ab360) pb = torch.bucketize(c6d[...,3].contiguous(),ab180) ob[db==params['DBINS']] = params['ABINS'] tb[db==params['DBINS']] = params['ABINS'] pb[db==params['DBINS']] = params['ABINS']//2 return torch.stack([db,ob,tb,pb],axis=-1).to(torch.uint8) # ============================================================ def dist_to_bins(dist,params): """bin 2d distance maps """ dstep = (params['DMAX'] - params['DMIN']) / params['DBINS'] db = torch.round((dist-params['DMIN']-dstep/2)/dstep) db[db<0] = 0 db[db>params['DBINS']] = params['DBINS'] return db.long() # ============================================================ def c6d_to_bins2(c6d,params): """bin 2d distance and orientation maps (alternative slightly simpler version) """ dstep = (params['DMAX'] - params['DMIN']) / params['DBINS'] astep = 2.0*np.pi / params['ABINS'] db = torch.round((c6d[...,0]-params['DMIN']-dstep/2)/dstep) ob = torch.round((c6d[...,1]+np.pi-astep/2)/astep) tb = torch.round((c6d[...,2]+np.pi-astep/2)/astep) pb = torch.round((c6d[...,3]-astep/2)/astep) # put all dparams['DBINS']] = params['DBINS'] ob[db==params['DBINS']] = params['ABINS'] tb[db==params['DBINS']] = params['ABINS'] pb[db==params['DBINS']] = params['ABINS']//2 return torch.stack([db,ob,tb,pb],axis=-1).long() # ============================================================ def get_cb(N,Ca,C): """recreate Cb given N,Ca,C""" b = Ca - N c = C - Ca a = torch.cross(b, c, dim=-1) Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca return Cb