File size: 19,670 Bytes
dbac20f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
import logging
from pathlib import Path

import einops
import torch
from omegaconf import OmegaConf
from timm.layers import trunc_normal_
from torch import nn

from mmaudio.ext.synchformer.utils import check_if_file_exists_else_download
from mmaudio.ext.synchformer.video_model_builder import VisionTransformer

FILE2URL = {
    # cfg
    'motionformer_224_16x4.yaml':
    'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/motionformer_224_16x4.yaml',
    'joint_224_16x4.yaml':
    'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/joint_224_16x4.yaml',
    'divided_224_16x4.yaml':
    'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/divided_224_16x4.yaml',
    # ckpt
    'ssv2_motionformer_224_16x4.pyth':
    'https://dl.fbaipublicfiles.com/motionformer/ssv2_motionformer_224_16x4.pyth',
    'ssv2_joint_224_16x4.pyth':
    'https://dl.fbaipublicfiles.com/motionformer/ssv2_joint_224_16x4.pyth',
    'ssv2_divided_224_16x4.pyth':
    'https://dl.fbaipublicfiles.com/motionformer/ssv2_divided_224_16x4.pyth',
}


class MotionFormer(VisionTransformer):
    ''' This class serves three puposes:
            1. Renames the class to MotionFormer.
            2. Downloads the cfg from the original repo and patches it if needed.
            3. Takes care of feature extraction by redefining .forward()
                - if `extract_features=True` and `factorize_space_time=False`,
                    the output is of shape (B, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
                - if `extract_features=True` and `factorize_space_time=True`, the output is of shape (B*S, D)
                    and spatial and temporal transformer encoder layers are used.
                - if `extract_features=True` and `factorize_space_time=True` as well as `add_global_repr=True`
                    the output is of shape (B, D) and spatial and temporal transformer encoder layers
                    are used as well as the global representation is extracted from segments (extra pos emb
                    is added).
    '''

    def __init__(
        self,
        extract_features: bool = False,
        ckpt_path: str = None,
        factorize_space_time: bool = None,
        agg_space_module: str = None,
        agg_time_module: str = None,
        add_global_repr: bool = True,
        agg_segments_module: str = None,
        max_segments: int = None,
    ):
        self.extract_features = extract_features
        self.ckpt_path = ckpt_path
        self.factorize_space_time = factorize_space_time

        if self.ckpt_path is not None:
            check_if_file_exists_else_download(self.ckpt_path, FILE2URL)
            ckpt = torch.load(self.ckpt_path, map_location='cpu')
            mformer_ckpt2cfg = {
                'ssv2_motionformer_224_16x4.pyth': 'motionformer_224_16x4.yaml',
                'ssv2_joint_224_16x4.pyth': 'joint_224_16x4.yaml',
                'ssv2_divided_224_16x4.pyth': 'divided_224_16x4.yaml',
            }
            # init from motionformer ckpt or from our Stage I ckpt
            # depending on whether the feat extractor was pre-trained on AVCLIPMoCo or not, we need to
            # load the state dict differently
            was_pt_on_avclip = self.ckpt_path.endswith(
                '.pt')  # checks if it is a stage I ckpt (FIXME: a bit generic)
            if self.ckpt_path.endswith(tuple(mformer_ckpt2cfg.keys())):
                cfg_fname = mformer_ckpt2cfg[Path(self.ckpt_path).name]
            elif was_pt_on_avclip:
                # TODO: this is a hack, we should be able to get the cfg from the ckpt (earlier ckpt didn't have it)
                s1_cfg = ckpt.get('args', None)  # Stage I cfg
                if s1_cfg is not None:
                    s1_vfeat_extractor_ckpt_path = s1_cfg.model.params.vfeat_extractor.params.ckpt_path
                    # if the stage I ckpt was initialized from a motionformer ckpt or train from scratch
                    if s1_vfeat_extractor_ckpt_path is not None:
                        cfg_fname = mformer_ckpt2cfg[Path(s1_vfeat_extractor_ckpt_path).name]
                    else:
                        cfg_fname = 'divided_224_16x4.yaml'
                else:
                    cfg_fname = 'divided_224_16x4.yaml'
            else:
                raise ValueError(f'ckpt_path {self.ckpt_path} is not supported.')
        else:
            was_pt_on_avclip = False
            cfg_fname = 'divided_224_16x4.yaml'
            # logging.info(f'No ckpt_path provided, using {cfg_fname} config.')

        if cfg_fname in ['motionformer_224_16x4.yaml', 'divided_224_16x4.yaml']:
            pos_emb_type = 'separate'
        elif cfg_fname == 'joint_224_16x4.yaml':
            pos_emb_type = 'joint'

        self.mformer_cfg_path = Path(__file__).absolute().parent / cfg_fname

        check_if_file_exists_else_download(self.mformer_cfg_path, FILE2URL)
        mformer_cfg = OmegaConf.load(self.mformer_cfg_path)
        logging.info(f'Loading MotionFormer config from {self.mformer_cfg_path.absolute()}')

        # patch the cfg (from the default cfg defined in the repo `Motionformer/slowfast/config/defaults.py`)
        mformer_cfg.VIT.ATTN_DROPOUT = 0.0
        mformer_cfg.VIT.POS_EMBED = pos_emb_type
        mformer_cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE = True
        mformer_cfg.VIT.APPROX_ATTN_TYPE = 'none'  # guessing
        mformer_cfg.VIT.APPROX_ATTN_DIM = 64  # from ckpt['cfg']

        # finally init VisionTransformer with the cfg
        super().__init__(mformer_cfg)

        # load the ckpt now if ckpt is provided and not from AVCLIPMoCo-pretrained ckpt
        if (self.ckpt_path is not None) and (not was_pt_on_avclip):
            _ckpt_load_status = self.load_state_dict(ckpt['model_state'], strict=False)
            if len(_ckpt_load_status.missing_keys) > 0 or len(
                    _ckpt_load_status.unexpected_keys) > 0:
                logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed.' \
                                f'Missing keys: {_ckpt_load_status.missing_keys}, ' \
                                f'Unexpected keys: {_ckpt_load_status.unexpected_keys}')
            else:
                logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')

        if self.extract_features:
            assert isinstance(self.norm,
                              nn.LayerNorm), 'early x[:, 1:, :] may not be safe for per-tr weights'
            # pre-logits are Sequential(nn.Linear(emb, emd), act) and `act` is tanh but see the logger
            self.pre_logits = nn.Identity()
            # we don't need the classification head (saving memory)
            self.head = nn.Identity()
            self.head_drop = nn.Identity()
            # avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
            transf_enc_layer_kwargs = dict(
                d_model=self.embed_dim,
                nhead=self.num_heads,
                activation=nn.GELU(),
                batch_first=True,
                dim_feedforward=self.mlp_ratio * self.embed_dim,
                dropout=self.drop_rate,
                layer_norm_eps=1e-6,
                norm_first=True,
            )
            # define adapters if needed
            if self.factorize_space_time:
                if agg_space_module == 'TransformerEncoderLayer':
                    self.spatial_attn_agg = SpatialTransformerEncoderLayer(
                        **transf_enc_layer_kwargs)
                elif agg_space_module == 'AveragePooling':
                    self.spatial_attn_agg = AveragePooling(avg_pattern='BS D t h w -> BS D t',
                                                           then_permute_pattern='BS D t -> BS t D')
                if agg_time_module == 'TransformerEncoderLayer':
                    self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
                elif agg_time_module == 'AveragePooling':
                    self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D')
                elif 'Identity' in agg_time_module:
                    self.temp_attn_agg = nn.Identity()
            # define a global aggregation layer (aggregarate over segments)
            self.add_global_repr = add_global_repr
            if add_global_repr:
                if agg_segments_module == 'TransformerEncoderLayer':
                    # we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
                    # we need to add pos emb (PE) because previously we added the same PE for each segment
                    pos_max_len = max_segments if max_segments is not None else 16  # 16 = 10sec//0.64sec + 1
                    self.global_attn_agg = TemporalTransformerEncoderLayer(
                        add_pos_emb=True,
                        pos_emb_drop=mformer_cfg.VIT.POS_DROPOUT,
                        pos_max_len=pos_max_len,
                        **transf_enc_layer_kwargs)
                elif agg_segments_module == 'AveragePooling':
                    self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D')

        if was_pt_on_avclip:
            # we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
            # and keep only the state_dict of the feat extractor
            ckpt_weights = dict()
            for k, v in ckpt['state_dict'].items():
                if k.startswith(('module.v_encoder.', 'v_encoder.')):
                    k = k.replace('module.', '').replace('v_encoder.', '')
                    ckpt_weights[k] = v
            _load_status = self.load_state_dict(ckpt_weights, strict=False)
            if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
                logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed. \n' \
                                f'Missing keys ({len(_load_status.missing_keys)}): ' \
                                f'{_load_status.missing_keys}, \n' \
                                f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
                                f'{_load_status.unexpected_keys} \n' \
                                f'temp_attn_agg are expected to be missing if ckpt was pt contrastively.')
            else:
                logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')

        # patch_embed is not used in MotionFormer, only patch_embed_3d, because cfg.VIT.PATCH_SIZE_TEMP > 1
        # but it used to calculate the number of patches, so we need to set keep it
        self.patch_embed.requires_grad_(False)

    def forward(self, x):
        '''
        x is of shape (B, S, C, T, H, W) where S is the number of segments.
        '''
        # Batch, Segments, Channels, T=frames, Height, Width
        B, S, C, T, H, W = x.shape
        # Motionformer expects a tensor of shape (1, B, C, T, H, W).
        # The first dimension (1) is a dummy dimension to make the input tensor and won't be used:
        # see `video_model_builder.video_input`.
        # x = x.unsqueeze(0)  # (1, B, S, C, T, H, W)

        orig_shape = (B, S, C, T, H, W)
        x = x.view(B * S, C, T, H, W)  # flatten batch and segments
        x = self.forward_segments(x, orig_shape=orig_shape)
        # unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
        x = x.view(B, S, *x.shape[1:])
        # x is now of shape (B*S, D) or (B*S, t, D) if `self.temp_attn_agg` is `Identity`

        return x  # x is (B, S, ...)

    def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor:
        '''x is of shape (1, BS, C, T, H, W) where S is the number of segments.'''
        x, x_mask = self.forward_features(x)

        assert self.extract_features

        # (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
        x = x[:,
              1:, :]  # without the CLS token for efficiency (should be safe for LayerNorm and FC)
        x = self.norm(x)
        x = self.pre_logits(x)
        if self.factorize_space_time:
            x = self.restore_spatio_temp_dims(x, orig_shape)  # (B*S, D, t, h, w) <- (B*S, t*h*w, D)

            x = self.spatial_attn_agg(x, x_mask)  # (B*S, t, D)
            x = self.temp_attn_agg(
                x)  # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity`

        return x

    def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
        '''
            feats are of shape (B*S, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
            Our goal is to make them of shape (B*S, t, h, w, D) where h, w are the spatial dimensions.
            From `self.patch_embed_3d`, it follows that we could reshape feats with:
                `feats.transpose(1, 2).view(B*S, D, t, h, w)`
        '''
        B, S, C, T, H, W = orig_shape
        D = self.embed_dim

        # num patches in each dimension
        t = T // self.patch_embed_3d.z_block_size
        h = self.patch_embed_3d.height
        w = self.patch_embed_3d.width

        feats = feats.permute(0, 2, 1)  # (B*S, D, T)
        feats = feats.view(B * S, D, t, h, w)  # (B*S, D, t, h, w)

        return feats


