Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from diffab.modules.common.layers import clampped_one_hot | |
from diffab.modules.common.so3 import ApproxAngularDistribution, random_normal_so3, so3vec_to_rotation, rotation_to_so3vec | |
class VarianceSchedule(nn.Module): | |
def __init__(self, num_steps=100, s=0.01): | |
super().__init__() | |
T = num_steps | |
t = torch.arange(0, num_steps+1, dtype=torch.float) | |
f_t = torch.cos( (np.pi / 2) * ((t/T) + s) / (1 + s) ) ** 2 | |
alpha_bars = f_t / f_t[0] | |
betas = 1 - (alpha_bars[1:] / alpha_bars[:-1]) | |
betas = torch.cat([torch.zeros([1]), betas], dim=0) | |
betas = betas.clamp_max(0.999) | |
sigmas = torch.zeros_like(betas) | |
for i in range(1, betas.size(0)): | |
sigmas[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i] | |
sigmas = torch.sqrt(sigmas) | |
self.register_buffer('betas', betas) | |
self.register_buffer('alpha_bars', alpha_bars) | |
self.register_buffer('alphas', 1 - betas) | |
self.register_buffer('sigmas', sigmas) | |
class PositionTransition(nn.Module): | |
def __init__(self, num_steps, var_sched_opt={}): | |
super().__init__() | |
self.var_sched = VarianceSchedule(num_steps, **var_sched_opt) | |
def add_noise(self, p_0, mask_generate, t): | |
""" | |
Args: | |
p_0: (N, L, 3). | |
mask_generate: (N, L). | |
t: (N,). | |
""" | |
alpha_bar = self.var_sched.alpha_bars[t] | |
c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) | |
c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) | |
e_rand = torch.randn_like(p_0) | |
p_noisy = c0*p_0 + c1*e_rand | |
p_noisy = torch.where(mask_generate[..., None].expand_as(p_0), p_noisy, p_0) | |
return p_noisy, e_rand | |
def denoise(self, p_t, eps_p, mask_generate, t): | |
# IMPORTANT: | |
# clampping alpha is to fix the instability issue at the first step (t=T) | |
# it seems like a problem with the ``improved ddpm''. | |
alpha = self.var_sched.alphas[t].clamp_min( | |
self.var_sched.alphas[-2] | |
) | |
alpha_bar = self.var_sched.alpha_bars[t] | |
sigma = self.var_sched.sigmas[t].view(-1, 1, 1) | |
c0 = ( 1.0 / torch.sqrt(alpha + 1e-8) ).view(-1, 1, 1) | |
c1 = ( (1 - alpha) / torch.sqrt(1 - alpha_bar + 1e-8) ).view(-1, 1, 1) | |
z = torch.where( | |
(t > 1)[:, None, None].expand_as(p_t), | |
torch.randn_like(p_t), | |
torch.zeros_like(p_t), | |
) | |
p_next = c0 * (p_t - c1 * eps_p) + sigma * z | |
p_next = torch.where(mask_generate[..., None].expand_as(p_t), p_next, p_t) | |
return p_next | |
class RotationTransition(nn.Module): | |
def __init__(self, num_steps, var_sched_opt={}, angular_distrib_fwd_opt={}, angular_distrib_inv_opt={}): | |
super().__init__() | |
self.var_sched = VarianceSchedule(num_steps, **var_sched_opt) | |
# Forward (perturb) | |
c1 = torch.sqrt(1 - self.var_sched.alpha_bars) # (T,). | |
self.angular_distrib_fwd = ApproxAngularDistribution(c1.tolist(), **angular_distrib_fwd_opt) | |
# Inverse (generate) | |
sigma = self.var_sched.sigmas | |
self.angular_distrib_inv = ApproxAngularDistribution(sigma.tolist(), **angular_distrib_inv_opt) | |
self.register_buffer('_dummy', torch.empty([0, ])) | |
def add_noise(self, v_0, mask_generate, t): | |
""" | |
Args: | |
v_0: (N, L, 3). | |
mask_generate: (N, L). | |
t: (N,). | |
""" | |
N, L = mask_generate.size() | |
alpha_bar = self.var_sched.alpha_bars[t] | |
c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) | |
c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) | |
# Noise rotation | |
e_scaled = random_normal_so3(t[:, None].expand(N, L), self.angular_distrib_fwd, device=self._dummy.device) # (N, L, 3) | |
e_normal = e_scaled / (c1 + 1e-8) | |
E_scaled = so3vec_to_rotation(e_scaled) # (N, L, 3, 3) | |
# Scaled true rotation | |
R0_scaled = so3vec_to_rotation(c0 * v_0) # (N, L, 3, 3) | |
R_noisy = E_scaled @ R0_scaled | |
v_noisy = rotation_to_so3vec(R_noisy) | |
v_noisy = torch.where(mask_generate[..., None].expand_as(v_0), v_noisy, v_0) | |
return v_noisy, e_scaled | |
def denoise(self, v_t, v_next, mask_generate, t): | |
N, L = mask_generate.size() | |
e = random_normal_so3(t[:, None].expand(N, L), self.angular_distrib_inv, device=self._dummy.device) # (N, L, 3) | |
e = torch.where( | |
(t > 1)[:, None, None].expand(N, L, 3), | |
e, | |
torch.zeros_like(e) # Simply denoise and don't add noise at the last step | |
) | |
E = so3vec_to_rotation(e) | |
R_next = E @ so3vec_to_rotation(v_next) | |
v_next = rotation_to_so3vec(R_next) | |
v_next = torch.where(mask_generate[..., None].expand_as(v_next), v_next, v_t) | |
return v_next | |
class AminoacidCategoricalTransition(nn.Module): | |
def __init__(self, num_steps, num_classes=20, var_sched_opt={}): | |
super().__init__() | |
self.num_classes = num_classes | |
self.var_sched = VarianceSchedule(num_steps, **var_sched_opt) | |
def _sample(c): | |
""" | |
Args: | |
c: (N, L, K). | |
Returns: | |
x: (N, L). | |
""" | |
N, L, K = c.size() | |
c = c.view(N*L, K) + 1e-8 | |
x = torch.multinomial(c, 1).view(N, L) | |
return x | |
def add_noise(self, x_0, mask_generate, t): | |
""" | |
Args: | |
x_0: (N, L) | |
mask_generate: (N, L). | |
t: (N,). | |
Returns: | |
c_t: Probability, (N, L, K). | |
x_t: Sample, LongTensor, (N, L). | |
""" | |
N, L = x_0.size() | |
K = self.num_classes | |
c_0 = clampped_one_hot(x_0, num_classes=K).float() # (N, L, K). | |
alpha_bar = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1) | |
c_noisy = (alpha_bar*c_0) + ( (1-alpha_bar)/K ) | |
c_t = torch.where(mask_generate[..., None].expand(N,L,K), c_noisy, c_0) | |
x_t = self._sample(c_t) | |
return c_t, x_t | |
def posterior(self, x_t, x_0, t): | |
""" | |
Args: | |
x_t: Category LongTensor (N, L) or Probability FloatTensor (N, L, K). | |
x_0: Category LongTensor (N, L) or Probability FloatTensor (N, L, K). | |
t: (N,). | |
Returns: | |
theta: Posterior probability at (t-1)-th step, (N, L, K). | |
""" | |
K = self.num_classes | |
if x_t.dim() == 3: | |
c_t = x_t # When x_t is probability distribution. | |
else: | |
c_t = clampped_one_hot(x_t, num_classes=K).float() # (N, L, K) | |
if x_0.dim() == 3: | |
c_0 = x_0 # When x_0 is probability distribution. | |
else: | |
c_0 = clampped_one_hot(x_0, num_classes=K).float() # (N, L, K) | |
alpha = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1) | |
alpha_bar = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1) | |
theta = ((alpha*c_t) + (1-alpha)/K) * ((alpha_bar*c_0) + (1-alpha_bar)/K) # (N, L, K) | |
theta = theta / (theta.sum(dim=-1, keepdim=True) + 1e-8) | |
return theta | |
def denoise(self, x_t, c_0_pred, mask_generate, t): | |
""" | |
Args: | |
x_t: (N, L). | |
c_0_pred: Normalized probability predicted by networks, (N, L, K). | |
mask_generate: (N, L). | |
t: (N,). | |
Returns: | |
post: Posterior probability at (t-1)-th step, (N, L, K). | |
x_next: Sample at (t-1)-th step, LongTensor, (N, L). | |
""" | |
c_t = clampped_one_hot(x_t, num_classes=self.num_classes).float() # (N, L, K) | |
post = self.posterior(c_t, c_0_pred, t=t) # (N, L, K) | |
post = torch.where(mask_generate[..., None].expand(post.size()), post, c_t) | |
x_next = self._sample(post) | |
return post, x_next | |