merve HF staff commited on
Commit
e5d3156
1 Parent(s): d4d71ea

Upload 4 files

Browse files
Files changed (4) hide show
  1. hiera/__init__.py +43 -0
  2. hiera/hiera.py +535 -0
  3. hiera/hiera_mae.py +398 -0
  4. hiera/hiera_utils.py +287 -0
hiera/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+
8
+ from .hiera import (
9
+ hiera_tiny_224,
10
+ hiera_small_224,
11
+ hiera_base_224,
12
+ hiera_base_plus_224,
13
+ hiera_large_224,
14
+ hiera_huge_224,
15
+
16
+ hiera_base_16x224,
17
+ hiera_base_plus_16x224,
18
+ hiera_large_16x224,
19
+ hiera_huge_16x224,
20
+
21
+ Hiera,
22
+ HieraBlock,
23
+ MaskUnitAttention,
24
+ Head,
25
+ PatchEmbed,
26
+ )
27
+
28
+
29
+ from .hiera_mae import (
30
+ mae_hiera_tiny_224,
31
+ mae_hiera_small_224,
32
+ mae_hiera_base_224,
33
+ mae_hiera_base_plus_224,
34
+ mae_hiera_large_224,
35
+ mae_hiera_huge_224,
36
+
37
+ mae_hiera_base_16x224,
38
+ mae_hiera_base_plus_16x224,
39
+ mae_hiera_large_16x224,
40
+ mae_hiera_huge_16x224,
41
+
42
+ MaskedAutoencoderHiera,
43
+ )
hiera/hiera.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ #
8
+ # Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles
9
+ #
10
+ # Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan,
11
+ # Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed,
12
+ # Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer.
13
+ #
14
+ # Paper: https://arxiv.org/abs/2306.00989/
15
+ #
16
+ # References:
17
+ # slowfast: https://github.com/facebookresearch/SlowFast
18
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
19
+ # --------------------------------------------------------
20
+
21
+ import math
22
+ from functools import partial
23
+ from typing import List, Tuple, Callable, Optional
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+
29
+ from timm.models.layers import DropPath, Mlp
30
+
31
+ from .hiera_utils import pretrained_model, conv_nd, do_pool, do_masked_conv, Unroll, Reroll
32
+
33
+
34
+
35
+ class MaskUnitAttention(nn.Module):
36
+ """
37
+ Computes either Mask Unit or Global Attention. Also is able to perform q pooling.
38
+
39
+ Note: this assumes the tokens have already been flattened and unrolled into mask units.
40
+ See `Unroll` for more details.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dim: int,
46
+ dim_out: int,
47
+ heads: int,
48
+ q_stride: int = 1,
49
+ window_size: int = 0,
50
+ use_mask_unit_attn: bool = False,
51
+ ):
52
+ """
53
+ Args:
54
+ - dim, dim_out: The input and output feature dimensions.
55
+ - heads: The number of attention heads.
56
+ - q_stride: If greater than 1, pool q with this stride. The stride should be flattened (e.g., 2x2 = 4).
57
+ - window_size: The current (flattened) size of a mask unit *after* pooling (if any).
58
+ - use_mask_unit_attn: Use Mask Unit or Global Attention.
59
+ """
60
+ super().__init__()
61
+
62
+ self.dim = dim
63
+ self.dim_out = dim_out
64
+ self.heads = heads
65
+ self.q_stride = q_stride
66
+
67
+ self.head_dim = dim_out // heads
68
+ self.scale = (self.head_dim) ** -0.5
69
+
70
+ self.qkv = nn.Linear(dim, 3 * dim_out)
71
+ self.proj = nn.Linear(dim_out, dim_out)
72
+
73
+ self.window_size = window_size
74
+ self.use_mask_unit_attn = use_mask_unit_attn
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ """ Input should be of shape [batch, tokens, channels]. """
78
+ B, N, _ = x.shape
79
+ num_windows = (
80
+ (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
81
+ )
82
+
83
+ qkv = (
84
+ self.qkv(x)
85
+ .reshape(B, -1, num_windows, 3, self.heads, self.head_dim)
86
+ .permute(3, 0, 4, 2, 1, 5)
87
+ )
88
+ q, k, v = qkv[0], qkv[1], qkv[2]
89
+
90
+ if self.q_stride > 1:
91
+ # Refer to Unroll to see how this performs a maxpool-Nd
92
+ q = (
93
+ q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim)
94
+ .max(dim=3)
95
+ .values
96
+ )
97
+
98
+ if hasattr(F, "scaled_dot_product_attention"):
99
+ # Note: the original paper did *not* use SDPA, it's a free boost!
100
+ x = F.scaled_dot_product_attention(q, k, v)
101
+ else:
102
+ attn = (q * self.scale) @ k.transpose(-1, -2)
103
+ attn = attn.softmax(dim=-1)
104
+ x = (attn @ v)
105
+
106
+ x = x.transpose(1, 3).reshape(B, -1, self.dim_out)
107
+ x = self.proj(x)
108
+ return x
109
+
110
+
111
+ class HieraBlock(nn.Module):
112
+ def __init__(
113
+ self,
114
+ dim: int,
115
+ dim_out: int,
116
+ heads: int,
117
+ mlp_ratio: float = 4.0,
118
+ drop_path: float = 0.0,
119
+ norm_layer: nn.Module = nn.LayerNorm,
120
+ act_layer: nn.Module = nn.GELU,
121
+ q_stride: int = 1,
122
+ window_size: int = 0,
123
+ use_mask_unit_attn: bool = False,
124
+ ):
125
+ super().__init__()
126
+
127
+ self.dim = dim
128
+ self.dim_out = dim_out
129
+
130
+ self.norm1 = norm_layer(dim)
131
+ self.attn = MaskUnitAttention(
132
+ dim, dim_out, heads, q_stride, window_size, use_mask_unit_attn
133
+ )
134
+
135
+ self.norm2 = norm_layer(dim_out)
136
+ self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer)
137
+
138
+ self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
139
+ if dim != dim_out:
140
+ self.proj = nn.Linear(dim, dim_out)
141
+
142
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
143
+ # Attention + Q Pooling
144
+ x_norm = self.norm1(x)
145
+ if self.dim != self.dim_out:
146
+ x = do_pool(self.proj(x_norm), stride=self.attn.q_stride)
147
+ x = x + self.drop_path(self.attn(x_norm))
148
+
149
+ # MLP
150
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
151
+ return x
152
+
153
+
154
+ class Head(nn.Module):
155
+ def __init__(
156
+ self,
157
+ dim: int,
158
+ num_classes: int,
159
+ dropout_rate: float = 0.0,
160
+ act_func: Callable[[torch.Tensor], torch.Tensor] = lambda x: x.softmax(dim=-1),
161
+ ):
162
+ super().__init__()
163
+ self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
164
+ self.projection = nn.Linear(dim, num_classes)
165
+ # act_fun for eval and testing only
166
+ self.act_func = act_func
167
+
168
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
169
+ x = self.dropout(x)
170
+ x = self.projection(x)
171
+ if not self.training:
172
+ x = self.act_func(x)
173
+ return x
174
+
175
+
176
+ class PatchEmbed(nn.Module):
177
+ """Patch embed that supports any number of spatial dimensions (1d, 2d, 3d)."""
178
+
179
+ def __init__(
180
+ self,
181
+ dim_in: int,
182
+ dim_out: int,
183
+ kernel: Tuple[int, ...],
184
+ stride: Tuple[int, ...],
185
+ padding: Tuple[int, ...],
186
+ ):
187
+ super().__init__()
188
+
189
+ # Support any number of spatial dimensions
190
+ self.spatial_dims = len(kernel)
191
+ self.proj = conv_nd(self.spatial_dims)(
192
+ dim_in,
193
+ dim_out,
194
+ kernel_size=kernel,
195
+ stride=stride,
196
+ padding=padding,
197
+ )
198
+
199
+ def forward(
200
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
201
+ ) -> torch.Tensor:
202
+ x = do_masked_conv(x, self.proj, mask)
203
+ x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1)
204
+ return x
205
+
206
+
207
+ class Hiera(nn.Module):
208
+ def __init__(
209
+ self,
210
+ input_size: Tuple[int, ...] = (224, 224),
211
+ in_chans: int = 3,
212
+ embed_dim: int = 96, # initial embed dim
213
+ num_heads: int = 1, # initial number of heads
214
+ num_classes: int = 1000,
215
+ stages: Tuple[int, ...] = (2, 3, 16, 3),
216
+ q_pool: int = 3, # number of q_pool stages
217
+ q_stride: Tuple[int, ...] = (2, 2),
218
+ mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1)
219
+ # mask_unit_attn: which stages use mask unit attention?
220
+ mask_unit_attn: Tuple[bool, ...] = (True, True, False, False),
221
+ dim_mul: float = 2.0,
222
+ head_mul: float = 2.0,
223
+ patch_kernel: Tuple[int, ...] = (7, 7),
224
+ patch_stride: Tuple[int, ...] = (4, 4),
225
+ patch_padding: Tuple[int, ...] = (3, 3),
226
+ mlp_ratio: float = 4.0,
227
+ drop_path_rate: float = 0.0,
228
+ norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
229
+ head_dropout: float = 0.0,
230
+ head_init_scale: float = 0.001,
231
+ sep_pos_embed: bool = False,
232
+ ):
233
+ super().__init__()
234
+
235
+ depth = sum(stages)
236
+ self.patch_stride = patch_stride
237
+ self.tokens_spatial_shape = [i // s for i, s in zip(input_size, patch_stride)]
238
+ num_tokens = math.prod(self.tokens_spatial_shape)
239
+ flat_mu_size = math.prod(mask_unit_size)
240
+ flat_q_stride = math.prod(q_stride)
241
+
242
+ assert q_pool < len(stages)
243
+ self.q_pool, self.q_stride = q_pool, q_stride
244
+ self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
245
+ self.mask_spatial_shape = [
246
+ i // s for i, s in zip(self.tokens_spatial_shape, self.mask_unit_size)
247
+ ]
248
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
249
+
250
+ self.patch_embed = PatchEmbed(
251
+ in_chans, embed_dim, patch_kernel, patch_stride, patch_padding
252
+ )
253
+
254
+ self.sep_pos_embed = sep_pos_embed
255
+ if sep_pos_embed:
256
+ self.pos_embed_spatial = nn.Parameter(
257
+ torch.zeros(
258
+ 1,
259
+ self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
260
+ embed_dim,
261
+ )
262
+ )
263
+ self.pos_embed_temporal = nn.Parameter(
264
+ torch.zeros(1, self.tokens_spatial_shape[0], embed_dim)
265
+ )
266
+ else:
267
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
268
+
269
+ # Setup roll and reroll modules
270
+ self.unroll = Unroll(
271
+ input_size, patch_stride, [q_stride] * len(self.stage_ends[:-1])
272
+ )
273
+ self.reroll = Reroll(
274
+ input_size,
275
+ patch_stride,
276
+ [q_stride] * len(self.stage_ends[:-1]),
277
+ self.stage_ends,
278
+ q_pool,
279
+ )
280
+ # q_pool locations
281
+ q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]]
282
+ # stochastic depth decay rule
283
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
284
+
285
+ # Transformer blocks
286
+ cur_stage = 0
287
+ self.blocks = nn.ModuleList()
288
+
289
+ for i in range(depth):
290
+ dim_out = embed_dim
291
+ # Mask unit or global attention.
292
+ # Lag by 1 block, so that global attention,
293
+ # applied post pooling on lower resolution
294
+ use_mask_unit_attn = mask_unit_attn[cur_stage]
295
+
296
+ if i - 1 in self.stage_ends:
297
+ dim_out = int(embed_dim * dim_mul)
298
+ num_heads = int(num_heads * head_mul)
299
+ cur_stage += 1
300
+ if i in q_pool_blocks:
301
+ flat_mu_size //= flat_q_stride
302
+
303
+ block = HieraBlock(
304
+ dim=embed_dim,
305
+ dim_out=dim_out,
306
+ heads=num_heads,
307
+ mlp_ratio=mlp_ratio,
308
+ drop_path=dpr[i],
309
+ norm_layer=norm_layer,
310
+ q_stride=(flat_q_stride if i in q_pool_blocks else 1),
311
+ window_size=flat_mu_size,
312
+ use_mask_unit_attn=use_mask_unit_attn,
313
+ )
314
+
315
+ embed_dim = dim_out
316
+ self.blocks.append(block)
317
+
318
+ self.norm = norm_layer(embed_dim)
319
+ self.head = Head(embed_dim, num_classes, dropout_rate=head_dropout)
320
+
321
+ # Initialize everything
322
+ if sep_pos_embed:
323
+ nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
324
+ nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)
325
+ else:
326
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
327
+ self.apply(partial(self._init_weights))
328
+ self.head.projection.weight.data.mul_(head_init_scale)
329
+ self.head.projection.bias.data.mul_(head_init_scale)
330
+
331
+ def _init_weights(self, m, init_bias=0.02):
332
+ if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
333
+ nn.init.trunc_normal_(m.weight, std=0.02)
334
+ if isinstance(m, nn.Linear) and m.bias is not None:
335
+ nn.init.constant_(m.bias, init_bias)
336
+ elif isinstance(m, nn.LayerNorm):
337
+ nn.init.constant_(m.bias, init_bias)
338
+ nn.init.constant_(m.weight, 1.0)
339
+
340
+ @torch.jit.ignore
341
+ def no_weight_decay(self):
342
+ if self.sep_pos_embed:
343
+ return ["pos_embed_spatial", "pos_embed_temporal"]
344
+ else:
345
+ return ["pos_embed"]
346
+
347
+ def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
348
+ """
349
+ Generates a random mask, mask_ratio fraction are dropped.
350
+ 1 is *keep*, 0 is *remove*. Useful for MAE, FLIP, etc.
351
+ """
352
+ B = x.shape[0]
353
+ # Tokens selected for masking at mask unit level
354
+ num_windows = math.prod(self.mask_spatial_shape) # num_mask_units
355
+ len_keep = int(num_windows * (1 - mask_ratio))
356
+ noise = torch.rand(B, num_windows, device=x.device)
357
+
358
+ # Sort noise for each sample
359
+ ids_shuffle = torch.argsort(
360
+ noise, dim=1
361
+ ) # ascend: small is keep, large is remove
362
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
363
+
364
+ # Generate the binary mask: 1 is *keep*, 0 is *remove*
365
+ # Note this is opposite to original MAE
366
+ mask = torch.zeros([B, num_windows], device=x.device)
367
+ mask[:, :len_keep] = 1
368
+ # Unshuffle to get the binary mask
369
+ mask = torch.gather(mask, dim=1, index=ids_restore)
370
+
371
+ return mask.bool()
372
+
373
+ def get_pos_embed(self) -> torch.Tensor:
374
+ if self.sep_pos_embed:
375
+ return self.pos_embed_spatial.repeat(
376
+ 1, self.tokens_spatial_shape[0], 1
377
+ ) + torch.repeat_interleave(
378
+ self.pos_embed_temporal,
379
+ self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
380
+ dim=1,
381
+ )
382
+ else:
383
+ return self.pos_embed
384
+
385
+ def forward(
386
+ self,
387
+ x: torch.Tensor,
388
+ mask: torch.Tensor = None,
389
+ return_intermediates: bool = False,
390
+ ) -> torch.Tensor:
391
+ """
392
+ mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim.
393
+ Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch.
394
+ """
395
+ # Slowfast training passes in a list
396
+ if isinstance(x, list):
397
+ x = x[0]
398
+ intermediates = []
399
+
400
+ x = self.patch_embed(
401
+ x,
402
+ mask=mask.view(
403
+ x.shape[0], 1, *self.mask_spatial_shape
404
+ ) # B, C, *mask_spatial_shape
405
+ if mask is not None
406
+ else None,
407
+ )
408
+ x = x + self.get_pos_embed()
409
+ x = self.unroll(x)
410
+
411
+ # Discard masked tokens
412
+ if mask is not None:
413
+ x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(
414
+ x.shape[0], -1, x.shape[-1]
415
+ )
416
+
417
+ for i, blk in enumerate(self.blocks):
418
+ x = blk(x)
419
+
420
+ if return_intermediates and i in self.stage_ends:
421
+ intermediates.append(self.reroll(x, i, mask=mask))
422
+
423
+ if mask is None:
424
+ x = x.mean(dim=1)
425
+ x = self.norm(x)
426
+ x = self.head(x)
427
+
428
+ # x may not always be in spatial order here.
429
+ # e.g. if q_pool = 2, mask_unit_size = (8, 8), and
430
+ # q_stride = (2, 2), not all unrolls were consumed,
431
+ # intermediates[-1] is x in spatial order
432
+ if return_intermediates:
433
+ return x, intermediates
434
+
435
+ return x
436
+
437
+
438
+ # Image models
439
+
440
+ @pretrained_model({
441
+ "mae_in1k_ft_in1k": "https://huggingface.co/merve/hiera-tiny-ft-224-in1k/resolve/main/hiera_tiny_224.pth",
442
+ "mae_in1k": "https://huggingface.co/merve/hiera-tiny-224-in1k/resolve/main/mae_hiera_tiny_224.pth",
443
+ }, default="mae_in1k_ft_in1k")
444
+ def hiera_tiny_224(**kwdargs):
445
+ return Hiera(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2), **kwdargs)
446
+
447
+
448
+ @pretrained_model({
449
+ "mae_in1k_ft_in1k": "https://huggingface.co/merve/hiera-small-ft-224-in1k/resolve/main/hiera_small_224.pth",
450
+ "mae_in1k": "https://huggingface.co/merve/hiera-small-224-in1k/resolve/main/mae_hiera_small_224.pth",
451
+ }, default="mae_in1k_ft_in1k")
452
+ def hiera_small_224(**kwdargs):
453
+ return Hiera(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), **kwdargs)
454
+
455
+
456
+ @pretrained_model({
457
+ "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth",
458
+ "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth",
459
+ }, default="mae_in1k_ft_in1k")
460
+ def hiera_base_224(**kwdargs):
461
+ return Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), **kwdargs)
462
+
463
+
464
+ @pretrained_model({
465
+ "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth",
466
+ "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth",
467
+ }, default="mae_in1k_ft_in1k")
468
+ def hiera_base_plus_224(**kwdargs):
469
+ return Hiera(embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs)
470
+
471
+
472
+ @pretrained_model({
473
+ "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth",
474
+ "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth",
475
+ }, default="mae_in1k_ft_in1k")
476
+ def hiera_large_224(**kwdargs):
477
+ return Hiera(embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs)
478
+
479
+
480
+ @pretrained_model({
481
+ "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth",
482
+ "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth",
483
+ }, default="mae_in1k_ft_in1k")
484
+ def hiera_huge_224(**kwdargs):
485
+ return Hiera(embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs)
486
+
487
+
488
+ # Video models
489
+
490
+ @pretrained_model({
491
+ "mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_base_16x224.pth",
492
+ "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth",
493
+ }, default="mae_k400_ft_k400")
494
+ def hiera_base_16x224(num_classes: int = 400, **kwdargs):
495
+ return Hiera(
496
+ num_classes=num_classes, # K400 has 400 classes
497
+ input_size=(16, 224, 224),
498
+ q_stride=(1, 2, 2),
499
+ mask_unit_size=(1, 8, 8),
500
+ patch_kernel=(3, 7, 7),
501
+ patch_stride=(2, 4, 4),
502
+ patch_padding=(1, 3, 3),
503
+ sep_pos_embed=True,
504
+ **kwdargs
505
+ )
506
+
507
+
508
+ @pretrained_model({
509
+ "mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_16x224.pth",
510
+ "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth",
511
+ }, default="mae_k400_ft_k400")
512
+ def hiera_base_plus_16x224(**kwdargs):
513
+ return hiera_base_16x224(
514
+ embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs
515
+ )
516
+
517
+
518
+ @pretrained_model({
519
+ "mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_large_16x224.pth",
520
+ "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth",
521
+ }, default="mae_k400_ft_k400")
522
+ def hiera_large_16x224(**kwdargs):
523
+ return hiera_base_16x224(
524
+ embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs
525
+ )
526
+
527
+
528
+ @pretrained_model({
529
+ "mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_huge_16x224.pth",
530
+ "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth",
531
+ }, default="mae_k400_ft_k400")
532
+ def hiera_huge_16x224(**kwdargs):
533
+ return hiera_base_16x224(
534
+ embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs
535
+ )
hiera/hiera_mae.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # mae: https://github.com/facebookresearch/mae
9
+ # slowfast: https://github.com/facebookresearch/SlowFast
10
+ # --------------------------------------------------------
11
+
12
+
13
+ from functools import partial
14
+ from typing import Tuple, Optional
15
+
16
+ import math
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from .hiera import Hiera, HieraBlock
21
+ from .hiera_utils import pretrained_model, undo_windowing, conv_nd
22
+
23
+
24
+ def apply_fusion_head(head: nn.Module, x: torch.Tensor) -> torch.Tensor:
25
+ if isinstance(head, nn.Identity):
26
+ return x
27
+
28
+ B, num_mask_units = x.shape[0:2]
29
+ # Apply head, e.g [B, #MUs, My, Mx, C] -> head([B * #MUs, C, My, Mx])
30
+ permute = [0] + [len(x.shape) - 2] + list(range(1, len(x.shape) - 2))
31
+ x = head(x.reshape(B * num_mask_units, *x.shape[2:]).permute(permute))
32
+
33
+ # Restore original layout, e.g. [B * #MUs, C', My', Mx'] -> [B, #MUs, My', Mx', C']
34
+ permute = [0] + list(range(2, len(x.shape))) + [1]
35
+ x = x.permute(permute).reshape(B, num_mask_units, *x.shape[2:], x.shape[1])
36
+ return x
37
+
38
+
39
+ class MaskedAutoencoderHiera(Hiera):
40
+ """Masked Autoencoder with Hiera backbone"""
41
+
42
+ def __init__(
43
+ self,
44
+ in_chans: int = 3,
45
+ patch_stride: Tuple[int, ...] = (4, 4),
46
+ mlp_ratio: float = 4.0,
47
+ decoder_embed_dim: int = 512,
48
+ decoder_depth: int = 8,
49
+ decoder_num_heads: int = 16,
50
+ norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
51
+ **kwdargs,
52
+ ):
53
+ super().__init__(
54
+ in_chans=in_chans,
55
+ patch_stride=patch_stride,
56
+ mlp_ratio=mlp_ratio,
57
+ norm_layer=norm_layer,
58
+ **kwdargs,
59
+ )
60
+
61
+ del self.norm, self.head
62
+ encoder_dim_out = self.blocks[-1].dim_out
63
+ self.encoder_norm = norm_layer(encoder_dim_out)
64
+ self.mask_unit_spatial_shape_final = [
65
+ i // s ** (self.q_pool) for i, s in zip(self.mask_unit_size, self.q_stride)
66
+ ]
67
+ self.tokens_spatial_shape_final = [
68
+ i // s ** (self.q_pool)
69
+ for i, s in zip(self.tokens_spatial_shape, self.q_stride)
70
+ ]
71
+ # --------------------------------------------------------------------------
72
+ # Multi-scale fusion heads
73
+ curr_mu_size = self.mask_unit_size
74
+ self.multi_scale_fusion_heads = nn.ModuleList()
75
+
76
+ for i in self.stage_ends[: self.q_pool]: # resolution constant after q_pool
77
+ kernel = [
78
+ i // s for i, s in zip(curr_mu_size, self.mask_unit_spatial_shape_final)
79
+ ]
80
+ curr_mu_size = [i // s for i, s in zip(curr_mu_size, self.q_stride)]
81
+ self.multi_scale_fusion_heads.append(
82
+ conv_nd(len(self.q_stride))(
83
+ self.blocks[i].dim_out,
84
+ encoder_dim_out,
85
+ kernel_size=kernel,
86
+ stride=kernel,
87
+ )
88
+ )
89
+ self.multi_scale_fusion_heads.append(nn.Identity()) # final stage, no transform
90
+
91
+ # --------------------------------------------------------------------------
92
+ # MAE decoder specifics
93
+ self.decoder_embed = nn.Linear(encoder_dim_out, decoder_embed_dim)
94
+
95
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
96
+
97
+ self.decoder_pos_embed = nn.Parameter(
98
+ torch.zeros(
99
+ 1, math.prod(self.tokens_spatial_shape_final), decoder_embed_dim
100
+ )
101
+ )
102
+
103
+ self.decoder_blocks = nn.ModuleList(
104
+ [
105
+ HieraBlock(
106
+ dim=decoder_embed_dim,
107
+ dim_out=decoder_embed_dim,
108
+ heads=decoder_num_heads,
109
+ norm_layer=norm_layer,
110
+ mlp_ratio=mlp_ratio,
111
+ )
112
+ for i in range(decoder_depth)
113
+ ]
114
+ )
115
+ self.decoder_norm = norm_layer(decoder_embed_dim)
116
+
117
+ self.pred_stride = patch_stride[-1] * (
118
+ self.q_stride[-1] ** self.q_pool
119
+ ) # patch stride of prediction
120
+
121
+ self.decoder_pred = nn.Linear(
122
+ decoder_embed_dim,
123
+ (self.pred_stride ** min(2, len(self.q_stride))) * in_chans,
124
+ ) # predictor
125
+ # --------------------------------------------------------------------------
126
+
127
+ self.initialize_weights()
128
+
129
+ def initialize_weights(self):
130
+ nn.init.trunc_normal_(self.mask_token, std=0.02)
131
+ nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)
132
+ self.apply(self._mae_init_weights)
133
+
134
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
135
+ w = self.patch_embed.proj.weight.data
136
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
137
+
138
+ def _mae_init_weights(self, m: nn.Module):
139
+ if isinstance(m, nn.Linear):
140
+ nn.init.xavier_uniform_(m.weight)
141
+ if m.bias is not None:
142
+ nn.init.constant_(m.bias, 0)
143
+ elif isinstance(m, nn.LayerNorm):
144
+ nn.init.constant_(m.bias, 0)
145
+ nn.init.constant_(m.weight, 1.0)
146
+
147
+ def get_pixel_label_2d(
148
+ self, input_img: torch.Tensor, mask: torch.Tensor, norm: bool = True
149
+ ) -> torch.Tensor:
150
+ # mask (boolean tensor): True must correspond to *masked*
151
+ input_img = input_img.permute(0, 2, 3, 1)
152
+
153
+ size = self.pred_stride
154
+ label = input_img.unfold(1, size, size).unfold(2, size, size)
155
+ label = label.flatten(1, 2).flatten(2)
156
+ label = label[mask]
157
+ if norm:
158
+ mean = label.mean(dim=-1, keepdim=True)
159
+ var = label.var(dim=-1, keepdim=True)
160
+ label = (label - mean) / (var + 1.0e-6) ** 0.5
161
+
162
+ return label
163
+
164
+ def get_pixel_label_3d(
165
+ self, input_vid: torch.Tensor, mask: torch.Tensor, norm: bool = True
166
+ ) -> torch.Tensor:
167
+ # mask (boolean tensor): True must correspond to *masked*
168
+
169
+ # We use time strided loss, only take the first frame from each token
170
+ input_vid = input_vid[:, :, ::self.patch_stride[0], :, :]
171
+
172
+ size = self.pred_stride
173
+ label = input_vid.unfold(3, size, size).unfold(4, size, size)
174
+ label = label.permute(0, 2, 3, 4, 5, 6, 1) # Different from 2d, mistake during training lol
175
+ label = label.flatten(1, 3).flatten(2)
176
+ label = label[mask]
177
+
178
+ if norm:
179
+ mean = label.mean(dim=-1, keepdim=True)
180
+ var = label.var(dim=-1, keepdim=True)
181
+ label = (label - mean) / (var + 1.0e-6) ** 0.5
182
+
183
+ return label
184
+
185
+
186
+ def forward_encoder(
187
+ self, x: torch.Tensor, mask_ratio: float, mask: Optional[torch.Tensor] = None
188
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
189
+
190
+ if mask is None:
191
+ mask = self.get_random_mask(x, mask_ratio) # [B, #MUs_all]
192
+
193
+ # Get multi-scale representations from encoder
194
+ _, intermediates = super().forward(x, mask, return_intermediates=True)
195
+ # Resolution unchanged after q_pool stages, so skip those features
196
+ intermediates = intermediates[: self.q_pool] + intermediates[-1:]
197
+
198
+ # Multi-scale fusion
199
+ x = 0.0
200
+ for head, interm_x in zip(self.multi_scale_fusion_heads, intermediates):
201
+ x += apply_fusion_head(head, interm_x)
202
+
203
+ x = self.encoder_norm(x)
204
+
205
+ return x, mask
206
+
207
+ def forward_decoder(
208
+ self, x: torch.Tensor, mask: torch.Tensor
209
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
210
+ # Embed tokens
211
+ x = self.decoder_embed(x)
212
+
213
+ # Combine visible and mask tokens
214
+
215
+ # x: [B, #MUs, *mask_unit_spatial_shape_final, encoder_dim_out]
216
+ # mask: [B, #MUs_all]
217
+ x_dec = torch.zeros(*mask.shape, *x.shape[2:], device=x.device, dtype=x.dtype)
218
+ mask_tokens = self.mask_token.view(
219
+ (1,) * (len(mask.shape) + len(x.shape[2:-1])) + (-1,)
220
+ )
221
+ mask = mask.reshape(mask.shape + (1,) * len(x.shape[2:]))
222
+ mask = mask.expand((-1,) * 2 + x.shape[2:]).bool()
223
+ x_dec[mask] = x.flatten()
224
+ x_dec = ~mask * mask_tokens + mask * x_dec
225
+
226
+ # Get back spatial order
227
+ x = undo_windowing(
228
+ x_dec,
229
+ self.tokens_spatial_shape_final,
230
+ self.mask_unit_spatial_shape_final,
231
+ )
232
+ mask = undo_windowing(
233
+ mask[..., 0:1],
234
+ self.tokens_spatial_shape_final,
235
+ self.mask_unit_spatial_shape_final,
236
+ )
237
+
238
+ # Flatten
239
+ x = x.reshape(x.shape[0], -1, x.shape[-1])
240
+ mask = mask.view(x.shape[0], -1)
241
+
242
+ # Add pos embed
243
+ x = x + self.decoder_pos_embed
244
+
245
+ # Apply decoder blocks
246
+ for blk in self.decoder_blocks:
247
+ x = blk(x)
248
+ x = self.decoder_norm(x)
249
+
250
+ # Predictor projection
251
+ x = self.decoder_pred(x)
252
+
253
+ return x, mask
254
+
255
+ def forward_loss(
256
+ self, x: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor
257
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
258
+ """
259
+ Note: in mask, 0 is *visible*, 1 is *masked*
260
+
261
+ x: e.g. [B, 3, H, W]
262
+ pred: [B * num_pred_tokens, num_pixels_in_pred_patch * in_chans]
263
+ label: [B * num_pred_tokens, num_pixels_in_pred_patch * in_chans]
264
+ """
265
+ if len(self.q_stride) == 2:
266
+ label = self.get_pixel_label_2d(x, mask)
267
+ elif len(self.q_stride) == 3:
268
+ label = self.get_pixel_label_3d(x, mask)
269
+ else:
270
+ raise NotImplementedError
271
+
272
+ pred = pred[mask]
273
+ loss = (pred - label) ** 2
274
+
275
+ return loss.mean(), pred, label
276
+
277
+ def forward(
278
+ self,
279
+ x: torch.Tensor,
280
+ mask_ratio: float = 0.6,
281
+ mask: Optional[torch.Tensor] = None,
282
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
283
+
284
+ latent, mask = self.forward_encoder(x, mask_ratio, mask=mask)
285
+ pred, pred_mask = self.forward_decoder(
286
+ latent, mask
287
+ ) # pred_mask is mask at resolution of *prediction*
288
+
289
+ # Toggle mask, to generate labels for *masked* tokens
290
+ return *self.forward_loss(x, pred, ~pred_mask), mask
291
+
292
+
293
+
294
+
295
+ # Image Models
296
+
297
+ @pretrained_model({
298
+ "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth",
299
+ }, default="mae_in1k")
300
+ def mae_hiera_tiny_224(**kwargs):
301
+ return MaskedAutoencoderHiera(
302
+ embed_dim=96, num_heads=1, stages=(1, 2, 7, 2), q_pool=2, **kwargs,
303
+ )
304
+
305
+
306
+ @pretrained_model({
307
+ "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth",
308
+ }, default="mae_in1k")
309
+ def mae_hiera_small_224(**kwargs):
310
+ return MaskedAutoencoderHiera(
311
+ embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), q_pool=2, **kwargs,
312
+ )
313
+
314
+
315
+ @pretrained_model({
316
+ "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth",
317
+ }, default="mae_in1k")
318
+ def mae_hiera_base_224(**kwargs):
319
+ return MaskedAutoencoderHiera(
320
+ embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), q_pool=2, **kwargs,
321
+ )
322
+
323
+
324
+ @pretrained_model({
325
+ "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth",
326
+ }, default="mae_in1k")
327
+ def mae_hiera_base_plus_224(**kwargs):
328
+ return MaskedAutoencoderHiera(
329
+ embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), q_pool=2, **kwargs,
330
+ )
331
+
332
+
333
+ @pretrained_model({
334
+ "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth",
335
+ }, default="mae_in1k")
336
+ def mae_hiera_large_224(**kwargs):
337
+ return MaskedAutoencoderHiera(
338
+ embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), q_pool=2, **kwargs,
339
+ )
340
+
341
+
342
+ @pretrained_model({
343
+ "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth",
344
+ }, default="mae_in1k")
345
+ def mae_hiera_huge_224(**kwargs):
346
+ return MaskedAutoencoderHiera(
347
+ embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), q_pool=2, **kwargs,
348
+ )
349
+
350
+
351
+
352
+ # Video Models
353
+
354
+ @pretrained_model({
355
+ "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth",
356
+ }, default="mae_k400")
357
+ def mae_hiera_base_16x224(num_classes: int = 400, **kwdargs):
358
+ return MaskedAutoencoderHiera(
359
+ num_classes=num_classes, # K400 has 400 classes
360
+ input_size=(16, 224, 224),
361
+ q_stride=(1, 2, 2),
362
+ mask_unit_size=(1, 8, 8),
363
+ patch_kernel=(3, 7, 7),
364
+ patch_stride=(2, 4, 4),
365
+ patch_padding=(1, 3, 3),
366
+ sep_pos_embed=True,
367
+ q_pool=2,
368
+ **kwdargs
369
+ )
370
+
371
+
372
+ @pretrained_model({
373
+ "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth",
374
+ }, default="mae_k400")
375
+ @pretrained_model(None)
376
+ def mae_hiera_base_plus_16x224(**kwdargs):
377
+ return mae_hiera_base_16x224(
378
+ embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs
379
+ )
380
+
381
+
382
+ @pretrained_model({
383
+ "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth",
384
+ }, default="mae_k400")
385
+ @pretrained_model(None)
386
+ def mae_hiera_large_16x224(**kwdargs):
387
+ return mae_hiera_base_16x224(
388
+ embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs
389
+ )
390
+
391
+
392
+ @pretrained_model({
393
+ "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth",
394
+ }, default="mae_k400")
395
+ def mae_hiera_huge_16x224(**kwdargs):
396
+ return mae_hiera_base_16x224(
397
+ embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs
398
+ )
hiera/hiera_utils.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ #
8
+ # Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles
9
+ #
10
+ # Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan,
11
+ # Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed,
12
+ # Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer.
13
+ #
14
+ # Paper: https://arxiv.org/abs/2306.00989/
15
+ #
16
+ # References:
17
+ # slowfast: https://github.com/facebookresearch/SlowFast
18
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
19
+ # --------------------------------------------------------
20
+
21
+ import math
22
+ from typing import List, Tuple, Optional, Type, Callable, Dict
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+
28
+
29
+ def pretrained_model(checkpoints: Dict[str, str], default: str = None) -> Callable:
30
+ """ Loads a Hiera model from a pretrained source (if pretrained=True). Use "checkpoint" to specify the checkpoint. """
31
+
32
+ def inner(model_func: Callable) -> Callable:
33
+ def model_def(pretrained: bool = False, checkpoint: str = default, strict: bool = True, **kwdargs) -> nn.Module:
34
+ if pretrained:
35
+ if checkpoints is None:
36
+ raise RuntimeError("This model currently doesn't have pretrained weights available.")
37
+ elif checkpoint is None:
38
+ raise RuntimeError("No checkpoint specified.")
39
+ elif checkpoint not in checkpoints:
40
+ raise RuntimeError(f"Invalid checkpoint specified ({checkpoint}). Options are: {list(checkpoints.keys())}.")
41
+
42
+ state_dict = torch.hub.load_state_dict_from_url(checkpoints[checkpoint], map_location="cpu")
43
+
44
+ if "head.projection.weight" in state_dict["model_state"]:
45
+ # Set the number of classes equal to the state_dict only if the user doesn't want to overwrite it
46
+ if "num_classes" not in kwdargs:
47
+ kwdargs["num_classes"] = state_dict["model_state"]["head.projection.weight"].shape[0]
48
+ # If the user specified a different number of classes, remove the projection weights or else we'll error out
49
+ elif kwdargs["num_classes"] != state_dict["model_state"]["head.projection.weight"].shape[0]:
50
+ del state_dict["model_state"]["head.projection.weight"]
51
+ del state_dict["model_state"]["head.projection.bias"]
52
+
53
+ model = model_func(**kwdargs)
54
+ if pretrained:
55
+ # Disable being strict when trying to load a encoder-decoder model into an encoder-only model
56
+ if "decoder_pos_embed" in state_dict["model_state"] and not hasattr(model, "decoder_pos_embed"):
57
+ strict = False
58
+
59
+ model.load_state_dict(state_dict["model_state"], strict=strict)
60
+
61
+ return model
62
+
63
+ return model_def
64
+
65
+ return inner
66
+
67
+
68
+
69
+ def conv_nd(n: int) -> Type[nn.Module]:
70
+ """
71
+ Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3.
72
+ If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises)
73
+ """
74
+ return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
75
+
76
+
77
+ def do_pool(x: torch.Tensor, stride: int) -> torch.Tensor:
78
+ # Refer to `Unroll` to see how this performs a maxpool-Nd
79
+ return x.view(x.shape[0], stride, -1, x.shape[-1]).max(dim=1).values
80
+
81
+
82
+ def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tensor:
83
+ # target_size: [(T), (H), W]
84
+ # (spatial) mask: [B, C, (t), (h), w]
85
+ if mask is None:
86
+ return mask
87
+
88
+ assert len(mask.shape[2:]) == len(target_size)
89
+ if mask.shape[2:] != target_size:
90
+ return F.interpolate(mask.float(), size=target_size)
91
+ return mask
92
+
93
+
94
+ def do_masked_conv(
95
+ x: torch.Tensor, conv: nn.Module, mask: Optional[torch.Tensor] = None
96
+ ) -> torch.Tensor:
97
+ """Zero-out the masked regions of the input before conv.
98
+ Prevents leakage of masked regions when using overlapping kernels.
99
+ """
100
+ if conv is None:
101
+ return x
102
+ if mask is None:
103
+ return conv(x)
104
+
105
+ mask = get_resized_mask(target_size=x.shape[2:], mask=mask)
106
+ return conv(x * mask.bool())
107
+
108
+
109
+ def undo_windowing(
110
+ x: torch.Tensor, shape: List[int], mu_shape: List[int]
111
+ ) -> torch.Tensor:
112
+ """
113
+ Restore spatial organization by undoing windowed organization of mask units.
114
+
115
+ Args:
116
+ x: organized by mask units windows, e.g. in 2d [B, #MUy*#MUx, MUy, MUx, C]
117
+ shape: current spatial shape, if it were not organized into mask unit
118
+ windows, e.g. in 2d [B, #MUy*MUy, #MUx*MUx, C].
119
+ mu_shape: current mask unit shape, e.g. in 2d [MUy, MUx]
120
+ Returns:
121
+ x: e.g. in 2d, [B, #MUy*MUy, #MUx*MUx, C]
122
+ """
123
+ D = len(shape)
124
+ B, C = x.shape[0], x.shape[-1]
125
+ # [B, #MUy*#MUx, MUy, MUx, C] -> [B, #MUy, #MUx, MUy, MUx, C]
126
+ num_MUs = [s // mu for s, mu in zip(shape, mu_shape)]
127
+ x = x.view(B, *num_MUs, *mu_shape, C)
128
+
129
+ # [B, #MUy, #MUx, MUy, MUx, C] -> [B, #MUy*MUy, #MUx*MUx, C]
130
+ permute = (
131
+ [0]
132
+ + sum(
133
+ [list(p) for p in zip(range(1, 1 + D), range(1 + D, 1 + 2 * D))],
134
+ [],
135
+ )
136
+ + [len(x.shape) - 1]
137
+ )
138
+ x = x.permute(permute).reshape(B, *shape, C)
139
+
140
+ return x
141
+
142
+
143
+
144
+ class Unroll(nn.Module):
145
+ """
146
+ Reorders the tokens such that patches are contiguous in memory.
147
+ E.g., given [B, (H, W), C] and stride of (Sy, Sx), this will re-order the tokens as
148
+ [B, (Sy, Sx, H // Sy, W // Sx), C]
149
+
150
+ This allows operations like Max2d to be computed as x.view(B, Sx*Sy, -1, C).max(dim=1).
151
+ Not only is this faster, but it also makes it easy to support inputs of arbitrary
152
+ dimensions in addition to patch-wise sparsity.
153
+
154
+ Performing this operation multiple times in sequence puts entire windows as contiguous
155
+ in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
156
+ size 8x8 would be contiguous in memory, allowing operations like mask unit attention
157
+ computed easily and efficiently, while also allowing max to be applied sequentially.
158
+
159
+ Note: This means that intermediate values of the model are not in HxW order, so they
160
+ need to be re-rolled if you want to use the intermediate values as a HxW feature map.
161
+ The last block of the network is fine though, since by then the strides are all consumed.
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ input_size: Tuple[int, ...],
167
+ patch_stride: Tuple[int, ...],
168
+ unroll_schedule: List[Tuple[int, ...]],
169
+ ):
170
+ super().__init__()
171
+ self.size = [i // s for i, s in zip(input_size, patch_stride)]
172
+ self.schedule = unroll_schedule
173
+
174
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
175
+ """
176
+ Input: Flattened patch embeddings [B, N, C]
177
+ Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd
178
+ """
179
+ B, _, C = x.shape
180
+
181
+ cur_size = self.size
182
+ x = x.view(*([B] + cur_size + [C]))
183
+
184
+ for strides in self.schedule:
185
+ # Move patches with the given strides to the batch dimension
186
+
187
+ # Create a view of the tensor with the patch stride as separate dims
188
+ # For example in 2d: [B, H // Sy, Sy, W // Sx, Sx, C]
189
+ cur_size = [i // s for i, s in zip(cur_size, strides)]
190
+ new_shape = [B] + sum([[i, s] for i, s in zip(cur_size, strides)], []) + [C]
191
+ x = x.view(new_shape)
192
+
193
+ # Move the patch stride into the batch dimension
194
+ # For example in 2d: [B, Sy, Sx, H // Sy, W // Sx, C]
195
+ L = len(new_shape)
196
+ permute = (
197
+ [0] + list(range(2, L - 1, 2)) + list(range(1, L - 1, 2)) + [L - 1]
198
+ )
199
+ x = x.permute(permute)
200
+
201
+ # Now finally flatten the relevant dims into the batch dimension
202
+ x = x.flatten(0, len(strides))
203
+ B *= math.prod(strides)
204
+
205
+ x = x.reshape(-1, math.prod(self.size), C)
206
+ return x
207
+
208
+
209
+ class Reroll(nn.Module):
210
+ """
211
+ Undos the "unroll" operation so that you can use intermediate features.
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ input_size: Tuple[int, ...],
217
+ patch_stride: Tuple[int, ...],
218
+ unroll_schedule: List[Tuple[int, ...]],
219
+ stage_ends: List[int],
220
+ q_pool: int,
221
+ ):
222
+ super().__init__()
223
+ self.size = [i // s for i, s in zip(input_size, patch_stride)]
224
+
225
+ # The first stage has to reverse everything
226
+ # The next stage has to reverse all but the first unroll, etc.
227
+ self.schedule = {}
228
+ size = self.size
229
+ for i in range(stage_ends[-1] + 1):
230
+ self.schedule[i] = unroll_schedule, size
231
+ # schedule unchanged if no pooling at a stage end
232
+ if i in stage_ends[:q_pool]:
233
+ if len(unroll_schedule) > 0:
234
+ size = [n // s for n, s in zip(size, unroll_schedule[0])]
235
+ unroll_schedule = unroll_schedule[1:]
236
+
237
+ def forward(
238
+ self, x: torch.Tensor, block_idx: int, mask: torch.Tensor = None
239
+ ) -> torch.Tensor:
240
+ """
241
+ Roll the given tensor back up to spatial order assuming it's from the given block.
242
+
243
+ If no mask is provided:
244
+ - Returns [B, H, W, C] for 2d, [B, T, H, W, C] for 3d, etc.
245
+ If a mask is provided:
246
+ - Returns [B, #MUs, MUy, MUx, C] for 2d, etc.
247
+ """
248
+ schedule, size = self.schedule[block_idx]
249
+ B, N, C = x.shape
250
+
251
+ D = len(size)
252
+ cur_mu_shape = [1] * D
253
+
254
+ for strides in schedule:
255
+ # Extract the current patch from N
256
+ x = x.view(B, *strides, N // math.prod(strides), *cur_mu_shape, C)
257
+
258
+ # Move that patch into the current MU
259
+ # Example in 2d: [B, Sy, Sx, N//(Sy*Sx), MUy, MUx, C] -> [B, N//(Sy*Sx), Sy, MUy, Sx, MUx, C]
260
+ L = len(x.shape)
261
+ permute = (
262
+ [0, 1 + D]
263
+ + sum(
264
+ [list(p) for p in zip(range(1, 1 + D), range(1 + D + 1, L - 1))],
265
+ [],
266
+ )
267
+ + [L - 1]
268
+ )
269
+ x = x.permute(permute)
270
+
271
+ # Reshape to [B, N//(Sy*Sx), *MU, C]
272
+ for i in range(D):
273
+ cur_mu_shape[i] *= strides[i]
274
+ x = x.reshape(B, -1, *cur_mu_shape, C)
275
+ N = x.shape[1]
276
+
277
+ # Current shape (e.g., 2d: [B, #MUy*#MUx, MUy, MUx, C])
278
+ x = x.view(B, N, *cur_mu_shape, C)
279
+
280
+ # If masked, return [B, #MUs, MUy, MUx, C]
281
+ if mask is not None:
282
+ return x
283
+
284
+ # If not masked, we can return [B, H, W, C]
285
+ x = undo_windowing(x, size, cur_mu_shape)
286
+
287
+ return x