class BaseEncoderLayer(nn.TransformerEncoderLayer):
    '''
        This is a wrapper around nn.TransformerEncoderLayer that adds a CLS token
        to the sequence and outputs the CLS token's representation.
        This base class parents both SpatialEncoderLayer and TemporalEncoderLayer for the RGB stream
        and the FrequencyEncoderLayer and TemporalEncoderLayer for the audio stream stream.
        We also, optionally, add a positional embedding to the input sequence which
        allows to reuse it for global aggregation (of segments) for both streams.
    '''

    def __init__(self,
                 add_pos_emb: bool = False,
                 pos_emb_drop: float = None,
                 pos_max_len: int = None,
                 *args_transformer_enc,
                 **kwargs_transformer_enc):
        super().__init__(*args_transformer_enc, **kwargs_transformer_enc)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim))
        trunc_normal_(self.cls_token, std=.02)

        # add positional embedding
        self.add_pos_emb = add_pos_emb
        if add_pos_emb:
            self.pos_max_len = 1 + pos_max_len  # +1 (for CLS)
            self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim))
            self.pos_drop = nn.Dropout(pos_emb_drop)
            trunc_normal_(self.pos_emb, std=.02)

        self.apply(self._init_weights)

    def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
        ''' x is of shape (B, N, D); if provided x_mask is of shape (B, N)'''
        batch_dim = x.shape[0]

        # add CLS token
        cls_tokens = self.cls_token.expand(batch_dim, -1, -1)  # expanding to match batch dimension
        x = torch.cat((cls_tokens, x), dim=-2)  # (batch_dim, 1+seq_len, D)
        if x_mask is not None:
            cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool,
                                  device=x_mask.device)  # 1=keep; 0=mask
            x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1)  # (batch_dim, 1+seq_len)
            B, N = x_mask_w_cls.shape
            # torch expects (N, N) or (B*num_heads, N, N) mask (sadness ahead); torch masks
            x_mask_w_cls = x_mask_w_cls.reshape(B, 1, 1, N)\
                                       .expand(-1, self.self_attn.num_heads, N, -1)\
                                       .reshape(B * self.self_attn.num_heads, N, N)
            assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, 'x_mask_w_cls.dtype != bool'
            x_mask_w_cls = ~x_mask_w_cls  # invert mask (1=mask)
        else:
            x_mask_w_cls = None

        # add positional embedding
        if self.add_pos_emb:
            seq_len = x.shape[
                1]  # (don't even think about moving it before the CLS token concatenation)
            assert seq_len <= self.pos_max_len, f'Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})'
            x = x + self.pos_emb[:, :seq_len, :]
            x = self.pos_drop(x)

        # apply encoder layer (calls nn.TransformerEncoderLayer.forward);
        x = super().forward(src=x, src_mask=x_mask_w_cls)  # (batch_dim, 1+seq_len, D)

        # CLS token is expected to hold spatial information for each frame
        x = x[:, 0, :]  # (batch_dim, D)

        return x

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'cls_token', 'pos_emb'}


