|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Utilities for calculating all atom representations. |
|
Code adapted from OpenFold. |
|
""" |
|
|
|
import torch |
|
from openfold.data import data_transforms |
|
from openfold.np import residue_constants |
|
from openfold.utils import rigid_utils as ru |
|
from data import utils as du |
|
|
|
Rigid = ru.Rigid |
|
Rotation = ru.Rotation |
|
|
|
|
|
|
|
|
|
IDEALIZED_POS = torch.tensor(residue_constants.restype_atom14_rigid_group_positions) |
|
DEFAULT_FRAMES = torch.tensor(residue_constants.restype_rigid_group_default_frame) |
|
ATOM_MASK = torch.tensor(residue_constants.restype_atom14_mask) |
|
GROUP_IDX = torch.tensor(residue_constants.restype_atom14_to_rigid_group) |
|
|
|
IDEALIZED_POS_37 = torch.tensor(residue_constants.restype_atom37_rigid_group_positions) |
|
ATOM_MASK_37 = torch.tensor(residue_constants.restype_atom37_mask) |
|
GROUP_IDX_37 = torch.tensor(residue_constants.restype_atom37_to_rigid_group) |
|
|
|
def to_atom37(trans, rots, aatype=None, torsions_with_CB=None, get_mask=False): |
|
num_batch, num_res, _ = trans.shape |
|
|
|
if torsions_with_CB is None: |
|
torsions_with_CB = torch.concat( |
|
[torch.zeros((num_batch,num_res,8,1),device=trans.device), |
|
torch.ones((num_batch,num_res,8,1),device=trans.device)], |
|
dim=-1 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
final_atom37, atom37_mask = compute_atom37_pos( |
|
du.create_rigid(rots, trans), |
|
torsions_with_CB, |
|
aatype=aatype, |
|
)[:2] |
|
|
|
if get_mask: |
|
return final_atom37, atom37_mask |
|
else: |
|
return final_atom37 |
|
|
|
def torsion_angles_to_frames( |
|
r: Rigid, |
|
alpha: torch.Tensor, |
|
aatype: torch.Tensor, |
|
bb_rot = None |
|
): |
|
"""Conversion method of torsion angles to frames provided the backbone. |
|
|
|
Args: |
|
r: Backbone rigid groups. |
|
alpha: Torsion angles. (B,L,7,2) |
|
aatype: residue types. |
|
|
|
Returns: |
|
All 8 frames corresponding to each torsion frame. |
|
|
|
|
|
|
|
|
|
|
|
!!! May need to set omega and fai angle to be zero !!! |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
with torch.no_grad(): |
|
default_4x4 = DEFAULT_FRAMES.to(aatype.device)[aatype, ...] |
|
|
|
|
|
|
|
|
|
default_r = r.from_tensor_4x4(default_4x4) |
|
|
|
if bb_rot is None: |
|
bb_rot = alpha.new_zeros(((1,) * len(alpha.shape[:-1]))+(2,)) |
|
bb_rot[..., 1] = 1 |
|
bb_rot = bb_rot.expand(*alpha.shape[:-2], -1, -1) |
|
|
|
alpha = torch.cat([bb_rot, alpha], dim=-2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) |
|
all_rots[..., 0, 0] = 1 |
|
all_rots[..., 1, 1] = alpha[..., 1] |
|
all_rots[..., 1, 2] = -alpha[..., 0] |
|
all_rots[..., 2, 1:] = alpha |
|
|
|
all_rots = Rigid(Rotation(rot_mats=all_rots), None) |
|
|
|
all_frames = default_r.compose(all_rots) |
|
|
|
chi2_frame_to_frame = all_frames[..., 5] |
|
chi3_frame_to_frame = all_frames[..., 6] |
|
chi4_frame_to_frame = all_frames[..., 7] |
|
|
|
chi1_frame_to_bb = all_frames[..., 4] |
|
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) |
|
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) |
|
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) |
|
|
|
all_frames_to_bb = Rigid.cat( |
|
[ |
|
all_frames[..., :5], |
|
chi2_frame_to_bb.unsqueeze(-1), |
|
chi3_frame_to_bb.unsqueeze(-1), |
|
chi4_frame_to_bb.unsqueeze(-1), |
|
], |
|
dim=-1, |
|
) |
|
|
|
all_frames_to_global = r[..., None].compose(all_frames_to_bb) |
|
|
|
return all_frames_to_global |
|
|
|
|
|
def prot_to_torsion_angles(aatype, atom37, atom37_mask): |
|
"""Calculate torsion angle features from protein features.""" |
|
prot_feats = { |
|
"aatype": aatype, |
|
"all_atom_positions": atom37, |
|
"all_atom_mask": atom37_mask, |
|
} |
|
torsion_angles_feats = data_transforms.atom37_to_torsion_angles()(prot_feats) |
|
torsion_angles = torsion_angles_feats["torsion_angles_sin_cos"] |
|
torsion_mask = torsion_angles_feats["torsion_angles_mask"] |
|
return torsion_angles, torsion_mask |
|
|
|
|
|
def frames_to_atom14_pos( |
|
r: Rigid, |
|
aatype: torch.Tensor, |
|
): |
|
"""Convert frames to their idealized all atom representation. |
|
|
|
Args: |
|
r: All rigid groups. [B,L,8] |
|
aatype: Residue types. [B,L] |
|
|
|
Returns: |
|
|
|
""" |
|
with torch.no_grad(): |
|
group_mask = GROUP_IDX.to(aatype.device)[aatype, ...] |
|
group_mask = torch.nn.functional.one_hot( |
|
group_mask, |
|
num_classes=DEFAULT_FRAMES.shape[-3], |
|
) |
|
frame_atom_mask = ATOM_MASK.to(aatype.device)[aatype, ...].unsqueeze(-1) |
|
frame_null_pos = IDEALIZED_POS.to(aatype.device)[aatype, ...] |
|
|
|
t_atoms_to_global = r[..., None, :] * group_mask |
|
|
|
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) |
|
|
|
pred_positions = t_atoms_to_global.apply(frame_null_pos) |
|
pred_positions = pred_positions * frame_atom_mask |
|
|
|
return pred_positions |
|
|
|
def frames_to_atom37_pos( |
|
r: Rigid, |
|
aatype: torch.Tensor, |
|
): |
|
"""Convert frames to their idealized all atom representation. |
|
|
|
Args: |
|
r: All rigid groups. [B,L] |
|
aatype: Residue types. [B,L] |
|
|
|
Returns: |
|
|
|
""" |
|
with torch.no_grad(): |
|
group_mask = GROUP_IDX_37.to(aatype.device)[aatype, ...] |
|
group_mask = torch.nn.functional.one_hot( |
|
group_mask, |
|
num_classes=DEFAULT_FRAMES.shape[-3], |
|
) |
|
frame_atom_mask = ATOM_MASK_37.to(aatype.device)[aatype, ...].unsqueeze(-1) |
|
frame_null_pos = IDEALIZED_POS_37.to(aatype.device)[aatype, ...] |
|
|
|
t_atoms_to_global = r[..., None, :] * group_mask |
|
|
|
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) |
|
|
|
pred_positions = t_atoms_to_global.apply(frame_null_pos) |
|
pred_positions = pred_positions * frame_atom_mask |
|
|
|
return pred_positions, frame_atom_mask[...,0] |
|
|
|
|
|
def compute_backbone(bb_rigids, psi_torsions, aatype=None): |
|
torsion_angles = torch.tile( |
|
psi_torsions[..., None, :], tuple([1 for _ in range(len(bb_rigids.shape))]) + (7, 1) |
|
) |
|
''' |
|
psi_torsions[..., None, :].shape: (B,L,1,2) |
|
torsion_angles.shape: (B,L,7,2) |
|
bb_rigids.shape: (B,L) |
|
''' |
|
|
|
if aatype is None: |
|
aatype = torch.zeros(bb_rigids.shape, device=bb_rigids.device).long() |
|
|
|
all_frames = torsion_angles_to_frames( |
|
bb_rigids, |
|
torsion_angles, |
|
aatype, |
|
) |
|
atom14_pos = frames_to_atom14_pos(all_frames, aatype) |
|
atom37_bb_pos = torch.zeros(bb_rigids.shape + (37, 3), device=bb_rigids.device) |
|
|
|
|
|
|
|
atom37_bb_pos[..., :3, :] = atom14_pos[..., :3, :] |
|
atom37_mask = torch.any(atom37_bb_pos, axis=-1) |
|
|
|
return atom37_bb_pos, atom37_mask, aatype, atom14_pos |
|
|
|
def compute_atom37_pos(bb_rigids, torsions_with_CB, aatype=None): |
|
''' |
|
torsions_with_CB.shape: (B,L,8,2) |
|
bb_rigids.shape: (B,L) |
|
''' |
|
|
|
if aatype is None: |
|
aatype = torch.zeros(bb_rigids.shape, device=bb_rigids.device).long() |
|
|
|
all_frames = torsion_angles_to_frames( |
|
bb_rigids, |
|
torsions_with_CB[:,:,1:,:], |
|
aatype, |
|
bb_rot = torsions_with_CB[:,:,0:1,:], |
|
) |
|
|
|
atom14_pos = frames_to_atom14_pos(all_frames, aatype) |
|
atom37_pos,atom37_mask = frames_to_atom37_pos(all_frames, aatype) |
|
|
|
return atom37_pos, atom37_mask, aatype, atom14_pos |
|
|
|
def calculate_neighbor_angles(R_ac, R_ab): |
|
"""Calculate angles between atoms c <- a -> b. |
|
|
|
Parameters |
|
---------- |
|
R_ac: Tensor, shape = (N,3) |
|
Vector from atom a to c. |
|
R_ab: Tensor, shape = (N,3) |
|
Vector from atom a to b. |
|
|
|
Returns |
|
------- |
|
angle_cab: Tensor, shape = (N,) |
|
Angle between atoms c <- a -> b. |
|
""" |
|
|
|
x = torch.sum(R_ac * R_ab, dim=1) |
|
|
|
y = torch.cross(R_ac, R_ab).norm(dim=-1) |
|
|
|
y = torch.max(y, torch.tensor(1e-9)) |
|
angle = torch.atan2(y, x) |
|
return angle |
|
|
|
|
|
def vector_projection(R_ab, P_n): |
|
""" |
|
Project the vector R_ab onto a plane with normal vector P_n. |
|
|
|
Parameters |
|
---------- |
|
R_ab: Tensor, shape = (N,3) |
|
Vector from atom a to b. |
|
P_n: Tensor, shape = (N,3) |
|
Normal vector of a plane onto which to project R_ab. |
|
|
|
Returns |
|
------- |
|
R_ab_proj: Tensor, shape = (N,3) |
|
Projected vector (orthogonal to P_n). |
|
""" |
|
a_x_b = torch.sum(R_ab * P_n, dim=-1) |
|
b_x_b = torch.sum(P_n * P_n, dim=-1) |
|
return R_ab - (a_x_b / b_x_b)[:, None] * P_n |
|
|
|
|
|
def transrot_to_atom37(transrot_traj, res_mask, aatype=None, torsions_with_CB=None): |
|
atom37_traj = [] |
|
res_mask = res_mask.detach().cpu() |
|
num_batch = res_mask.shape[0] |
|
|
|
for trans, rots in transrot_traj: |
|
atom37 = to_atom37(trans, rots, aatype=aatype, torsions_with_CB=torsions_with_CB,get_mask=False) |
|
atom37 = atom37.detach().cpu() |
|
|
|
|
|
|
|
|
|
|
|
|
|
atom37_traj.append(atom37) |
|
return atom37_traj |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|