luost26's picture
Update
753e275
import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
from tqdm.auto import tqdm
from diffab.modules.common.geometry import apply_rotation_to_vector, quaternion_1ijk_to_rotation_matrix
from diffab.modules.common.so3 import so3vec_to_rotation, rotation_to_so3vec, random_uniform_so3
from diffab.modules.encoders.ga import GAEncoder
from .transition import RotationTransition, PositionTransition, AminoacidCategoricalTransition
def rotation_matrix_cosine_loss(R_pred, R_true):
"""
Args:
R_pred: (*, 3, 3).
R_true: (*, 3, 3).
Returns:
Per-matrix losses, (*, ).
"""
size = list(R_pred.shape[:-2])
ncol = R_pred.numel() // 3
RT_pred = R_pred.transpose(-2, -1).reshape(ncol, 3) # (ncol, 3)
RT_true = R_true.transpose(-2, -1).reshape(ncol, 3) # (ncol, 3)
ones = torch.ones([ncol, ], dtype=torch.long, device=R_pred.device)
loss = F.cosine_embedding_loss(RT_pred, RT_true, ones, reduction='none') # (ncol*3, )
loss = loss.reshape(size + [3]).sum(dim=-1) # (*, )
return loss
class EpsilonNet(nn.Module):
def __init__(self, res_feat_dim, pair_feat_dim, num_layers, encoder_opt={}):
super().__init__()
self.current_sequence_embedding = nn.Embedding(25, res_feat_dim) # 22 is padding
self.res_feat_mixer = nn.Sequential(
nn.Linear(res_feat_dim * 2, res_feat_dim), nn.ReLU(),
nn.Linear(res_feat_dim, res_feat_dim),
)
self.encoder = GAEncoder(res_feat_dim, pair_feat_dim, num_layers, **encoder_opt)
self.eps_crd_net = nn.Sequential(
nn.Linear(res_feat_dim+3, res_feat_dim), nn.ReLU(),
nn.Linear(res_feat_dim, res_feat_dim), nn.ReLU(),
nn.Linear(res_feat_dim, 3)
)
self.eps_rot_net = nn.Sequential(
nn.Linear(res_feat_dim+3, res_feat_dim), nn.ReLU(),
nn.Linear(res_feat_dim, res_feat_dim), nn.ReLU(),
nn.Linear(res_feat_dim, 3)
)
self.eps_seq_net = nn.Sequential(
nn.Linear(res_feat_dim+3, res_feat_dim), nn.ReLU(),
nn.Linear(res_feat_dim, res_feat_dim), nn.ReLU(),
nn.Linear(res_feat_dim, 20), nn.Softmax(dim=-1)
)
def forward(self, v_t, p_t, s_t, res_feat, pair_feat, beta, mask_generate, mask_res):
"""
Args:
v_t: (N, L, 3).
p_t: (N, L, 3).
s_t: (N, L).
res_feat: (N, L, res_dim).
pair_feat: (N, L, L, pair_dim).
beta: (N,).
mask_generate: (N, L).
mask_res: (N, L).
Returns:
v_next: UPDATED (not epsilon) SO3-vector of orietnations, (N, L, 3).
eps_pos: (N, L, 3).
"""
N, L = mask_res.size()
R = so3vec_to_rotation(v_t) # (N, L, 3, 3)
# s_t = s_t.clamp(min=0, max=19) # TODO: clamping is good but ugly.
res_feat = self.res_feat_mixer(torch.cat([res_feat, self.current_sequence_embedding(s_t)], dim=-1)) # [Important] Incorporate sequence at the current step.
res_feat = self.encoder(R, p_t, res_feat, pair_feat, mask_res)
t_embed = torch.stack([beta, torch.sin(beta), torch.cos(beta)], dim=-1)[:, None, :].expand(N, L, 3)
in_feat = torch.cat([res_feat, t_embed], dim=-1)
# Position changes
eps_crd = self.eps_crd_net(in_feat) # (N, L, 3)
eps_pos = apply_rotation_to_vector(R, eps_crd) # (N, L, 3)
eps_pos = torch.where(mask_generate[:, :, None].expand_as(eps_pos), eps_pos, torch.zeros_like(eps_pos))
# New orientation
eps_rot = self.eps_rot_net(in_feat) # (N, L, 3)
U = quaternion_1ijk_to_rotation_matrix(eps_rot) # (N, L, 3, 3)
R_next = R @ U
v_next = rotation_to_so3vec(R_next) # (N, L, 3)
v_next = torch.where(mask_generate[:, :, None].expand_as(v_next), v_next, v_t)
# New sequence categorical distributions
c_denoised = self.eps_seq_net(in_feat) # Already softmax-ed, (N, L, 20)
return v_next, R_next, eps_pos, c_denoised
class FullDPM(nn.Module):
def __init__(
self,
res_feat_dim,
pair_feat_dim,
num_steps,
eps_net_opt={},
trans_rot_opt={},
trans_pos_opt={},
trans_seq_opt={},
position_mean=[0.0, 0.0, 0.0],
position_scale=[10.0],
):
super().__init__()
self.eps_net = EpsilonNet(res_feat_dim, pair_feat_dim, **eps_net_opt)
self.num_steps = num_steps
self.trans_rot = RotationTransition(num_steps, **trans_rot_opt)
self.trans_pos = PositionTransition(num_steps, **trans_pos_opt)
self.trans_seq = AminoacidCategoricalTransition(num_steps, **trans_seq_opt)
self.register_buffer('position_mean', torch.FloatTensor(position_mean).view(1, 1, -1))
self.register_buffer('position_scale', torch.FloatTensor(position_scale).view(1, 1, -1))
self.register_buffer('_dummy', torch.empty([0, ]))
def _normalize_position(self, p):
p_norm = (p - self.position_mean) / self.position_scale
return p_norm
def _unnormalize_position(self, p_norm):
p = p_norm * self.position_scale + self.position_mean
return p
def forward(self, v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, denoise_structure, denoise_sequence, t=None):
N, L = res_feat.shape[:2]
if t == None:
t = torch.randint(0, self.num_steps, (N,), dtype=torch.long, device=self._dummy.device)
p_0 = self._normalize_position(p_0)
if denoise_structure:
# Add noise to rotation
R_0 = so3vec_to_rotation(v_0)
v_noisy, _ = self.trans_rot.add_noise(v_0, mask_generate, t)
# Add noise to positions
p_noisy, eps_p = self.trans_pos.add_noise(p_0, mask_generate, t)
else:
R_0 = so3vec_to_rotation(v_0)
v_noisy = v_0.clone()
p_noisy = p_0.clone()
eps_p = torch.zeros_like(p_noisy)
if denoise_sequence:
# Add noise to sequence
_, s_noisy = self.trans_seq.add_noise(s_0, mask_generate, t)
else:
s_noisy = s_0.clone()
beta = self.trans_pos.var_sched.betas[t]
v_pred, R_pred, eps_p_pred, c_denoised = self.eps_net(
v_noisy, p_noisy, s_noisy, res_feat, pair_feat, beta, mask_generate, mask_res
) # (N, L, 3), (N, L, 3, 3), (N, L, 3), (N, L, 20), (N, L)
loss_dict = {}
# Rotation loss
loss_rot = rotation_matrix_cosine_loss(R_pred, R_0) # (N, L)
loss_rot = (loss_rot * mask_generate).sum() / (mask_generate.sum().float() + 1e-8)
loss_dict['rot'] = loss_rot
# Position loss
loss_pos = F.mse_loss(eps_p_pred, eps_p, reduction='none').sum(dim=-1) # (N, L)
loss_pos = (loss_pos * mask_generate).sum() / (mask_generate.sum().float() + 1e-8)
loss_dict['pos'] = loss_pos
# Sequence categorical loss
post_true = self.trans_seq.posterior(s_noisy, s_0, t)
log_post_pred = torch.log(self.trans_seq.posterior(s_noisy, c_denoised, t) + 1e-8)
kldiv = F.kl_div(
input=log_post_pred,
target=post_true,
reduction='none',
log_target=False
).sum(dim=-1) # (N, L)
loss_seq = (kldiv * mask_generate).sum() / (mask_generate.sum().float() + 1e-8)
loss_dict['seq'] = loss_seq
return loss_dict
@torch.no_grad()
def sample(
self,
v, p, s,
res_feat, pair_feat,
mask_generate, mask_res,
sample_structure=True, sample_sequence=True,
pbar=False,
):
"""
Args:
v: Orientations of contextual residues, (N, L, 3).
p: Positions of contextual residues, (N, L, 3).
s: Sequence of contextual residues, (N, L).
"""
N, L = v.shape[:2]
p = self._normalize_position(p)
# Set the orientation and position of residues to be predicted to random values
if sample_structure:
v_rand = random_uniform_so3([N, L], device=self._dummy.device)
p_rand = torch.randn_like(p)
v_init = torch.where(mask_generate[:, :, None].expand_as(v), v_rand, v)
p_init = torch.where(mask_generate[:, :, None].expand_as(p), p_rand, p)
else:
v_init, p_init = v, p
if sample_sequence:
s_rand = torch.randint_like(s, low=0, high=19)
s_init = torch.where(mask_generate, s_rand, s)
else:
s_init = s
traj = {self.num_steps: (v_init, self._unnormalize_position(p_init), s_init)}
if pbar:
pbar = functools.partial(tqdm, total=self.num_steps, desc='Sampling')
else:
pbar = lambda x: x
for t in pbar(range(self.num_steps, 0, -1)):
v_t, p_t, s_t = traj[t]
p_t = self._normalize_position(p_t)
beta = self.trans_pos.var_sched.betas[t].expand([N, ])
t_tensor = torch.full([N, ], fill_value=t, dtype=torch.long, device=self._dummy.device)
v_next, R_next, eps_p, c_denoised = self.eps_net(
v_t, p_t, s_t, res_feat, pair_feat, beta, mask_generate, mask_res
) # (N, L, 3), (N, L, 3, 3), (N, L, 3)
v_next = self.trans_rot.denoise(v_t, v_next, mask_generate, t_tensor)
p_next = self.trans_pos.denoise(p_t, eps_p, mask_generate, t_tensor)
_, s_next = self.trans_seq.denoise(s_t, c_denoised, mask_generate, t_tensor)
if not sample_structure:
v_next, p_next = v_t, p_t
if not sample_sequence:
s_next = s_t
traj[t-1] = (v_next, self._unnormalize_position(p_next), s_next)
traj[t] = tuple(x.cpu() for x in traj[t]) # Move previous states to cpu memory.
return traj
@torch.no_grad()
def optimize(
self,
v, p, s,
opt_step: int,
res_feat, pair_feat,
mask_generate, mask_res,
sample_structure=True, sample_sequence=True,
pbar=False,
):
"""
Description:
First adds noise to the given structure, then denoises it.
"""
N, L = v.shape[:2]
p = self._normalize_position(p)
t = torch.full([N, ], fill_value=opt_step, dtype=torch.long, device=self._dummy.device)
# Set the orientation and position of residues to be predicted to random values
if sample_structure:
# Add noise to rotation
v_noisy, _ = self.trans_rot.add_noise(v, mask_generate, t)
# Add noise to positions
p_noisy, _ = self.trans_pos.add_noise(p, mask_generate, t)
v_init = torch.where(mask_generate[:, :, None].expand_as(v), v_noisy, v)
p_init = torch.where(mask_generate[:, :, None].expand_as(p), p_noisy, p)
else:
v_init, p_init = v, p
if sample_sequence:
_, s_noisy = self.trans_seq.add_noise(s, mask_generate, t)
s_init = torch.where(mask_generate, s_noisy, s)
else:
s_init = s
traj = {opt_step: (v_init, self._unnormalize_position(p_init), s_init)}
if pbar:
pbar = functools.partial(tqdm, total=opt_step, desc='Optimizing')
else:
pbar = lambda x: x
for t in pbar(range(opt_step, 0, -1)):
v_t, p_t, s_t = traj[t]
p_t = self._normalize_position(p_t)
beta = self.trans_pos.var_sched.betas[t].expand([N, ])
t_tensor = torch.full([N, ], fill_value=t, dtype=torch.long, device=self._dummy.device)
v_next, R_next, eps_p, c_denoised = self.eps_net(
v_t, p_t, s_t, res_feat, pair_feat, beta, mask_generate, mask_res
) # (N, L, 3), (N, L, 3, 3), (N, L, 3)
v_next = self.trans_rot.denoise(v_t, v_next, mask_generate, t_tensor)
p_next = self.trans_pos.denoise(p_t, eps_p, mask_generate, t_tensor)
_, s_next = self.trans_seq.denoise(s_t, c_denoised, mask_generate, t_tensor)
if not sample_structure:
v_next, p_next = v_t, p_t
if not sample_sequence:
s_next = s_t
traj[t-1] = (v_next, self._unnormalize_position(p_next), s_next)
traj[t] = tuple(x.cpu() for x in traj[t]) # Move previous states to cpu memory.
return traj