Holmes
test
ca7299e
"""Inspired by https://github.com/jasonkyuyim/se3_diffusion/blob/master/data/r3_diffuser.py"""
from math import sqrt
import torch
from src.utils.tensor_utils import inflate_array_like
class R3Diffuser:
"""VPSDE diffusion module."""
def __init__(
self,
min_b: float = 0.1,
max_b: float = 20.0,
coordinate_scaling: float = 1.0,
):
self.min_b = min_b
self.max_b = max_b
self.coordinate_scaling = coordinate_scaling
def scale(self, x):
return x * self.coordinate_scaling
def unscale(self, x):
return x / self.coordinate_scaling
def b_t(self, t: torch.Tensor):
if torch.any(t < 0) or torch.any(t > 1):
raise ValueError(f'Invalid t={t}')
return self.min_b + t * (self.max_b - self.min_b)
def diffusion_coef(self, t):
return torch.sqrt(self.b_t(t))
def drift_coef(self, x, t):
return -0.5 * self.b_t(t) * x
def sample_prior(self, shape, device=None):
return torch.randn(size=shape, device=device)
def marginal_b_t(self, t):
return t*self.min_b + 0.5*(t**2)*(self.max_b-self.min_b)
def calc_trans_0(self, score_t, x_t, t):
beta_t = self.marginal_b_t(t)
beta_t = beta_t[..., None, None]
cond_var = 1 - torch.exp(-beta_t)
return (score_t * cond_var + x_t) / torch.exp(-0.5*beta_t)
def forward_marginal(
self,
x_0: torch.Tensor,
t: torch.Tensor
):
"""Samples marginal p(x(t) | x(0)).
Args:
x_0: [..., n, 3] initial positions in Angstroms.
t: continuous time in [0, 1].
Returns:
x_t: [..., n, 3] positions at time t in Angstroms.
score_t: [..., n, 3] score at time t in scaled Angstroms.
"""
t = inflate_array_like(t, x_0)
x_0 = self.scale(x_0)
loc = torch.exp(-0.5 * self.marginal_b_t(t)) * x_0
scale = torch.sqrt(1 - torch.exp(-self.marginal_b_t(t)))
z = torch.randn_like(x_0)
x_t = z * scale + loc
score_t = self.score(x_t, x_0, t)
x_t = self.unscale(x_t)
return x_t, score_t
def score_scaling(self, t: torch.Tensor):
return 1.0 / torch.sqrt(self.conditional_var(t))
def reverse(
self,
x_t: torch.Tensor,
score_t: torch.Tensor,
t: torch.Tensor,
dt: float,
mask: torch.Tensor = None,
center: bool = True,
noise_scale: float = 1.0,
probability_flow: bool = True,
):
"""Simulates the reverse SDE for 1 step
Args:
x_t: [..., 3] current positions at time t in angstroms.
score_t: [..., 3] rotation score at time t.
t: continuous time in [0, 1].
dt: continuous step size in [0, 1].
mask: True indicates which residues to diffuse.
probability_flow: whether to use probability flow ODE.
Returns:
[..., 3] positions at next step t-1.
"""
t = inflate_array_like(t, x_t)
x_t = self.scale(x_t)
f_t = self.drift_coef(x_t, t)
g_t = self.diffusion_coef(t)
z = noise_scale * torch.randn_like(score_t)
rev_drift = (f_t - g_t ** 2 * score_t) * dt * (0.5 if probability_flow else 1.)
rev_diffusion = 0. if probability_flow else (g_t * sqrt(dt) * z)
perturb = rev_drift + rev_diffusion
if mask is not None:
perturb *= mask[..., None]
else:
mask = torch.ones_like(x_t[..., 0])
x_t_1 = x_t - perturb # reverse in time
if center:
com = torch.sum(x_t_1, dim=-2) / torch.sum(mask, dim=-1)[..., None] # reduce length dim
x_t_1 -= com[..., None, :]
x_t_1 = self.unscale(x_t_1)
return x_t_1
def conditional_var(self, t, use_torch=False):
"""Conditional variance of p(xt|x0).
Var[x_t|x_0] = conditional_var(t) * I
"""
return 1.0 - torch.exp(-self.marginal_b_t(t))
def score(self, x_t, x_0, t, scale=False):
t = inflate_array_like(t, x_t)
if scale:
x_t, x_0 = self.scale(x_t), self.scale(x_0)
return -(x_t - torch.exp(-0.5 * self.marginal_b_t(t)) * x_0) / self.conditional_var(t)
def distribution(self, x_t, score_t, t, mask, dt):
x_t = self.scale(x_t)
f_t = self.drift_coef(x_t, t)
g_t = self.diffusion_coef(t)
std = g_t * sqrt(dt)
mu = x_t - (f_t - g_t**2 * score_t) * dt
if mask is not None:
mu *= mask[..., None]
return mu, std