| from __future__ import annotations |
|
|
| from typing import Tuple, TypeVar |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor |
| from torch.amp import autocast |
|
|
| from src.data.esm.utils import residue_constants |
| from src.data.esm.utils.misc import unbinpack |
| from src.data.esm.utils.structure.affine3d import Affine3D |
|
|
| ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor) |
|
|
|
|
| def index_by_atom_name( |
| atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2 |
| ) -> ArrayOrTensor: |
| squeeze = False |
| if isinstance(atom_names, str): |
| atom_names = [atom_names] |
| squeeze = True |
| indices = [residue_constants.atom_order[atom_name] for atom_name in atom_names] |
| dim = dim % atom37.ndim |
| index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim)) |
| result = atom37[index] |
| if squeeze: |
| result = result.squeeze(dim) |
| return result |
|
|
|
|
| def infer_cbeta_from_atom37( |
| atom37: ArrayOrTensor, L: float = 1.522, A: float = 1.927, D: float = -2.143 |
| ): |
| """ |
| Inspired by a util in trDesign: |
| https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L92 |
| |
| input: atom37, (L)ength, (A)ngle, and (D)ihedral |
| output: 4th coord |
| """ |
| N = index_by_atom_name(atom37, "N", dim=-2) |
| CA = index_by_atom_name(atom37, "CA", dim=-2) |
| C = index_by_atom_name(atom37, "C", dim=-2) |
|
|
| if isinstance(atom37, np.ndarray): |
|
|
| def normalize(x: ArrayOrTensor): |
| return x / np.linalg.norm(x, axis=-1, keepdims=True) |
|
|
| cross = np.cross |
| else: |
| normalize = F.normalize |
| cross = torch.cross |
|
|
| with np.errstate(invalid="ignore"): |
| vec_nca = N - CA |
| vec_nc = N - C |
| nca = normalize(vec_nca) |
| n = normalize(cross(vec_nc, nca)) |
| m = [nca, cross(n, nca), n] |
| d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)] |
| return CA + sum([m * d for m, d in zip(m, d)]) |
|
|
|
|
| @torch.no_grad() |
| @autocast("cuda", enabled=False) |
| def compute_alignment_tensors( |
| mobile: torch.Tensor, |
| target: torch.Tensor, |
| atom_exists_mask: torch.Tensor | None = None, |
| sequence_id: torch.Tensor | None = None, |
| ): |
| """ |
| Align two batches of structures with support for masking invalid atoms using PyTorch. |
| |
| Args: |
| - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) |
| - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) |
| - atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N) |
| - sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking. |
| |
| Returns: |
| - centered_mobile (torch.Tensor): Batch of coordinates of structure centered mobile (B, N, 3) |
| - centroid_mobile (torch.Tensor): Batch of coordinates of mobile centeroid (B, 3) |
| - centered_target (torch.Tensor): Batch of coordinates of structure centered target (B, N, 3) |
| - centroid_target (torch.Tensor): Batch of coordinates of target centeroid (B, 3) |
| - rotation_matrix (torch.Tensor): Batch of coordinates of rotation matrix (B, 3, 3) |
| - num_valid_atoms (torch.Tensor): Batch of number of valid atoms for alignment (B,) |
| """ |
|
|
| |
| if sequence_id is not None: |
| mobile = unbinpack(mobile, sequence_id, pad_value=torch.nan) |
| target = unbinpack(target, sequence_id, pad_value=torch.nan) |
| if atom_exists_mask is not None: |
| atom_exists_mask = unbinpack(atom_exists_mask, sequence_id, pad_value=0) |
| else: |
| atom_exists_mask = torch.isfinite(target).all(-1) |
|
|
| assert mobile.shape == target.shape, "Batch structure shapes do not match!" |
|
|
| |
| batch_size = mobile.shape[0] |
|
|
| |
| if mobile.dim() == 4: |
| mobile = mobile.view(batch_size, -1, 3) |
| if target.dim() == 4: |
| target = target.view(batch_size, -1, 3) |
| if atom_exists_mask is not None and atom_exists_mask.dim() == 3: |
| atom_exists_mask = atom_exists_mask.view(batch_size, -1) |
|
|
| |
| num_atoms = mobile.shape[1] |
|
|
| |
| if atom_exists_mask is not None: |
| mobile = mobile.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) |
| target = target.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) |
| else: |
| atom_exists_mask = torch.ones( |
| batch_size, num_atoms, dtype=torch.bool, device=mobile.device |
| ) |
|
|
| num_valid_atoms = atom_exists_mask.sum(dim=-1, keepdim=True) |
| |
| centroid_mobile = mobile.sum(dim=-2, keepdim=True) / num_valid_atoms.unsqueeze(-1) |
| centroid_target = target.sum(dim=-2, keepdim=True) / num_valid_atoms.unsqueeze(-1) |
|
|
| |
| centroid_mobile[num_valid_atoms == 0] = 0 |
| centroid_target[num_valid_atoms == 0] = 0 |
|
|
| |
| centered_mobile = mobile - centroid_mobile |
| centered_target = target - centroid_target |
|
|
| centered_mobile = centered_mobile.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) |
| centered_target = centered_target.masked_fill(~atom_exists_mask.unsqueeze(-1), 0) |
|
|
| |
| covariance_matrix = torch.matmul(centered_mobile.transpose(1, 2), centered_target) |
|
|
| |
| u, _, v = torch.svd(covariance_matrix) |
|
|
| |
| rotation_matrix = torch.matmul(u, v.transpose(1, 2)) |
|
|
| return ( |
| centered_mobile, |
| centroid_mobile, |
| centered_target, |
| centroid_target, |
| rotation_matrix, |
| num_valid_atoms, |
| ) |
|
|
|
|
| @torch.no_grad() |
| @autocast("cuda", enabled=False) |
| def compute_rmsd_no_alignment( |
| aligned: torch.Tensor, |
| target: torch.Tensor, |
| num_valid_atoms: torch.Tensor, |
| reduction: str = "batch", |
| ) -> torch.Tensor: |
| """ |
| Compute RMSD between two batches of structures without alignment. |
| |
| Args: |
| - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) |
| - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) |
| - num_valid_atoms (torch.Tensor): Batch of number of valid atoms for alignment (B,) |
| - reduction (str): One of "batch", "per_sample", "per_residue". |
| |
| Returns: |
| |
| If reduction == "batch": |
| (torch.Tensor): 0-dim, Average Root Mean Square Deviation between the structures for each batch |
| If reduction == "per_sample": |
| (torch.Tensor): (B,)-dim, Root Mean Square Deviation between the structures for each batch |
| If reduction == "per_residue": |
| (torch.Tensor): (B, N)-dim, Root Mean Square Deviation between the structures for residue in the batch |
| """ |
| if reduction not in ("per_residue", "per_sample", "batch"): |
| raise ValueError("Unrecognized reduction: '{reduction}'") |
| |
| diff = aligned - target |
| if reduction == "per_residue": |
| mean_squared_error = diff.square().view(diff.size(0), -1, 9).mean(dim=-1) |
| else: |
| mean_squared_error = diff.square().sum(dim=(1, 2)) / ( |
| num_valid_atoms.squeeze(-1) * 3 |
| ) |
|
|
| rmsd = torch.sqrt(mean_squared_error) |
| if reduction in ("per_sample", "per_residue"): |
| return rmsd |
| elif reduction == "batch": |
| avg_rmsd = rmsd.masked_fill(num_valid_atoms.squeeze(-1) == 0, 0).sum() / ( |
| (num_valid_atoms > 0).sum() + 1e-8 |
| ) |
| return avg_rmsd |
| else: |
| raise ValueError(reduction) |
|
|
|
|
| @torch.no_grad() |
| @autocast("cuda", enabled=False) |
| def compute_affine_and_rmsd( |
| mobile: torch.Tensor, |
| target: torch.Tensor, |
| atom_exists_mask: torch.Tensor | None = None, |
| sequence_id: torch.Tensor | None = None, |
| ) -> Tuple[Affine3D, torch.Tensor]: |
| """ |
| Compute RMSD between two batches of structures with support for masking invalid atoms using PyTorch. |
| |
| Args: |
| - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) |
| - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) |
| - atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N) |
| - sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking. |
| |
| Returns: |
| - affine (Affine3D): Transformation between mobile and target structure |
| - avg_rmsd (torch.Tensor): Average Root Mean Square Deviation between the structures for each batch |
| """ |
|
|
| ( |
| centered_mobile, |
| centroid_mobile, |
| centered_target, |
| centroid_target, |
| rotation_matrix, |
| num_valid_atoms, |
| ) = compute_alignment_tensors( |
| mobile=mobile, |
| target=target, |
| atom_exists_mask=atom_exists_mask, |
| sequence_id=sequence_id, |
| ) |
|
|
| |
| translation = torch.matmul(-centroid_mobile, rotation_matrix) + centroid_target |
| affine = Affine3D.from_tensor_pair( |
| translation, rotation_matrix.unsqueeze(dim=-3).transpose(-2, -1) |
| ) |
|
|
| |
| rotated_mobile = torch.matmul(centered_mobile, rotation_matrix) |
| avg_rmsd = compute_rmsd_no_alignment( |
| rotated_mobile, centered_target, num_valid_atoms, reduction="batch" |
| ) |
|
|
| return affine, avg_rmsd |
|
|
|
|
| def compute_gdt_ts_no_alignment( |
| aligned: torch.Tensor, |
| target: torch.Tensor, |
| atom_exists_mask: torch.Tensor, |
| reduction: str = "batch", |
| ) -> torch.Tensor: |
| """ |
| Compute GDT_TS between two batches of structures without alignment. |
| |
| Args: |
| - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3) |
| - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3) |
| - atom_exists_mask (torch.Tensor): Mask for Whether an atom exists of shape (B, N). noo |
| - reduction (str): One of "batch", "per_sample". |
| |
| Returns: |
| If reduction == "batch": |
| (torch.Tensor): 0-dim, GDT_TS between the structures for each batch |
| If reduction == "per_sample": |
| (torch.Tensor): (B,)-dim, GDT_TS between the structures for each sample in the batch |
| """ |
| if reduction not in ("per_sample", "batch"): |
| raise ValueError("Unrecognized reduction: '{reduction}'") |
|
|
| if atom_exists_mask is None: |
| atom_exists_mask = torch.isfinite(target).all(dim=-1) |
|
|
| deviation = torch.linalg.vector_norm(aligned - target, dim=-1) |
| num_valid_atoms = atom_exists_mask.sum(dim=-1) |
|
|
| |
| score = ( |
| ((deviation < 1) * atom_exists_mask).sum(dim=-1) / num_valid_atoms |
| + ((deviation < 2) * atom_exists_mask).sum(dim=-1) / num_valid_atoms |
| + ((deviation < 4) * atom_exists_mask).sum(dim=-1) / num_valid_atoms |
| + ((deviation < 8) * atom_exists_mask).sum(dim=-1) / num_valid_atoms |
| ) * 0.25 |
|
|
| if reduction == "batch": |
| return score.mean() |
| elif reduction == "per_sample": |
| return score |
| else: |
| raise ValueError("Unrecognized reduction: '{reduction}'") |
|
|