GenFBDD / datasets /parse_chi.py
libokj's picture
Minify
c17cba8
# From Nick Polizzi
import numpy as np
from collections import defaultdict
import prody as pr
import os
from datasets.constants import chi, atom_order, aa_long2short, aa_short2aa_idx, aa_idx2aa_short
def get_dihedral_indices(resname, chi_num):
"""Return the atom indices for the specified dihedral angle.
"""
if resname not in chi:
return np.array([np.nan]*4)
if chi_num not in chi[resname]:
return np.array([np.nan]*4)
return np.array([atom_order[resname].index(x) for x in chi[resname][chi_num]])
dihedral_indices = defaultdict(list)
for aa in atom_order.keys():
for i in range(1, 5):
inds = get_dihedral_indices(aa, i)
dihedral_indices[aa].append(inds)
dihedral_indices[aa] = np.array(dihedral_indices[aa])
def vector_batch(a, b):
return a - b
def unit_vector_batch(v):
return v / np.linalg.norm(v, axis=1, keepdims=True)
def dihedral_angle_batch(p):
b0 = vector_batch(p[:, 0], p[:, 1])
b1 = vector_batch(p[:, 1], p[:, 2])
b2 = vector_batch(p[:, 2], p[:, 3])
n1 = np.cross(b0, b1)
n2 = np.cross(b1, b2)
m1 = np.cross(n1, b1 / np.linalg.norm(b1, axis=1, keepdims=True))
x = np.sum(n1 * n2, axis=1)
y = np.sum(m1 * n2, axis=1)
deg = np.degrees(np.arctan2(y, x))
deg[deg < 0] += 360
return deg
def batch_compute_dihedral_angles(sidechains):
sidechains_np = np.array(sidechains)
dihedral_angles = dihedral_angle_batch(sidechains_np)
return dihedral_angles
def get_coords(prody_pdb):
resindices = sorted(set(prody_pdb.ca.getResindices()))
coords = np.full((len(resindices), 14, 3), np.nan)
for i, resind in enumerate(resindices):
sel = prody_pdb.select(f'resindex {resind}')
resname = sel.getResnames()[0]
for j, name in enumerate(atom_order[aa_long2short[resname] if resname in aa_long2short else 'X']):
sel_resnum_name = sel.select(f'name {name}')
if sel_resnum_name is not None:
coords[i, j, :] = sel_resnum_name.getCoords()[0]
else:
coords[i, j, :] = [np.nan, np.nan, np.nan]
return coords
def get_onehot_sequence(seq):
onehot = np.zeros((len(seq), 20))
for i, aa in enumerate(seq):
idx = aa_short2aa_idx[aa] if aa in aa_short2aa_idx else 7 # 7 is the index for GLY
onehot[i, idx] = 1
return onehot
def get_dihedral_indices(onehot_sequence):
return np.array([dihedral_indices[aa_idx2aa_short[aa_idx]] for aa_idx in np.where(onehot_sequence)[1]])
def _get_chi_angles(coords, indices):
X = coords
Y = indices.astype(int)
N = coords.shape[0]
mask = np.isnan(indices)
Y[mask] = 0
Z = X[np.arange(N)[:, None, None], Y, :]
Z[mask] = np.nan
chi_angles = batch_compute_dihedral_angles(Z.reshape(-1, 4, 3)).reshape(N, 4)
return chi_angles
def get_chi_angles(coords, seq, return_onehot=False):
"""
Parameters
----------
prody_pdb : prody.AtomGroup
prody pdb object or selection
return_coords : bool, optional
return coordinates of prody_pdb in (N, 14, 3) array format, by default False
return_onehot : bool, optional
return one-hot sequence of prody_pdb, by default False
Returns
-------
numpy array of shape (N, 4)
Array contains chi angles of sidechains in row-order of residue indices in prody_pdb.
If a chi angle is not defined for a residue, due to missing atoms or GLY / ALA, it is set to np.nan.
"""
onehot = get_onehot_sequence(seq)
dihedral_indices = get_dihedral_indices(onehot)
if return_onehot:
return _get_chi_angles(coords, dihedral_indices), onehot
return _get_chi_angles(coords, dihedral_indices)
def test_get_chi_angles(print_chi_angles=False):
# need internet connection of '6w70.pdb' in working directory
pdb = pr.parsePDB('6w70')
prody_pdb = pdb.select('chain A')
chi_angles = get_chi_angles(prody_pdb)
assert chi_angles.shape == (prody_pdb.ca.numAtoms(), 4)
assert chi_angles[0,0] < 56.0 and chi_angles[0,0] > 55.0
print('test_get_chi_angles passed')
try:
os.remove('6w70.pdb.gz')
except:
pass
if print_chi_angles:
print(chi_angles)
return True
if __name__ == '__main__':
test_get_chi_angles(print_chi_angles=True)