File size: 20,368 Bytes
2fd6166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
"""
two separate diffusion models for human+object, with cross attention to communicate between them
"""
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional, Tuple
import math

from model.pvcnn.modules import Attention, PVConv, BallQueryHO
from model.pvcnn.pvcnn_utils import create_mlp_components, create_pointnet2_sa_components, create_pointnet2_fp_modules
from model.pvcnn.pvcnn_utils import get_timestep_embedding
import torch.nn.functional as F
from .pos_enc import get_embedder


def _scaled_dot_product_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    attn_mask: Optional[Tensor] = None,
    dropout_p: float = 0.0,
) -> Tuple[Tensor, Tensor]:
    r"""
    Computes scaled dot product attention on query, key and value tensors, using
    an optional attention mask if passed, and applying dropout if a probability
    greater than 0.0 is specified.
    Returns a tensor pair containing attended values and attention weights.

    Args:
        q, k, v: query, key and value tensors. See Shape section for shape details.
        attn_mask: optional tensor containing mask values to be added to calculated
            attention. May be 2D or 3D; see Shape section for details.
        dropout_p: dropout probability. If greater than 0.0, dropout is applied.

    Shape:
        - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
            and E is embedding dimension.
        - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
            and E is embedding dimension.
        - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
            and E is embedding dimension.
        - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
            shape :math:`(Nt, Ns)`.

        - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
            have shape :math:`(B, Nt, Ns)`
    """
    B, Nt, E = q.shape
    q = q / math.sqrt(E)
    # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
    if attn_mask is not None:
        attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
    else:
        attn = torch.bmm(q, k.transpose(-2, -1))

    attn = F.softmax(attn, dim=-1)
    if dropout_p > 0.0:
        attn = F.dropout(attn, p=dropout_p) # this is only for training?
    # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
    output = torch.bmm(attn, v)
    return output, attn


