Spaces:
Runtime error
Runtime error
import torch | |
from torch.nn import Module, Linear, LayerNorm, Sequential, ReLU | |
from ..common.geometry import compose_rotation_and_translation, quaternion_to_rotation_matrix, repr_6d_to_rotation_matrix | |
class FrameRotationTranslationPrediction(Module): | |
def __init__(self, feat_dim, rot_repr, nn_type='mlp'): | |
super().__init__() | |
assert rot_repr in ('quaternion', '6d') | |
self.rot_repr = rot_repr | |
if rot_repr == 'quaternion': | |
out_dim = 3 + 3 | |
elif rot_repr == '6d': | |
out_dim = 6 + 3 | |
if nn_type == 'linear': | |
self.nn = Linear(feat_dim, out_dim) | |
elif nn_type == 'mlp': | |
self.nn = Sequential( | |
Linear(feat_dim, feat_dim), ReLU(), | |
Linear(feat_dim, feat_dim), ReLU(), | |
Linear(feat_dim, out_dim) | |
) | |
else: | |
raise ValueError('Unknown nn_type: %s' % nn_type) | |
def forward(self, x): | |
y = self.nn(x) # (..., d+3) | |
if self.rot_repr == 'quaternion': | |
quaternion = torch.cat([torch.ones_like(y[..., :1]), y[..., 0:3]], dim=-1) | |
R_delta = quaternion_to_rotation_matrix(quaternion) | |
t_delta = y[..., 3:6] | |
return R_delta, t_delta | |
elif self.rot_repr == '6d': | |
R_delta = repr_6d_to_rotation_matrix(y[..., 0:6]) | |
t_delta = y[..., 6:9] | |
return R_delta, t_delta | |
class FrameUpdate(Module): | |
def __init__(self, node_feat_dim, rot_repr='quaternion', rot_tran_nn_type='mlp'): | |
super().__init__() | |
self.transition_mlp = Sequential( | |
Linear(node_feat_dim, node_feat_dim), ReLU(), | |
Linear(node_feat_dim, node_feat_dim), ReLU(), | |
Linear(node_feat_dim, node_feat_dim), | |
) | |
self.transition_layer_norm = LayerNorm(node_feat_dim) | |
self.rot_tran = FrameRotationTranslationPrediction(node_feat_dim, rot_repr, nn_type=rot_tran_nn_type) | |
def forward(self, R, t, x, mask_generate): | |
""" | |
Args: | |
R: Frame basis matrices, (N, L, 3, 3_index). | |
t: Frame external (absolute) coordinates, (N, L, 3). Unit: Angstrom. | |
x: Node-wise features, (N, L, F). | |
mask_generate: Masks, (N, L). | |
Returns: | |
R': Updated basis matrices, (N, L, 3, 3_index). | |
t': Updated coordinates, (N, L, 3). | |
""" | |
x = self.transition_layer_norm(x + self.transition_mlp(x)) | |
R_delta, t_delta = self.rot_tran(x) # (N, L, 3, 3), (N, L, 3) | |
R_new, t_new = compose_rotation_and_translation(R, t, R_delta, t_delta) | |
mask_R = mask_generate[:, :, None, None].expand_as(R) | |
mask_t = mask_generate[:, :, None].expand_as(t) | |
R_new = torch.where(mask_R, R_new, R) | |
t_new = torch.where(mask_t, t_new, t) | |
return R_new, t_new | |