P2DFlow / data /all_atom.py
Holmes
test
ca7299e
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for calculating all atom representations.
Code adapted from OpenFold.
"""
import torch
from openfold.data import data_transforms
from openfold.np import residue_constants
from openfold.utils import rigid_utils as ru
from data import utils as du
Rigid = ru.Rigid
Rotation = ru.Rotation
# Residue Constants from OpenFold/AlphaFold2.
IDEALIZED_POS = torch.tensor(residue_constants.restype_atom14_rigid_group_positions)
DEFAULT_FRAMES = torch.tensor(residue_constants.restype_rigid_group_default_frame)
ATOM_MASK = torch.tensor(residue_constants.restype_atom14_mask)
GROUP_IDX = torch.tensor(residue_constants.restype_atom14_to_rigid_group)
IDEALIZED_POS_37 = torch.tensor(residue_constants.restype_atom37_rigid_group_positions)
ATOM_MASK_37 = torch.tensor(residue_constants.restype_atom37_mask)
GROUP_IDX_37 = torch.tensor(residue_constants.restype_atom37_to_rigid_group)
def to_atom37(trans, rots, aatype=None, torsions_with_CB=None, get_mask=False):
num_batch, num_res, _ = trans.shape
if torsions_with_CB is None: # (B,L,)
torsions_with_CB = torch.concat(
[torch.zeros((num_batch,num_res,8,1),device=trans.device),
torch.ones((num_batch,num_res,8,1),device=trans.device)],
dim=-1
) # (B,L,8,2)
# final_atom37 = compute_backbone(
# du.create_rigid(rots, trans),
# torch.zeros(num_batch, num_res, 2, device=trans.device)
# )[0]
final_atom37, atom37_mask = compute_atom37_pos(
du.create_rigid(rots, trans),
torsions_with_CB,
aatype=aatype,
)[:2] # (B,L,37,3)
if get_mask:
return final_atom37, atom37_mask
else:
return final_atom37
def torsion_angles_to_frames(
r: Rigid, # type: ignore [valid-type]
alpha: torch.Tensor,
aatype: torch.Tensor,
bb_rot = None
):
"""Conversion method of torsion angles to frames provided the backbone.
Args:
r: Backbone rigid groups.
alpha: Torsion angles. (B,L,7,2)
aatype: residue types.
Returns:
All 8 frames corresponding to each torsion frame.
!!! May need to set omega and fai angle to be zero !!!
"""
# [*, N, 8, 4, 4]
with torch.no_grad():
default_4x4 = DEFAULT_FRAMES.to(aatype.device)[aatype, ...] # type: ignore [attr-defined]
# [*, N, 8] transformations, i.e.
# One [*, N, 8, 3, 3] rotation matrix and
# One [*, N, 8, 3] translation matrix
default_r = r.from_tensor_4x4(default_4x4) # (B,L,8)
if bb_rot is None:
bb_rot = alpha.new_zeros(((1,) * len(alpha.shape[:-1]))+(2,)) # (1,1,1,2)
bb_rot[..., 1] = 1
bb_rot = bb_rot.expand(*alpha.shape[:-2], -1, -1) # (B,L,1,2)
alpha = torch.cat([bb_rot, alpha], dim=-2) # (B,L,8,2)
# [*, N, 8, 3, 3]
# Produces rotation matrices of the form:
# [
# [1, 0 , 0 ],
# [0, a_2,-a_1],
# [0, a_1, a_2]
# ]
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) # (B,L,8,3,3)
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_r.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
chi4_frame_to_frame = all_frames[..., 7]
chi1_frame_to_bb = all_frames[..., 4]
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = Rigid.cat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
chi3_frame_to_bb.unsqueeze(-1),
chi4_frame_to_bb.unsqueeze(-1),
],
dim=-1,
)
all_frames_to_global = r[..., None].compose(all_frames_to_bb) # type: ignore [index]
return all_frames_to_global
def prot_to_torsion_angles(aatype, atom37, atom37_mask):
"""Calculate torsion angle features from protein features."""
prot_feats = {
"aatype": aatype,
"all_atom_positions": atom37,
"all_atom_mask": atom37_mask,
}
torsion_angles_feats = data_transforms.atom37_to_torsion_angles()(prot_feats)
torsion_angles = torsion_angles_feats["torsion_angles_sin_cos"]
torsion_mask = torsion_angles_feats["torsion_angles_mask"]
return torsion_angles, torsion_mask
def frames_to_atom14_pos(
r: Rigid, # type: ignore [valid-type]
aatype: torch.Tensor,
):
"""Convert frames to their idealized all atom representation.
Args:
r: All rigid groups. [B,L,8]
aatype: Residue types. [B,L]
Returns:
"""
with torch.no_grad():
group_mask = GROUP_IDX.to(aatype.device)[aatype, ...] # (B,L,14)
group_mask = torch.nn.functional.one_hot(
group_mask,
num_classes=DEFAULT_FRAMES.shape[-3], # (21,8,4,4)
) # (B,L,14,8)
frame_atom_mask = ATOM_MASK.to(aatype.device)[aatype, ...].unsqueeze(-1) # (B,L,14,1)
frame_null_pos = IDEALIZED_POS.to(aatype.device)[aatype, ...] # (B,L,14,3)
t_atoms_to_global = r[..., None, :] * group_mask # (B,L,14,8)
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) # (B,L,14)
pred_positions = t_atoms_to_global.apply(frame_null_pos) # (B,L,14,3)
pred_positions = pred_positions * frame_atom_mask
return pred_positions
def frames_to_atom37_pos(
r: Rigid, # type: ignore [valid-type]
aatype: torch.Tensor,
):
"""Convert frames to their idealized all atom representation.
Args:
r: All rigid groups. [B,L]
aatype: Residue types. [B,L]
Returns:
"""
with torch.no_grad():
group_mask = GROUP_IDX_37.to(aatype.device)[aatype, ...] # (B,L,37)
group_mask = torch.nn.functional.one_hot(
group_mask,
num_classes=DEFAULT_FRAMES.shape[-3], # (21,8,4,4)
) # (B,L,37,8)
frame_atom_mask = ATOM_MASK_37.to(aatype.device)[aatype, ...].unsqueeze(-1) # (B,L,37,1)
frame_null_pos = IDEALIZED_POS_37.to(aatype.device)[aatype, ...] # (B,L,37,3)
t_atoms_to_global = r[..., None, :] * group_mask # (B,L,37,8)
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) # (B,L,37)
pred_positions = t_atoms_to_global.apply(frame_null_pos) # (B,L,37,3)
pred_positions = pred_positions * frame_atom_mask
return pred_positions, frame_atom_mask[...,0]
#! may need to remove it
def compute_backbone(bb_rigids, psi_torsions, aatype=None):
torsion_angles = torch.tile(
psi_torsions[..., None, :], tuple([1 for _ in range(len(bb_rigids.shape))]) + (7, 1)
)
'''
psi_torsions[..., None, :].shape: (B,L,1,2)
torsion_angles.shape: (B,L,7,2)
bb_rigids.shape: (B,L)
'''
if aatype is None:
aatype = torch.zeros(bb_rigids.shape, device=bb_rigids.device).long()
all_frames = torsion_angles_to_frames(
bb_rigids,
torsion_angles,
aatype,
)
atom14_pos = frames_to_atom14_pos(all_frames, aatype) # (B,L,14,3)
atom37_bb_pos = torch.zeros(bb_rigids.shape + (37, 3), device=bb_rigids.device) # (B,L,37,3)
# atom14 bb order = ['N', 'CA', 'C', 'O', 'CB']
# atom37 bb order = ['N', 'CA', 'C', 'CB', 'O']
atom37_bb_pos[..., :3, :] = atom14_pos[..., :3, :]
atom37_mask = torch.any(atom37_bb_pos, axis=-1) # mask atom with all 0 xyz
return atom37_bb_pos, atom37_mask, aatype, atom14_pos
def compute_atom37_pos(bb_rigids, torsions_with_CB, aatype=None):
'''
torsions_with_CB.shape: (B,L,8,2)
bb_rigids.shape: (B,L)
'''
if aatype is None:
aatype = torch.zeros(bb_rigids.shape, device=bb_rigids.device).long()
all_frames = torsion_angles_to_frames(
bb_rigids,
torsions_with_CB[:,:,1:,:],
aatype,
bb_rot = torsions_with_CB[:,:,0:1,:],
)
atom14_pos = frames_to_atom14_pos(all_frames, aatype) # (B,L,14,3)
atom37_pos,atom37_mask = frames_to_atom37_pos(all_frames, aatype) # (B,L,37,3)
return atom37_pos, atom37_mask, aatype, atom14_pos
def calculate_neighbor_angles(R_ac, R_ab):
"""Calculate angles between atoms c <- a -> b.
Parameters
----------
R_ac: Tensor, shape = (N,3)
Vector from atom a to c.
R_ab: Tensor, shape = (N,3)
Vector from atom a to b.
Returns
-------
angle_cab: Tensor, shape = (N,)
Angle between atoms c <- a -> b.
"""
# cos(alpha) = (u * v) / (|u|*|v|)
x = torch.sum(R_ac * R_ab, dim=1) # shape = (N,)
# sin(alpha) = |u x v| / (|u|*|v|)
y = torch.cross(R_ac, R_ab).norm(dim=-1) # shape = (N,)
# avoid that for y == (0,0,0) the gradient wrt. y becomes NaN
y = torch.max(y, torch.tensor(1e-9))
angle = torch.atan2(y, x)
return angle
def vector_projection(R_ab, P_n):
"""
Project the vector R_ab onto a plane with normal vector P_n.
Parameters
----------
R_ab: Tensor, shape = (N,3)
Vector from atom a to b.
P_n: Tensor, shape = (N,3)
Normal vector of a plane onto which to project R_ab.
Returns
-------
R_ab_proj: Tensor, shape = (N,3)
Projected vector (orthogonal to P_n).
"""
a_x_b = torch.sum(R_ab * P_n, dim=-1)
b_x_b = torch.sum(P_n * P_n, dim=-1)
return R_ab - (a_x_b / b_x_b)[:, None] * P_n
def transrot_to_atom37(transrot_traj, res_mask, aatype=None, torsions_with_CB=None):
atom37_traj = []
res_mask = res_mask.detach().cpu()
num_batch = res_mask.shape[0]
for trans, rots in transrot_traj:
atom37 = to_atom37(trans, rots, aatype=aatype, torsions_with_CB=torsions_with_CB,get_mask=False)
atom37 = atom37.detach().cpu()
# batch_atom37 = []
# for i in range(num_batch):
# batch_atom37.append(
# du.adjust_oxygen_pos(atom37[i], res_mask[i])
# )
# atom37_traj.append(torch.stack(batch_atom37))
atom37_traj.append(atom37)
return atom37_traj
# def atom37_from_trans_rot(trans, rots, res_mask):
# rigids = du.create_rigid(rots, trans)
# atom37 = compute_backbone(
# rigids,
# torch.zeros(
# trans.shape[0],
# trans.shape[1],
# 2,
# device=trans.device
# )
# )[0]
# atom37 = atom37.detach().cpu()
# batch_atom37 = []
# num_batch = res_mask.shape[0]
# for i in range(num_batch):
# batch_atom37.append(
# du.adjust_oxygen_pos(atom37[i], res_mask[i])
# )
# return torch.stack(batch_atom37)
# def process_trans_rot_traj(trans_traj, rots_traj, res_mask):
# res_mask = res_mask.detach().cpu()
# atom37_traj = [
# atom37_from_trans_rot(trans, rots, res_mask)
# for trans, rots in zip(trans_traj, rots_traj)
# ]
# atom37_traj = torch.stack(atom37_traj).swapaxes(0, 1)
# return atom37_traj