Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
from diffab.utils.protein.constants import ( | |
BBHeavyAtom, | |
backbone_atom_coordinates_tensor, | |
bb_oxygen_coordinate_tensor, | |
) | |
from .topology import get_terminus_flag | |
def safe_norm(x, dim=-1, keepdim=False, eps=1e-8, sqrt=True): | |
out = torch.clamp(torch.sum(torch.square(x), dim=dim, keepdim=keepdim), min=eps) | |
return torch.sqrt(out) if sqrt else out | |
def pairwise_distances(x, y=None, return_v=False): | |
""" | |
Args: | |
x: (B, N, d) | |
y: (B, M, d) | |
""" | |
if y is None: y = x | |
v = x.unsqueeze(2) - y.unsqueeze(1) # (B, N, M, d) | |
d = safe_norm(v, dim=-1) | |
if return_v: | |
return d, v | |
else: | |
return d | |
def normalize_vector(v, dim, eps=1e-6): | |
return v / (torch.linalg.norm(v, ord=2, dim=dim, keepdim=True) + eps) | |
def project_v2v(v, e, dim): | |
""" | |
Description: | |
Project vector `v` onto vector `e`. | |
Args: | |
v: (N, L, 3). | |
e: (N, L, 3). | |
""" | |
return (e * v).sum(dim=dim, keepdim=True) * e | |
def construct_3d_basis(center, p1, p2): | |
""" | |
Args: | |
center: (N, L, 3), usually the position of C_alpha. | |
p1: (N, L, 3), usually the position of C. | |
p2: (N, L, 3), usually the position of N. | |
Returns | |
A batch of orthogonal basis matrix, (N, L, 3, 3cols_index). | |
The matrix is composed of 3 column vectors: [e1, e2, e3]. | |
""" | |
v1 = p1 - center # (N, L, 3) | |
e1 = normalize_vector(v1, dim=-1) | |
v2 = p2 - center # (N, L, 3) | |
u2 = v2 - project_v2v(v2, e1, dim=-1) | |
e2 = normalize_vector(u2, dim=-1) | |
e3 = torch.cross(e1, e2, dim=-1) # (N, L, 3) | |
mat = torch.cat([ | |
e1.unsqueeze(-1), e2.unsqueeze(-1), e3.unsqueeze(-1) | |
], dim=-1) # (N, L, 3, 3_index) | |
return mat | |
def local_to_global(R, t, p): | |
""" | |
Description: | |
Convert local (internal) coordinates to global (external) coordinates q. | |
q <- Rp + t | |
Args: | |
R: (N, L, 3, 3). | |
t: (N, L, 3). | |
p: Local coordinates, (N, L, ..., 3). | |
Returns: | |
q: Global coordinates, (N, L, ..., 3). | |
""" | |
assert p.size(-1) == 3 | |
p_size = p.size() | |
N, L = p_size[0], p_size[1] | |
p = p.view(N, L, -1, 3).transpose(-1, -2) # (N, L, *, 3) -> (N, L, 3, *) | |
q = torch.matmul(R, p) + t.unsqueeze(-1) # (N, L, 3, *) | |
q = q.transpose(-1, -2).reshape(p_size) # (N, L, 3, *) -> (N, L, *, 3) -> (N, L, ..., 3) | |
return q | |
def global_to_local(R, t, q): | |
""" | |
Description: | |
Convert global (external) coordinates q to local (internal) coordinates p. | |
p <- R^{T}(q - t) | |
Args: | |
R: (N, L, 3, 3). | |
t: (N, L, 3). | |
q: Global coordinates, (N, L, ..., 3). | |
Returns: | |
p: Local coordinates, (N, L, ..., 3). | |
""" | |
assert q.size(-1) == 3 | |
q_size = q.size() | |
N, L = q_size[0], q_size[1] | |
q = q.reshape(N, L, -1, 3).transpose(-1, -2) # (N, L, *, 3) -> (N, L, 3, *) | |
p = torch.matmul(R.transpose(-1, -2), (q - t.unsqueeze(-1))) # (N, L, 3, *) | |
p = p.transpose(-1, -2).reshape(q_size) # (N, L, 3, *) -> (N, L, *, 3) -> (N, L, ..., 3) | |
return p | |
def apply_rotation_to_vector(R, p): | |
return local_to_global(R, torch.zeros_like(p), p) | |
def compose_rotation_and_translation(R1, t1, R2, t2): | |
""" | |
Args: | |
R1,t1: Frame basis and coordinate, (N, L, 3, 3), (N, L, 3). | |
R2,t2: Rotation and translation to be applied to (R1, t1), (N, L, 3, 3), (N, L, 3). | |
Returns | |
R_new <- R1R2 | |
t_new <- R1t2 + t1 | |
""" | |
R_new = torch.matmul(R1, R2) # (N, L, 3, 3) | |
t_new = torch.matmul(R1, t2.unsqueeze(-1)).squeeze(-1) + t1 | |
return R_new, t_new | |
def compose_chain(Ts): | |
while len(Ts) >= 2: | |
R1, t1 = Ts[-2] | |
R2, t2 = Ts[-1] | |
T_next = compose_rotation_and_translation(R1, t1, R2, t2) | |
Ts = Ts[:-2] + [T_next] | |
return Ts[0] | |
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
def quaternion_to_rotation_matrix(quaternions): | |
""" | |
Convert rotations given as quaternions to rotation matrices. | |
Args: | |
quaternions: quaternions with real part first, | |
as tensor of shape (..., 4). | |
Returns: | |
Rotation matrices as tensor of shape (..., 3, 3). | |
""" | |
quaternions = F.normalize(quaternions, dim=-1) | |
r, i, j, k = torch.unbind(quaternions, -1) | |
two_s = 2.0 / (quaternions * quaternions).sum(-1) | |
o = torch.stack( | |
( | |
1 - two_s * (j * j + k * k), | |
two_s * (i * j - k * r), | |
two_s * (i * k + j * r), | |
two_s * (i * j + k * r), | |
1 - two_s * (i * i + k * k), | |
two_s * (j * k - i * r), | |
two_s * (i * k - j * r), | |
two_s * (j * k + i * r), | |
1 - two_s * (i * i + j * j), | |
), | |
-1, | |
) | |
return o.reshape(quaternions.shape[:-1] + (3, 3)) | |
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
BSD License | |
For PyTorch3D software | |
Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. | |
Redistribution and use in source and binary forms, with or without modification, | |
are permitted provided that the following conditions are met: | |
* Redistributions of source code must retain the above copyright notice, this | |
list of conditions and the following disclaimer. | |
* Redistributions in binary form must reproduce the above copyright notice, | |
this list of conditions and the following disclaimer in the documentation | |
and/or other materials provided with the distribution. | |
* Neither the name Meta nor the names of its contributors may be used to | |
endorse or promote products derived from this software without specific | |
prior written permission. | |
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | |
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | |
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR | |
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | |
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | |
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON | |
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | |
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
""" | |
def quaternion_1ijk_to_rotation_matrix(q): | |
""" | |
(1 + ai + bj + ck) -> R | |
Args: | |
q: (..., 3) | |
""" | |
b, c, d = torch.unbind(q, dim=-1) | |
s = torch.sqrt(1 + b**2 + c**2 + d**2) | |
a, b, c, d = 1/s, b/s, c/s, d/s | |
o = torch.stack( | |
( | |
a**2 + b**2 - c**2 - d**2, 2*b*c - 2*a*d, 2*b*d + 2*a*c, | |
2*b*c + 2*a*d, a**2 - b**2 + c**2 - d**2, 2*c*d - 2*a*b, | |
2*b*d - 2*a*c, 2*c*d + 2*a*b, a**2 - b**2 - c**2 + d**2, | |
), | |
-1, | |
) | |
return o.reshape(q.shape[:-1] + (3, 3)) | |
def repr_6d_to_rotation_matrix(x): | |
""" | |
Args: | |
x: 6D representations, (..., 6). | |
Returns: | |
Rotation matrices, (..., 3, 3_index). | |
""" | |
a1, a2 = x[..., 0:3], x[..., 3:6] | |
b1 = normalize_vector(a1, dim=-1) | |
b2 = normalize_vector(a2 - project_v2v(a2, b1, dim=-1), dim=-1) | |
b3 = torch.cross(b1, b2, dim=-1) | |
mat = torch.cat([ | |
b1.unsqueeze(-1), b2.unsqueeze(-1), b3.unsqueeze(-1) | |
], dim=-1) # (N, L, 3, 3_index) | |
return mat | |
def dihedral_from_four_points(p0, p1, p2, p3): | |
""" | |
Args: | |
p0-3: (*, 3). | |
Returns: | |
Dihedral angles in radian, (*, ). | |
""" | |
v0 = p2 - p1 | |
v1 = p0 - p1 | |
v2 = p3 - p2 | |
u1 = torch.cross(v0, v1, dim=-1) | |
n1 = u1 / torch.linalg.norm(u1, dim=-1, keepdim=True) | |
u2 = torch.cross(v0, v2, dim=-1) | |
n2 = u2 / torch.linalg.norm(u2, dim=-1, keepdim=True) | |
sgn = torch.sign( (torch.cross(v1, v2, dim=-1) * v0).sum(-1) ) | |
dihed = sgn*torch.acos( (n1 * n2).sum(-1).clamp(min=-0.999999, max=0.999999) ) | |
dihed = torch.nan_to_num(dihed) | |
return dihed | |
def knn_gather(idx, value): | |
""" | |
Args: | |
idx: (B, N, K) | |
value: (B, M, d) | |
Returns: | |
(B, N, K, d) | |
""" | |
N, d = idx.size(1), value.size(-1) | |
idx = idx.unsqueeze(-1).repeat(1, 1, 1, d) # (B, N, K, d) | |
value = value.unsqueeze(1).repeat(1, N, 1, 1) # (B, N, M, d) | |
return torch.gather(value, dim=2, index=idx) | |
def knn_points(q, p, K): | |
""" | |
Args: | |
q: (B, M, d) | |
p: (B, N, d) | |
Returns: | |
(B, M, K), (B, M, K), (B, M, K, d) | |
""" | |
_, L, _ = p.size() | |
d = pairwise_distances(q, p) # (B, N, M) | |
dist, idx = d.topk(min(L, K), dim=-1, largest=False) # (B, M, K), (B, M, K) | |
return dist, idx, knn_gather(idx, p) | |
def angstrom_to_nm(x): | |
return x / 10 | |
def nm_to_angstrom(x): | |
return x * 10 | |
def get_backbone_dihedral_angles(pos_atoms, chain_nb, res_nb, mask): | |
""" | |
Args: | |
pos_atoms: (N, L, A, 3). | |
chain_nb: (N, L). | |
res_nb: (N, L). | |
mask: (N, L). | |
Returns: | |
bb_dihedral: Omega, Phi, and Psi angles in radian, (N, L, 3). | |
mask_bb_dihed: Masks of dihedral angles, (N, L, 3). | |
""" | |
pos_N = pos_atoms[:, :, BBHeavyAtom.N] # (N, L, 3) | |
pos_CA = pos_atoms[:, :, BBHeavyAtom.CA] | |
pos_C = pos_atoms[:, :, BBHeavyAtom.C] | |
N_term_flag, C_term_flag = get_terminus_flag(chain_nb, res_nb, mask) # (N, L) | |
omega_mask = torch.logical_not(N_term_flag) | |
phi_mask = torch.logical_not(N_term_flag) | |
psi_mask = torch.logical_not(C_term_flag) | |
# N-termini don't have omega and phi | |
omega = F.pad( | |
dihedral_from_four_points(pos_CA[:, :-1], pos_C[:, :-1], pos_N[:, 1:], pos_CA[:, 1:]), | |
pad=(1, 0), value=0, | |
) | |
phi = F.pad( | |
dihedral_from_four_points(pos_C[:, :-1], pos_N[:, 1:], pos_CA[:, 1:], pos_C[:, 1:]), | |
pad=(1, 0), value=0, | |
) | |
# C-termini don't have psi | |
psi = F.pad( | |
dihedral_from_four_points(pos_N[:, :-1], pos_CA[:, :-1], pos_C[:, :-1], pos_N[:, 1:]), | |
pad=(0, 1), value=0, | |
) | |
mask_bb_dihed = torch.stack([omega_mask, phi_mask, psi_mask], dim=-1) | |
bb_dihedral = torch.stack([omega, phi, psi], dim=-1) * mask_bb_dihed | |
return bb_dihedral, mask_bb_dihed | |
def pairwise_dihedrals(pos_atoms): | |
""" | |
Args: | |
pos_atoms: (N, L, A, 3). | |
Returns: | |
Inter-residue Phi and Psi angles, (N, L, L, 2). | |
""" | |
N, L = pos_atoms.shape[:2] | |
pos_N = pos_atoms[:, :, BBHeavyAtom.N] # (N, L, 3) | |
pos_CA = pos_atoms[:, :, BBHeavyAtom.CA] | |
pos_C = pos_atoms[:, :, BBHeavyAtom.C] | |
ir_phi = dihedral_from_four_points( | |
pos_C[:,:,None].expand(N, L, L, 3), | |
pos_N[:,None,:].expand(N, L, L, 3), | |
pos_CA[:,None,:].expand(N, L, L, 3), | |
pos_C[:,None,:].expand(N, L, L, 3) | |
) | |
ir_psi = dihedral_from_four_points( | |
pos_N[:,:,None].expand(N, L, L, 3), | |
pos_CA[:,:,None].expand(N, L, L, 3), | |
pos_C[:,:,None].expand(N, L, L, 3), | |
pos_N[:,None,:].expand(N, L, L, 3) | |
) | |
ir_dihed = torch.stack([ir_phi, ir_psi], dim=-1) | |
return ir_dihed | |
def apply_rotation_matrix_to_rot6d(R, O): | |
""" | |
Args: | |
R: (..., 3, 3) | |
O: (..., 6) | |
Returns: | |
Rotated 6D representation, (..., 6). | |
""" | |
u1, u2 = O[..., :3, None], O[..., 3:, None] # (..., 3, 1) | |
v1 = torch.matmul(R, u1).squeeze(-1) # (..., 3) | |
v2 = torch.matmul(R, u2).squeeze(-1) | |
return torch.cat([v1, v2], dim=-1) | |
def normalize_rot6d(O): | |
""" | |
Args: | |
O: (..., 6) | |
""" | |
u1, u2 = O[..., :3], O[..., 3:] # (..., 3) | |
v1 = F.normalize(u1, p=2, dim=-1) # (..., 3) | |
v2 = F.normalize(u2 - project_v2v(u2, v1), p=2, dim=-1) | |
return torch.cat([v1, v2], dim=-1) | |
def reconstruct_backbone(R, t, aa, chain_nb, res_nb, mask): | |
""" | |
Args: | |
R: (N, L, 3, 3) | |
t: (N, L, 3) | |
aa: (N, L) | |
chain_nb: (N, L) | |
res_nb: (N, L) | |
mask: (N, L) | |
Returns: | |
Reconstructed backbone atoms, (N, L, 4, 3). | |
""" | |
N, L = aa.size() | |
# atom_coords = restype_heavyatom_rigid_group_positions.clone().to(t) # (21, 14, 3) | |
bb_coords = backbone_atom_coordinates_tensor.clone().to(t) # (21, 3, 3) | |
oxygen_coord = bb_oxygen_coordinate_tensor.clone().to(t) # (21, 3) | |
aa = aa.clamp(min=0, max=20) # 20 for UNK | |
bb_coords = bb_coords[aa.flatten()].reshape(N, L, -1, 3) # (N, L, 3, 3) | |
oxygen_coord = oxygen_coord[aa.flatten()].reshape(N, L, -1) # (N, L, 3) | |
bb_pos = local_to_global(R, t, bb_coords) # Global coordinates of N, CA, C. (N, L, 3, 3). | |
# Compute PSI angle | |
bb_dihedral, _ = get_backbone_dihedral_angles(bb_pos, chain_nb, res_nb, mask) | |
psi = bb_dihedral[..., 2] # (N, L) | |
# Make rotation matrix for PSI | |
sin_psi = torch.sin(psi).reshape(N, L, 1, 1) | |
cos_psi = torch.cos(psi).reshape(N, L, 1, 1) | |
zero = torch.zeros_like(sin_psi) | |
one = torch.ones_like(sin_psi) | |
row1 = torch.cat([one, zero, zero], dim=-1) # (N, L, 1, 3) | |
row2 = torch.cat([zero, cos_psi, -sin_psi], dim=-1) # (N, L, 1, 3) | |
row3 = torch.cat([zero, sin_psi, cos_psi], dim=-1) # (N, L, 1, 3) | |
R_psi = torch.cat([row1, row2, row3], dim=-2) # (N, L, 3, 3) | |
# Compute rotoation and translation of PSI frame, and position of O. | |
R_psi, t_psi = compose_chain([ | |
(R, t), # Backbone | |
(R_psi, torch.zeros_like(t)), # PSI angle | |
]) | |
O_pos = local_to_global(R_psi, t_psi, oxygen_coord.reshape(N, L, 1, 3)) | |
bb_pos = torch.cat([bb_pos, O_pos], dim=2) # (N, L, 4, 3) | |
return bb_pos | |
def reconstruct_backbone_partially(pos_ctx, R_new, t_new, aa, chain_nb, res_nb, mask_atoms, mask_recons): | |
""" | |
Args: | |
pos: (N, L, A, 3). | |
R_new: (N, L, 3, 3). | |
t_new: (N, L, 3). | |
mask_atoms: (N, L, A). | |
mask_recons:(N, L). | |
Returns: | |
pos_new: (N, L, A, 3). | |
mask_new: (N, L, A). | |
""" | |
N, L, A = mask_atoms.size() | |
mask_res = mask_atoms[:, :, BBHeavyAtom.CA] | |
pos_recons = reconstruct_backbone(R_new, t_new, aa, chain_nb, res_nb, mask_res) # (N, L, 4, 3) | |
pos_recons = F.pad(pos_recons, pad=(0, 0, 0, A-4), value=0) # (N, L, A, 3) | |
pos_new = torch.where( | |
mask_recons[:, :, None, None].expand_as(pos_ctx), | |
pos_recons, pos_ctx | |
) # (N, L, A, 3) | |
mask_bb_atoms = torch.zeros_like(mask_atoms) | |
mask_bb_atoms[:, :, :4] = True | |
mask_new = torch.where( | |
mask_recons[:, :, None].expand_as(mask_atoms), | |
mask_bb_atoms, mask_atoms | |
) | |
return pos_new, mask_new | |