class PVCNN2HumObj(nn.Module):
    sa_blocks = [
        # (out_channel, num_blocks, voxel_reso), (num_centers, radius, num_neighbors, out_channels)
        ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
        ((64, 3, 16), (256, 0.2, 32, (64, 128))),
        ((128, 3, 8), (64, 0.4, 32, (128, 256))),
        (None, (16, 0.8, 32, (256, 256, 512))),
    ]
    fp_blocks = [
        # (out, in_channels), (out_channels, num_blocks, voxel_resolution)
        ((256, 256), (256, 3, 8)),
        ((256, 256), (256, 3, 8)),
        ((256, 128), (128, 2, 16)),
        ((128, 128, 64), (64, 2, 32)),
    ]

    def __init__(
            self,
            num_classes: int,
            embed_dim: int,
            use_att: bool = True,
            dropout: float = 0.1,
            extra_feature_channels: int = 3,
            width_multiplier: int = 1,
            voxel_resolution_multiplier: int = 1,
            attn_type: str = 'simple-cross', #
            attn_weight: float=1.0, # attention feature weight
            multires: int = 10, # positional encoding resolution
            num_neighbours: int = 32 # ball query neighbours
    ):
        super(PVCNN2HumObj, self).__init__()
        assert extra_feature_channels >= 0
        self.embed_dim = embed_dim
        self.dropout = dropout
        self.width_multiplier = width_multiplier
        self.num_neighbours = num_neighbours
        self.in_channels = extra_feature_channels + 3

        self.attn_type = attn_type # how to compute attention
        self.attn_weight = attn_weight

        # separate human/object model
        classifier, embedf, fp_layers, global_att, sa_layers = self.make_modules(dropout, embed_dim,
                                                                                 extra_feature_channels, num_classes,
                                                                                 use_att, voxel_resolution_multiplier,
                                                                                 width_multiplier)

        self.sa_layers_hum = sa_layers
        self.global_att_hum = global_att
        self.fp_layers_hum = fp_layers
        self.classifier_hum = classifier
        self.embedf_hum = embedf
        self.posi_encoder, _ = get_embedder(multires)

        classifier, embedf, fp_layers, global_att, sa_layers = self.make_modules(dropout, embed_dim,
                                                                                 extra_feature_channels, num_classes,
                                                                                 use_att, voxel_resolution_multiplier,
                                                                                 width_multiplier)

        self.sa_layers_obj = sa_layers
        self.global_att_obj = global_att
        self.fp_layers_obj = fp_layers
        self.classifier_obj = classifier
        self.embedf_obj = embedf

        self.make_coord_attn()
        assert self.attn_type == 'coord3d+posenc-learnable', f'unknown attention type {self.attn_type}'

    def make_modules(self, dropout, embed_dim, extra_feature_channels, num_classes, use_att,
                     voxel_resolution_multiplier, width_multiplier):
        """
        make module for human/object
        :param dropout:
        :param embed_dim:
        :param extra_feature_channels:
        :param num_classes:
        :param use_att:
        :param voxel_resolution_multiplier:
        :param width_multiplier:
        :return:
        """
        in_ch_multiplier = 1
        extra_in_channel = 63  # the segmentation+positional feature is projected to dim 63

        # Create PointNet-2 model
        sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
            sa_blocks_config=self.sa_blocks,
            extra_feature_channels=extra_feature_channels,
            with_se=True,
            embed_dim=embed_dim,
            use_att=use_att,
            dropout=dropout,
            width_multiplier=width_multiplier,
            voxel_resolution_multiplier=voxel_resolution_multiplier,
            in_ch_multiplier=in_ch_multiplier,
            extra_in_channel=extra_in_channel
        )
        sa_layers = nn.ModuleList(sa_layers)
        # Additional global attention module, default true
        if self.attn_type == 'coord3d+posenc+rgb':
            # reduce channel number, only for the global attention layer, the decoders remain unchanged
            global_att = None if not use_att else Attention(channels_sa_features//2, 8, D=1)
        else:
            global_att = None if not use_att else Attention(channels_sa_features, 8, D=1)

        # Only use extra features in the last fp module
        sa_in_channels[0] = extra_feature_channels
        fp_layers, channels_fp_features = create_pointnet2_fp_modules(
            fp_blocks=self.fp_blocks,
            in_channels=channels_sa_features,
            sa_in_channels=sa_in_channels,
            with_se=True,
            embed_dim=embed_dim,
            use_att=use_att,
            dropout=dropout,
            width_multiplier=width_multiplier,
            voxel_resolution_multiplier=voxel_resolution_multiplier,
            in_ch_multiplier=in_ch_multiplier,
            extra_in_channel=extra_in_channel
        )
        fp_layers = nn.ModuleList(fp_layers)

        # Create MLP layers for output prediction
        layers, _ = create_mlp_components(
            in_channels=channels_fp_features,
            out_channels=[128, dropout, num_classes],  # was 0.5
            classifier=True,
            dim=2,
            width_multiplier=width_multiplier
        )
        classifier = nn.Sequential(*layers)  # applied to point features directly
        # Time embedding function
        embedf = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Linear(embed_dim, embed_dim),
        )
        return classifier, embedf, fp_layers, global_att, sa_layers

    def make_coord_attn(self):
        "learnable attention only on point coordinate + positional encoding "
        pvconv_encoders = []
        for i, (conv_configs, sa_configs) in enumerate(self.sa_blocks):
            # should use point net out channel
            out_channel = 63
            layer = nn.MultiheadAttention(out_channel, 1, batch_first=True, kdim=out_channel, vdim=out_channel+2)
            pvconv_encoders.append(layer)  # only one block for conv
        pvconv_decoders = []
        for fp_configs, conv_configs in self.fp_blocks:
            out_channel = 63
            layer = nn.MultiheadAttention(out_channel, 1, batch_first=True, kdim=out_channel, vdim=out_channel + 2)
            pvconv_decoders.append(layer)
        self.cross_conv_encoders = nn.ModuleList(pvconv_encoders)
        self.cross_conv_decoders = nn.ModuleList(pvconv_decoders)

    def forward(self, inputs_hum: torch.Tensor, inputs_obj: torch.Tensor, t: torch.Tensor, norm_params=None):
        """

        :param inputs: (B, N, D), N is the number of points, D is the conditional feature dimension
        :param t: (B, ) timestamps
        :param norm_params: (2, B, 4), transformation parameters that move points back to H+O joint space, first 3 values are cent, the last is radius/scale
        :return: (B, N, D_out) x2
        """
        inputs_hum = inputs_hum.transpose(1, 2)
        inputs_obj = inputs_obj.transpose(1, 2)

        # Embed timesteps, sinusoidal encoding
        t_emb_init = get_timestep_embedding(self.embed_dim, t, inputs_hum.device).float()
        t_emb_hum = self.embedf_hum(t_emb_init)[:, :, None].expand(-1, -1, inputs_hum.shape[-1]).float()
        t_emb_obj = self.embedf_obj(t_emb_init)[:, :, None].expand(-1, -1, inputs_obj.shape[-1]).float()

        # Separate input coordinates and features
        coords_hum, coords_obj = inputs_hum[:, :3, :].contiguous(), inputs_obj[:, :3, :].contiguous()  # (B, 3, N) range (-3.5, 3.5)
        features_hum, features_obj = inputs_hum, inputs_obj  # (B, 3 + S, N)

        DEBUG = False

        # Encoder: Downscaling layers
        coords_list_hum, coords_list_obj = [], []
        in_features_list_hum, in_features_list_obj = [], []
        for i, (sa_blocks_h, sa_blocks_o) in enumerate(zip(self.sa_layers_hum, self.sa_layers_obj)):
            in_features_list_hum.append(features_hum)
            coords_list_hum.append(coords_hum)
            in_features_list_obj.append(features_obj)
            coords_list_obj.append(coords_obj)
            if i == 0:
                # First step no timestamp embedding
                features_hum, coords_hum, t_emb_hum = sa_blocks_h((features_hum, coords_hum, t_emb_hum))
                features_obj, coords_obj, t_emb_obj = sa_blocks_o((features_obj, coords_obj, t_emb_obj))
            else:
                features_hum, coords_hum, t_emb_hum = sa_blocks_h((torch.cat([features_hum, t_emb_hum], dim=1), coords_hum, t_emb_hum))
                features_obj, coords_obj, t_emb_obj = sa_blocks_o((torch.cat([features_obj, t_emb_obj], dim=1), coords_obj, t_emb_obj))

            if i < len(self.sa_layers_hum)-1:
                features_hum, features_obj = self.add_attn_feature(features_hum, features_obj,
                                                               self.transform_coords(coords_hum, norm_params, 0),
                                                               self.transform_coords(coords_obj, norm_params, 1),
                                                               self.cross_conv_encoders[i],
                                                                   temb_hum=t_emb_hum,
                                                                   temb_obj=t_emb_obj)

        # for debug: save some point clouds
        if DEBUG:
            for i, (ch, co) in enumerate(zip(coords_list_hum, coords_list_obj)):
                import trimesh
                ch_ho = self.transform_coords(ch, norm_params, 0)
                co_ho = self.transform_coords(co, norm_params, 1)
                points = torch.cat([ch_ho, co_ho], -1).transpose(1, 2)
                L = ch_ho.shape[-1]
                vc = np.concatenate(
                    [np.zeros((L, 3)) + np.array([0.5, 1.0, 0]),
                     np.zeros((L, 3)) + np.array([0.05, 1.0, 1.0])]
                )
                trimesh.PointCloud(points[0].cpu().numpy(), colors=vc).export(
                    f'/BS/xxie-2/work/pc2-diff/experiments/debug/meshes/encoder_step{i:02d}.ply')

        # Replace the input features
        in_features_list_hum[0] = inputs_hum[:, 3:, :].contiguous()
        in_features_list_obj[0] = inputs_obj[:, 3:, :].contiguous()

        # Apply global attention layer
        if self.global_att_hum is not None:
            features_hum = self.global_att_hum(features_hum)

        if self.global_att_obj is not None:
            features_obj = self.global_att_obj(features_obj)
        # Do cross attention after self-attention
        if self.attn_type in ['coord3d+posenc-learnable']:
            features_hum, features_obj = self.add_attn_feature(features_hum, features_obj,
                                                               self.transform_coords(coords_hum, norm_params, 0),
                                                               self.transform_coords(coords_obj, norm_params, 1),
                                                               self.cross_conv_encoders[-1] if self.attn_type in [
                                                                   'coord3d+posenc-learnable'] else None,
                                                               temb_hum=t_emb_hum,
                                                               temb_obj=t_emb_obj)

        # Upscaling layers
        for fp_idx, (fp_blocks_h, fp_blocks_o) in enumerate(zip(self.fp_layers_hum, self.fp_layers_obj)):
            features_hum, coords_hum, t_emb_hum = fp_blocks_h(
                (  # this is a tuple because of nn.Sequential
                    coords_list_hum[-1 - fp_idx],  # reverse coords list from above
                    coords_hum,  # original point coordinates
                    torch.cat([features_hum, t_emb_hum], dim=1),  # keep concatenating upsampled features with timesteps
                    in_features_list_hum[-1 - fp_idx],  # reverse features list from above
                    t_emb_hum  # original timestep embedding
                )
                # this is where point voxel convolution is carried out, the point feature network preserves the order.
            )
            features_obj, coords_obj, t_emb_obj = fp_blocks_o(
                (  # this is a tuple because of nn.Sequential
                    coords_list_obj[-1 - fp_idx],  # reverse coords list from above
                    coords_obj,  # original point coordinates
                    torch.cat([features_obj, t_emb_obj], dim=1),  # keep concatenating upsampled features with timesteps
                    in_features_list_obj[-1 - fp_idx],  # reverse features list from above
                    t_emb_obj  # original timestep embedding
                )
                # this is where point voxel convolution is carried out, the point feature network preserves the order.
            )

            # these features are reused as input for next layer
            # add attention except for the last layer
            if fp_idx < len(self.fp_layers_hum) - 1:
                # Perform cross attention between human and object branches
                features_hum, features_obj = self.add_attn_feature(features_hum, features_obj,
                                                                   self.transform_coords(coords_hum, norm_params, 0),
                                                                   self.transform_coords(coords_obj, norm_params, 1),
                                                                   self.cross_conv_decoders[fp_idx] if self.attn_type in ['coord3d+posenc-learnable'] else None,
                                                                   temb_hum=t_emb_hum,
                                                                   temb_obj=t_emb_obj
                                                                   )

            if DEBUG:
                import trimesh
                ch_ho = self.transform_coords(coords_hum, norm_params, 0)
                co_ho = self.transform_coords(coords_obj, norm_params, 1)
                points = torch.cat([ch_ho, co_ho], -1).transpose(1, 2)
                L = ch_ho.shape[-1]
                vc = np.concatenate(
                    [np.zeros((L, 3)) + np.array([0.5, 1.0, 0]),
                     np.zeros((L, 3)) + np.array([0.05, 1.0, 1.0])]
                )
                trimesh.PointCloud(points[0].cpu().numpy(), colors=vc).export(
                    f'/BS/xxie-2/work/pc2-diff/experiments/debug/meshes/decoder_step{fp_idx:02d}.ply')

        if DEBUG:
            exit(0)
        # Output MLP layers
        output_hum = self.classifier_hum(features_hum).transpose(1, 2) # convert back to (B, N, D) format
        output_obj = self.classifier_obj(features_obj).transpose(1, 2)

        return output_hum, output_obj

    def transform_coords(self, coords, norm_params, target_ind):
        """
        transform coordinates such that the points align back to H+O interaction space
        :param coords: (B, 3, N)
        :param norm_params: (2, B, 4)
        :param target_ind: 0 or 1
        :return:
        """
        scale = norm_params[target_ind, :, 3:].unsqueeze(1)
        cent = norm_params[target_ind, :, :3].unsqueeze(-1)
        coords_ho = coords * 2 * scale + cent
        return coords_ho

    def add_attn_feature(self, features_hum, features_obj,
                         coords_hum=None, coords_obj=None,
                         attn_module=None,
                         temb_hum=None, temb_obj=None):
        """
        compute cross attention between human and object points
        :param features_hum: (B, D, N)
        :param features_obj: (B, D, N)
        :param coords_hum: (B, 3, N), human points in the H+O frame
        :param coords_obj: (B, 3, N), object points in the H+O frame
        :param temb: time embedding
        :return: cross attended human object features.
        """
        B, D, N = features_hum.shape
        # the attn_module is learnable, only difference is the number of output feature dimension
        onehot_hum, onehot_obj = self.get_onehot_feat(features_hum)
        pos_hum = self.posi_encoder(coords_hum.permute(0, 2, 1)).permute(0, 2, 1)
        pos_obj = self.posi_encoder(coords_obj.permute(0, 2, 1)).permute(0, 2, 1)
        feat_hum = torch.cat([pos_hum, onehot_hum], 1)
        feat_obj = torch.cat([pos_obj, onehot_obj], 1)  # (B, 65, N)

        attn_h2o = attn_module(pos_obj.permute(0, 2, 1),
                               pos_hum.permute(0, 2, 1),
                               feat_hum.permute(0, 2, 1))[0].permute(0, 2, 1)
        attn_o2h = attn_module(pos_hum.permute(0, 2, 1),
                               pos_obj.permute(0, 2, 1),
                               feat_obj.permute(0, 2, 1))[0].permute(0, 2, 1)
        features_hum = torch.cat([features_hum, attn_o2h * self.attn_weight], 1)
        features_obj = torch.cat([features_obj, attn_h2o * self.attn_weight], 1)

        return features_hum, features_obj

    def get_onehot_feat(self, features_hum):
        """
        compute a onehot feature vector to identify this is human or object
        :param features_hum:
        :return: (B, 2, N) x2 for human and object
        """
        B, D, N = features_hum.shape
        onehot_hum = torch.zeros(B, 2, N).to(features_hum.device)
        onehot_hum[:, 0] = 1.
        onehot_obj = torch.zeros(B, 2, N).to(features_hum.device)
        onehot_obj[:, 1] = 1.0
        return onehot_hum, onehot_obj