Holmes
test
ca7299e
from typing import Optional
import torch
from src.models.score import so3, r3
from src.common.rigid_utils import Rigid, Rotation, quat_multiply
from src.common import rotation3d
def assemble_rigid(rotvec: torch.Tensor, trans: torch.Tensor):
rotvec_shape = rotvec.shape
rotmat = rotation3d.axis_angle_to_matrix(rotvec).view(rotvec_shape[:-1] + (3, 3))
return Rigid(
rots=Rotation(rot_mats=rotmat),
trans=trans,
)
def apply_mask(x_tgt, x_src, tgt_mask):
return tgt_mask * x_tgt + (1 - tgt_mask) * x_src
class FrameDiffuser:
"""
Wrapper class for diffusion of rigid body transformations,
including rotations and translations.
"""
def __init__(self,
trans_diffuser: Optional[r3.R3Diffuser] = None,
rot_diffuser: Optional[so3.SO3Diffuser] = None,
min_t: float = 0.001,
):
# if None, then no diffusion for this component
self.trans_diffuser = trans_diffuser
self.rot_diffuser = rot_diffuser
self.min_t = min_t
def forward_marginal(
self,
rigids_0: Rigid,
t: torch.Tensor,
diffuse_mask: torch.Tensor = None,
as_tensor_7: bool = True,
):
"""
Args:
rigids_0: [..., N] openfold Rigid objects
t: continuous time in [0, 1].
Returns:
Dict contains:
rigids_t: [..., N] noised rigid. [..., N, 7] if as_tensor_7 is true.
trans_score: [..., N, 3] translation score
rot_score: [..., N, 3] rotation score
trans_score_norm: [...] translation score norm
rot_score_norm: [...] rotation score norm
"""
output = {}
rot_0 = rotation3d.matrix_to_axis_angle(rigids_0.get_rots().get_rot_mats())
trans_0 = rigids_0.get_trans()
if self.rot_diffuser is None:
rot_t = rot_0
rot_score, rot_score_scaling = torch.zeros_like(rot_0), t
else:
rot_t, rot_score = self.rot_diffuser.forward_marginal(rot_0, t)
rot_score_scaling = self.rot_diffuser.score_scaling(t)
if self.trans_diffuser is None:
trans_t, trans_score, trans_score_scaling = (
trans_0,
torch.zeros_like(trans_0),
torch.ones_like(t)
)
else:
trans_t, trans_score = self.trans_diffuser.forward_marginal(trans_0, t)
trans_score_scaling = self.trans_diffuser.score_scaling(t)
# Perturb only a subset of residues
if diffuse_mask is not None:
diffuse_mask = torch.as_tensor(diffuse_mask, device=trans_t.device, dtype=trans_t.dtype)[..., None]
rot_t = apply_mask(rot_t, rot_0, diffuse_mask)
trans_t = apply_mask(trans_t, trans_0, diffuse_mask)
trans_score = apply_mask(
trans_score,
torch.zeros_like(trans_score),
diffuse_mask
)
rot_score = apply_mask(
rot_score,
torch.zeros_like(rot_score),
diffuse_mask
)
rigids_t = assemble_rigid(rot_t, trans_t)
if as_tensor_7:
rigids_t = rigids_t.to_tensor_7()
output = {
'rigids_t': rigids_t,
'trans_score': trans_score,
'rot_score': rot_score,
'trans_score_scaling': trans_score_scaling,
'rot_score_scaling': rot_score_scaling,
}
return output
def score(
self,
rigids_0: Rigid,
rigids_t: Rigid,
t: torch.Tensor,
mask: torch.Tensor = None,
):
rot_0, trans_0 = rigids_0.get_rots(), rigids_0.get_trans()
rot_t, trans_t = rigids_t.get_rots(), rigids_t.get_trans()
if self.rot_diffuser is None:
rot_score = torch.zeros_like(rot_0)
else:
rot_0_inv = rot_0.invert()
quat_0_inv = rotation3d.matrix_to_quaternion(rot_0_inv.get_rot_mats())
quat_t = rotation3d.matrix_to_quaternion(rot_t.get_rot_mats())
# get relative rotation
quat_0t = quat_multiply(quat_0_inv, quat_t)
rotvec_0t = rotation3d.quaternion_to_axis_angle(quat_0t)
# calculate score
rot_score = self.rot_diffuser.score(rotvec_0t, t)
if self.trans_diffuser is None:
trans_score = torch.zeros_like(trans_0)
else:
trans_score = self.trans_diffuser.score(trans_t, trans_0, t, scale=True)
if mask is not None:
trans_score = trans_score * mask[..., None]
rot_score = rot_score * mask[..., None]
return {
'trans_score': trans_score,
'rot_score': rot_score
}
def score_scaling(self, t):
rot_score_scaling = self.rot_diffuser.score_scaling(t)
trans_score_scaling = self.trans_diffuser.score_scaling(t)
return {
'trans_score_scaling': trans_score_scaling,
'rot_score_scaling': rot_score_scaling,
}
def reverse(
self,
rigids_t: Rigid,
rot_score: torch.Tensor,
trans_score: torch.Tensor,
t: torch.Tensor,
dt: float,
diffuse_mask: torch.Tensor = None,
center_trans: bool = True,
noise_scale: float = 1.0,
probability_flow: bool = True,
):
"""Reverse sampling function from (t) to (t-1).
Args:
rigids_t: [..., N] protein rigid objects at time t.
rot_score: [..., N, 3] rotation score.
trans_score: [..., N, 3] translation score.
t: continuous time in [0, 1].
dt: continuous step size in [0, 1].
mask: [..., N] which residues to update.
center_trans: true to set center of mass to zero after step
probability_flow: whether to use probability flow ODE.
Returns:
rigids_t_1: [..., N] protein rigid objects at time t-1.
"""
# extract rot and trans as tensors
rot_t = rotation3d.matrix_to_axis_angle(rigids_t.get_rots().get_rot_mats())
trans_t = rigids_t.get_trans()
# reverse rot
rot_t_1 = self.rot_diffuser.reverse(
rot_t=rot_t,
score_t=rot_score,
t=t,
dt=dt,
noise_scale=noise_scale,
probability_flow=probability_flow,
) if self.rot_diffuser is not None else rot_t # if no diffusion module, return as-is
# reverse trans
trans_t_1 = self.trans_diffuser.reverse(
x_t=trans_t,
score_t=trans_score,
t=t,
dt=dt,
center=center_trans,
noise_scale=noise_scale,
probability_flow=probability_flow,
) if self.trans_diffuser is not None else trans_t
# apply mask
if diffuse_mask is not None:
trans_t_1 = apply_mask(trans_t_1, trans_t, diffuse_mask[..., None])
rot_t_1 = apply_mask(rot_t_1, rot_t, diffuse_mask[..., None])
return assemble_rigid(rot_t_1, trans_t_1)
def sample_prior(
self,
shape: torch.Size,
device: torch.device,
reference_rigids: Rigid = None,
diffuse_mask: torch.Tensor = None,
as_tensor_7: bool = False
):
"""Samples rigids from reference distribution.
"""
if reference_rigids is not None:
assert reference_rigids.shape[:-1] == shape, f"reference_rigids.shape[:-1] = {reference_rigids.shape[:-1]}, shape = {shape}"
assert diffuse_mask is not None, "diffuse_mask must be provided if reference_rigids is given"
rot_ref = rotation3d.matrix_to_axis_angle(reference_rigids.get_rots().get_rot_mats())
trans_ref = reference_rigids.get_trans()
trans_ref = self.trans_diffuser.scale(trans_ref)
else:
# sanity check
assert diffuse_mask is None, "diffuse_mask must be None if reference_rigids is None"
assert self.rot_diffuser is not None and self.trans_diffuser is not None
# sample from prior
trans_shape, rot_shape = shape + (3, ), shape + (3, )
rot_sample = self.rot_diffuser.sample_prior(shape=rot_shape, device=device) \
if self.rot_diffuser is not None else rot_ref
trans_sample = self.trans_diffuser.sample_prior(shape=trans_shape, device=device) \
if self.trans_diffuser is not None else trans_ref
# apply mask
if diffuse_mask is not None:
rot_sample = apply_mask(rot_sample, rot_ref, diffuse_mask[..., None])
trans_sample = apply_mask(trans_sample, trans_ref, diffuse_mask[..., None])
trans_sample = self.trans_diffuser.unscale(trans_sample)
# assemble sampled rot and trans -> rigid
rigids_t = assemble_rigid(rot_sample, trans_sample)
if as_tensor_7:
rigids_t = rigids_t.to_tensor_7()
return {'rigids_t': rigids_t}