Spaces:
Sleeping
Sleeping
# Copyright 2021 AlQuraishi Laboratory | |
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from typing import Dict, Union | |
from dockformerpp.utils import protein | |
import dockformerpp.utils.residue_constants as rc | |
from dockformerpp.utils.geometry import rigid_matrix_vector, rotation_matrix, vector | |
from dockformerpp.utils.rigid_utils import Rotation, Rigid | |
from dockformerpp.utils.tensor_utils import ( | |
batched_gather, | |
one_hot, | |
tree_map, | |
tensor_tree_map, | |
) | |
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): | |
# rc.restype_order["X"] defines a ligand, and the atom position used is the CA | |
is_gly_or_lig = (aatype == rc.restype_order["G"]) | (aatype == rc.restype_order["Z"]) | |
ca_idx = rc.atom_order["CA"] | |
cb_idx = rc.atom_order["CB"] | |
pseudo_beta = torch.where( | |
is_gly_or_lig[..., None].expand(*((-1,) * len(is_gly_or_lig.shape)), 3), | |
all_atom_positions[..., ca_idx, :], | |
all_atom_positions[..., cb_idx, :], | |
) | |
if all_atom_masks is not None: | |
pseudo_beta_mask = torch.where( | |
is_gly_or_lig, | |
all_atom_masks[..., ca_idx], | |
all_atom_masks[..., cb_idx], | |
) | |
return pseudo_beta, pseudo_beta_mask | |
else: | |
return pseudo_beta | |
def atom14_to_atom37(atom14, batch): | |
atom37_data = batched_gather( | |
atom14, | |
batch["residx_atom37_to_atom14"], | |
dim=-2, | |
no_batch_dims=len(atom14.shape[:-2]), | |
) | |
atom37_data = atom37_data * batch["atom37_atom_exists"][..., None] | |
return atom37_data | |
def torsion_angles_to_frames( | |
r: Union[Rigid, rigid_matrix_vector.Rigid3Array], | |
alpha: torch.Tensor, | |
aatype: torch.Tensor, | |
rrgdf: torch.Tensor, | |
): | |
rigid_type = type(r) | |
# [*, N, 8, 4, 4] | |
default_4x4 = rrgdf[aatype, ...] | |
# [*, N, 8] transformations, i.e. | |
# One [*, N, 8, 3, 3] rotation matrix and | |
# One [*, N, 8, 3] translation matrix | |
default_r = rigid_type.from_tensor_4x4(default_4x4) | |
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) | |
bb_rot[..., 1] = 1 | |
# [*, N, 8, 2] | |
alpha = torch.cat( | |
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2 | |
) | |
# [*, N, 8, 3, 3] | |
# Produces rotation matrices of the form: | |
# [ | |
# [1, 0 , 0 ], | |
# [0, a_2,-a_1], | |
# [0, a_1, a_2] | |
# ] | |
# This follows the original code rather than the supplement, which uses | |
# different indices. | |
all_rots = alpha.new_zeros(default_r.shape + (4, 4)) | |
all_rots[..., 0, 0] = 1 | |
all_rots[..., 1, 1] = alpha[..., 1] | |
all_rots[..., 1, 2] = -alpha[..., 0] | |
all_rots[..., 2, 1:3] = alpha | |
all_rots = rigid_type.from_tensor_4x4(all_rots) | |
all_frames = default_r.compose(all_rots) | |
chi2_frame_to_frame = all_frames[..., 5] | |
chi3_frame_to_frame = all_frames[..., 6] | |
chi4_frame_to_frame = all_frames[..., 7] | |
chi1_frame_to_bb = all_frames[..., 4] | |
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) | |
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) | |
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) | |
all_frames_to_bb = rigid_type.cat( | |
[ | |
all_frames[..., :5], | |
chi2_frame_to_bb.unsqueeze(-1), | |
chi3_frame_to_bb.unsqueeze(-1), | |
chi4_frame_to_bb.unsqueeze(-1), | |
], | |
dim=-1, | |
) | |
all_frames_to_global = r[..., None].compose(all_frames_to_bb) | |
return all_frames_to_global | |
def frames_and_literature_positions_to_atom14_pos( | |
r: Union[Rigid, rigid_matrix_vector.Rigid3Array], | |
aatype: torch.Tensor, | |
default_frames, | |
group_idx, | |
atom_mask, | |
lit_positions, | |
): | |
# [*, N, 14, 4, 4] | |
default_4x4 = default_frames[aatype, ...] | |
# [*, N, 14] | |
group_mask = group_idx[aatype, ...] | |
# [*, N, 14, 8] | |
group_mask = nn.functional.one_hot( | |
group_mask, | |
num_classes=default_frames.shape[-3], | |
) | |
# [*, N, 14, 8] | |
t_atoms_to_global = r[..., None, :] * group_mask | |
# [*, N, 14] | |
t_atoms_to_global = t_atoms_to_global.map_tensor_fn( | |
lambda x: torch.sum(x, dim=-1) | |
) | |
# [*, N, 14] | |
atom_mask = atom_mask[aatype, ...].unsqueeze(-1) | |
# [*, N, 14, 3] | |
lit_positions = lit_positions[aatype, ...] | |
pred_positions = t_atoms_to_global.apply(lit_positions) | |
pred_positions = pred_positions * atom_mask | |
return pred_positions | |