Spaces:
Sleeping
Sleeping
# From Nick Polizzi | |
import numpy as np | |
from collections import defaultdict | |
import prody as pr | |
import os | |
from datasets.constants import chi, atom_order, aa_long2short, aa_short2aa_idx, aa_idx2aa_short | |
def get_dihedral_indices(resname, chi_num): | |
"""Return the atom indices for the specified dihedral angle. | |
""" | |
if resname not in chi: | |
return np.array([np.nan]*4) | |
if chi_num not in chi[resname]: | |
return np.array([np.nan]*4) | |
return np.array([atom_order[resname].index(x) for x in chi[resname][chi_num]]) | |
dihedral_indices = defaultdict(list) | |
for aa in atom_order.keys(): | |
for i in range(1, 5): | |
inds = get_dihedral_indices(aa, i) | |
dihedral_indices[aa].append(inds) | |
dihedral_indices[aa] = np.array(dihedral_indices[aa]) | |
def vector_batch(a, b): | |
return a - b | |
def unit_vector_batch(v): | |
return v / np.linalg.norm(v, axis=1, keepdims=True) | |
def dihedral_angle_batch(p): | |
b0 = vector_batch(p[:, 0], p[:, 1]) | |
b1 = vector_batch(p[:, 1], p[:, 2]) | |
b2 = vector_batch(p[:, 2], p[:, 3]) | |
n1 = np.cross(b0, b1) | |
n2 = np.cross(b1, b2) | |
m1 = np.cross(n1, b1 / np.linalg.norm(b1, axis=1, keepdims=True)) | |
x = np.sum(n1 * n2, axis=1) | |
y = np.sum(m1 * n2, axis=1) | |
deg = np.degrees(np.arctan2(y, x)) | |
deg[deg < 0] += 360 | |
return deg | |
def batch_compute_dihedral_angles(sidechains): | |
sidechains_np = np.array(sidechains) | |
dihedral_angles = dihedral_angle_batch(sidechains_np) | |
return dihedral_angles | |
def get_coords(prody_pdb): | |
resindices = sorted(set(prody_pdb.ca.getResindices())) | |
coords = np.full((len(resindices), 14, 3), np.nan) | |
for i, resind in enumerate(resindices): | |
sel = prody_pdb.select(f'resindex {resind}') | |
resname = sel.getResnames()[0] | |
for j, name in enumerate(atom_order[aa_long2short[resname] if resname in aa_long2short else 'X']): | |
sel_resnum_name = sel.select(f'name {name}') | |
if sel_resnum_name is not None: | |
coords[i, j, :] = sel_resnum_name.getCoords()[0] | |
else: | |
coords[i, j, :] = [np.nan, np.nan, np.nan] | |
return coords | |
def get_onehot_sequence(seq): | |
onehot = np.zeros((len(seq), 20)) | |
for i, aa in enumerate(seq): | |
idx = aa_short2aa_idx[aa] if aa in aa_short2aa_idx else 7 # 7 is the index for GLY | |
onehot[i, idx] = 1 | |
return onehot | |
def get_dihedral_indices(onehot_sequence): | |
return np.array([dihedral_indices[aa_idx2aa_short[aa_idx]] for aa_idx in np.where(onehot_sequence)[1]]) | |
def _get_chi_angles(coords, indices): | |
X = coords | |
Y = indices.astype(int) | |
N = coords.shape[0] | |
mask = np.isnan(indices) | |
Y[mask] = 0 | |
Z = X[np.arange(N)[:, None, None], Y, :] | |
Z[mask] = np.nan | |
chi_angles = batch_compute_dihedral_angles(Z.reshape(-1, 4, 3)).reshape(N, 4) | |
return chi_angles | |
def get_chi_angles(coords, seq, return_onehot=False): | |
""" | |
Parameters | |
---------- | |
prody_pdb : prody.AtomGroup | |
prody pdb object or selection | |
return_coords : bool, optional | |
return coordinates of prody_pdb in (N, 14, 3) array format, by default False | |
return_onehot : bool, optional | |
return one-hot sequence of prody_pdb, by default False | |
Returns | |
------- | |
numpy array of shape (N, 4) | |
Array contains chi angles of sidechains in row-order of residue indices in prody_pdb. | |
If a chi angle is not defined for a residue, due to missing atoms or GLY / ALA, it is set to np.nan. | |
""" | |
onehot = get_onehot_sequence(seq) | |
dihedral_indices = get_dihedral_indices(onehot) | |
if return_onehot: | |
return _get_chi_angles(coords, dihedral_indices), onehot | |
return _get_chi_angles(coords, dihedral_indices) | |
def test_get_chi_angles(print_chi_angles=False): | |
# need internet connection of '6w70.pdb' in working directory | |
pdb = pr.parsePDB('6w70') | |
prody_pdb = pdb.select('chain A') | |
chi_angles = get_chi_angles(prody_pdb) | |
assert chi_angles.shape == (prody_pdb.ca.numAtoms(), 4) | |
assert chi_angles[0,0] < 56.0 and chi_angles[0,0] > 55.0 | |
print('test_get_chi_angles passed') | |
try: | |
os.remove('6w70.pdb.gz') | |
except: | |
pass | |
if print_chi_angles: | |
print(chi_angles) | |
return True | |
if __name__ == '__main__': | |
test_get_chi_angles(print_chi_angles=True) | |