File size: 2,869 Bytes
753e275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from torch.nn import Module, Linear, LayerNorm, Sequential, ReLU

from ..common.geometry import compose_rotation_and_translation, quaternion_to_rotation_matrix, repr_6d_to_rotation_matrix


class FrameRotationTranslationPrediction(Module):

    def __init__(self, feat_dim, rot_repr, nn_type='mlp'):
        super().__init__()
        assert rot_repr in ('quaternion', '6d')
        self.rot_repr = rot_repr
        if rot_repr == 'quaternion':
            out_dim = 3 + 3
        elif rot_repr == '6d':
            out_dim = 6 + 3
        
        if nn_type == 'linear':
            self.nn = Linear(feat_dim, out_dim)
        elif nn_type == 'mlp':
            self.nn = Sequential(
                Linear(feat_dim, feat_dim), ReLU(),
                Linear(feat_dim, feat_dim), ReLU(),
                Linear(feat_dim, out_dim)
            )
        else:
            raise ValueError('Unknown nn_type: %s' % nn_type)

    def forward(self, x):
        y = self.nn(x)  # (..., d+3)
        if self.rot_repr == 'quaternion':
            quaternion = torch.cat([torch.ones_like(y[..., :1]), y[..., 0:3]], dim=-1)
            R_delta = quaternion_to_rotation_matrix(quaternion)
            t_delta = y[..., 3:6]
            return R_delta, t_delta
        elif self.rot_repr == '6d':
            R_delta = repr_6d_to_rotation_matrix(y[..., 0:6])
            t_delta = y[..., 6:9]
            return R_delta, t_delta


class FrameUpdate(Module):

    def __init__(self, node_feat_dim, rot_repr='quaternion', rot_tran_nn_type='mlp'):
        super().__init__()
        self.transition_mlp = Sequential(
            Linear(node_feat_dim, node_feat_dim), ReLU(),
            Linear(node_feat_dim, node_feat_dim), ReLU(),
            Linear(node_feat_dim, node_feat_dim),
        )
        self.transition_layer_norm = LayerNorm(node_feat_dim)

        self.rot_tran = FrameRotationTranslationPrediction(node_feat_dim, rot_repr, nn_type=rot_tran_nn_type)
    
    def forward(self, R, t, x, mask_generate):
        """
        Args:
            R:  Frame basis matrices, (N, L, 3, 3_index).
            t:  Frame external (absolute) coordinates, (N, L, 3). Unit: Angstrom.
            x:  Node-wise features, (N, L, F).
            mask_generate:   Masks, (N, L).
        Returns:
            R': Updated basis matrices, (N, L, 3, 3_index).
            t': Updated coordinates, (N, L, 3).
        """
        x = self.transition_layer_norm(x + self.transition_mlp(x))

        R_delta, t_delta = self.rot_tran(x) # (N, L, 3, 3), (N, L, 3)
        R_new, t_new = compose_rotation_and_translation(R, t, R_delta, t_delta)

        mask_R = mask_generate[:, :, None, None].expand_as(R)
        mask_t = mask_generate[:, :, None].expand_as(t)

        R_new = torch.where(mask_R, R_new, R)
        t_new = torch.where(mask_t, t_new, t)

        return R_new, t_new