class SpatialTransformerEncoderLayer(BaseEncoderLayer):
    ''' Aggregates spatial dimensions by applying attention individually to each frame. '''

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
        ''' x is of shape (B*S, D, t, h, w) where S is the number of segments.
            if specified x_mask (B*S, t, h, w), 0=masked, 1=kept
            Returns a tensor of shape (B*S, t, D) pooling spatial information for each frame. '''
        BS, D, t, h, w = x.shape

        # time as a batch dimension and flatten spatial dimensions as sequence
        x = einops.rearrange(x, 'BS D t h w -> (BS t) (h w) D')
        # similar to mask
        if x_mask is not None:
            x_mask = einops.rearrange(x_mask, 'BS t h w -> (BS t) (h w)')

        # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
        x = super().forward(x=x, x_mask=x_mask)  # (B*S*t, D)

        # reshape back to (B*S, t, D)
        x = einops.rearrange(x, '(BS t) D -> BS t D', BS=BS, t=t)

        # (B*S, t, D)
        return x


class TemporalTransformerEncoderLayer(BaseEncoderLayer):
    ''' Aggregates temporal dimension with attention. Also used with pos emb as global aggregation
    in both streams. '''

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x):
        ''' x is of shape (B*S, t, D) where S is the number of segments.
            Returns a tensor of shape (B*S, D) pooling temporal information. '''
        BS, t, D = x.shape

        # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
        x = super().forward(x)  # (B*S, D)

        return x  # (B*S, D)


class AveragePooling(nn.Module):

    def __init__(self, avg_pattern: str, then_permute_pattern: str = None) -> None:
        ''' patterns are e.g. "bs t d -> bs d" '''
        super().__init__()
        # TODO: need to register them as buffers (but fails because these are strings)
        self.reduce_fn = 'mean'
        self.avg_pattern = avg_pattern
        self.then_permute_pattern = then_permute_pattern

    def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
        x = einops.reduce(x, self.avg_pattern, self.reduce_fn)
        if self.then_permute_pattern is not None:
            x = einops.rearrange(x, self.then_permute_pattern)
        return x