DiffAb / diffab /modules /common /structure.py
luost26's picture
Update
753e275
raw
history blame
2.87 kB
import torch
from torch.nn import Module, Linear, LayerNorm, Sequential, ReLU
from ..common.geometry import compose_rotation_and_translation, quaternion_to_rotation_matrix, repr_6d_to_rotation_matrix
class FrameRotationTranslationPrediction(Module):
def __init__(self, feat_dim, rot_repr, nn_type='mlp'):
super().__init__()
assert rot_repr in ('quaternion', '6d')
self.rot_repr = rot_repr
if rot_repr == 'quaternion':
out_dim = 3 + 3
elif rot_repr == '6d':
out_dim = 6 + 3
if nn_type == 'linear':
self.nn = Linear(feat_dim, out_dim)
elif nn_type == 'mlp':
self.nn = Sequential(
Linear(feat_dim, feat_dim), ReLU(),
Linear(feat_dim, feat_dim), ReLU(),
Linear(feat_dim, out_dim)
)
else:
raise ValueError('Unknown nn_type: %s' % nn_type)
def forward(self, x):
y = self.nn(x) # (..., d+3)
if self.rot_repr == 'quaternion':
quaternion = torch.cat([torch.ones_like(y[..., :1]), y[..., 0:3]], dim=-1)
R_delta = quaternion_to_rotation_matrix(quaternion)
t_delta = y[..., 3:6]
return R_delta, t_delta
elif self.rot_repr == '6d':
R_delta = repr_6d_to_rotation_matrix(y[..., 0:6])
t_delta = y[..., 6:9]
return R_delta, t_delta
class FrameUpdate(Module):
def __init__(self, node_feat_dim, rot_repr='quaternion', rot_tran_nn_type='mlp'):
super().__init__()
self.transition_mlp = Sequential(
Linear(node_feat_dim, node_feat_dim), ReLU(),
Linear(node_feat_dim, node_feat_dim), ReLU(),
Linear(node_feat_dim, node_feat_dim),
)
self.transition_layer_norm = LayerNorm(node_feat_dim)
self.rot_tran = FrameRotationTranslationPrediction(node_feat_dim, rot_repr, nn_type=rot_tran_nn_type)
def forward(self, R, t, x, mask_generate):
"""
Args:
R: Frame basis matrices, (N, L, 3, 3_index).
t: Frame external (absolute) coordinates, (N, L, 3). Unit: Angstrom.
x: Node-wise features, (N, L, F).
mask_generate: Masks, (N, L).
Returns:
R': Updated basis matrices, (N, L, 3, 3_index).
t': Updated coordinates, (N, L, 3).
"""
x = self.transition_layer_norm(x + self.transition_mlp(x))
R_delta, t_delta = self.rot_tran(x) # (N, L, 3, 3), (N, L, 3)
R_new, t_new = compose_rotation_and_translation(R, t, R_delta, t_delta)
mask_R = mask_generate[:, :, None, None].expand_as(R)
mask_t = mask_generate[:, :, None].expand_as(t)
R_new = torch.where(mask_R, R_new, R)
t_new = torch.where(mask_t, t_new, t)
return R_new, t_new