Upload 4 files
Browse files- hiera/__init__.py +43 -0
- hiera/hiera.py +535 -0
- hiera/hiera_mae.py +398 -0
- hiera/hiera_utils.py +287 -0
hiera/__init__.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
from .hiera import (
|
9 |
+
hiera_tiny_224,
|
10 |
+
hiera_small_224,
|
11 |
+
hiera_base_224,
|
12 |
+
hiera_base_plus_224,
|
13 |
+
hiera_large_224,
|
14 |
+
hiera_huge_224,
|
15 |
+
|
16 |
+
hiera_base_16x224,
|
17 |
+
hiera_base_plus_16x224,
|
18 |
+
hiera_large_16x224,
|
19 |
+
hiera_huge_16x224,
|
20 |
+
|
21 |
+
Hiera,
|
22 |
+
HieraBlock,
|
23 |
+
MaskUnitAttention,
|
24 |
+
Head,
|
25 |
+
PatchEmbed,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
from .hiera_mae import (
|
30 |
+
mae_hiera_tiny_224,
|
31 |
+
mae_hiera_small_224,
|
32 |
+
mae_hiera_base_224,
|
33 |
+
mae_hiera_base_plus_224,
|
34 |
+
mae_hiera_large_224,
|
35 |
+
mae_hiera_huge_224,
|
36 |
+
|
37 |
+
mae_hiera_base_16x224,
|
38 |
+
mae_hiera_base_plus_16x224,
|
39 |
+
mae_hiera_large_16x224,
|
40 |
+
mae_hiera_huge_16x224,
|
41 |
+
|
42 |
+
MaskedAutoencoderHiera,
|
43 |
+
)
|
hiera/hiera.py
ADDED
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
#
|
8 |
+
# Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles
|
9 |
+
#
|
10 |
+
# Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan,
|
11 |
+
# Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed,
|
12 |
+
# Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer.
|
13 |
+
#
|
14 |
+
# Paper: https://arxiv.org/abs/2306.00989/
|
15 |
+
#
|
16 |
+
# References:
|
17 |
+
# slowfast: https://github.com/facebookresearch/SlowFast
|
18 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
19 |
+
# --------------------------------------------------------
|
20 |
+
|
21 |
+
import math
|
22 |
+
from functools import partial
|
23 |
+
from typing import List, Tuple, Callable, Optional
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
import torch.nn.functional as F
|
28 |
+
|
29 |
+
from timm.models.layers import DropPath, Mlp
|
30 |
+
|
31 |
+
from .hiera_utils import pretrained_model, conv_nd, do_pool, do_masked_conv, Unroll, Reroll
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
class MaskUnitAttention(nn.Module):
|
36 |
+
"""
|
37 |
+
Computes either Mask Unit or Global Attention. Also is able to perform q pooling.
|
38 |
+
|
39 |
+
Note: this assumes the tokens have already been flattened and unrolled into mask units.
|
40 |
+
See `Unroll` for more details.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
dim: int,
|
46 |
+
dim_out: int,
|
47 |
+
heads: int,
|
48 |
+
q_stride: int = 1,
|
49 |
+
window_size: int = 0,
|
50 |
+
use_mask_unit_attn: bool = False,
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
Args:
|
54 |
+
- dim, dim_out: The input and output feature dimensions.
|
55 |
+
- heads: The number of attention heads.
|
56 |
+
- q_stride: If greater than 1, pool q with this stride. The stride should be flattened (e.g., 2x2 = 4).
|
57 |
+
- window_size: The current (flattened) size of a mask unit *after* pooling (if any).
|
58 |
+
- use_mask_unit_attn: Use Mask Unit or Global Attention.
|
59 |
+
"""
|
60 |
+
super().__init__()
|
61 |
+
|
62 |
+
self.dim = dim
|
63 |
+
self.dim_out = dim_out
|
64 |
+
self.heads = heads
|
65 |
+
self.q_stride = q_stride
|
66 |
+
|
67 |
+
self.head_dim = dim_out // heads
|
68 |
+
self.scale = (self.head_dim) ** -0.5
|
69 |
+
|
70 |
+
self.qkv = nn.Linear(dim, 3 * dim_out)
|
71 |
+
self.proj = nn.Linear(dim_out, dim_out)
|
72 |
+
|
73 |
+
self.window_size = window_size
|
74 |
+
self.use_mask_unit_attn = use_mask_unit_attn
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
77 |
+
""" Input should be of shape [batch, tokens, channels]. """
|
78 |
+
B, N, _ = x.shape
|
79 |
+
num_windows = (
|
80 |
+
(N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
|
81 |
+
)
|
82 |
+
|
83 |
+
qkv = (
|
84 |
+
self.qkv(x)
|
85 |
+
.reshape(B, -1, num_windows, 3, self.heads, self.head_dim)
|
86 |
+
.permute(3, 0, 4, 2, 1, 5)
|
87 |
+
)
|
88 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
89 |
+
|
90 |
+
if self.q_stride > 1:
|
91 |
+
# Refer to Unroll to see how this performs a maxpool-Nd
|
92 |
+
q = (
|
93 |
+
q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim)
|
94 |
+
.max(dim=3)
|
95 |
+
.values
|
96 |
+
)
|
97 |
+
|
98 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
99 |
+
# Note: the original paper did *not* use SDPA, it's a free boost!
|
100 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
101 |
+
else:
|
102 |
+
attn = (q * self.scale) @ k.transpose(-1, -2)
|
103 |
+
attn = attn.softmax(dim=-1)
|
104 |
+
x = (attn @ v)
|
105 |
+
|
106 |
+
x = x.transpose(1, 3).reshape(B, -1, self.dim_out)
|
107 |
+
x = self.proj(x)
|
108 |
+
return x
|
109 |
+
|
110 |
+
|
111 |
+
class HieraBlock(nn.Module):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
dim: int,
|
115 |
+
dim_out: int,
|
116 |
+
heads: int,
|
117 |
+
mlp_ratio: float = 4.0,
|
118 |
+
drop_path: float = 0.0,
|
119 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
120 |
+
act_layer: nn.Module = nn.GELU,
|
121 |
+
q_stride: int = 1,
|
122 |
+
window_size: int = 0,
|
123 |
+
use_mask_unit_attn: bool = False,
|
124 |
+
):
|
125 |
+
super().__init__()
|
126 |
+
|
127 |
+
self.dim = dim
|
128 |
+
self.dim_out = dim_out
|
129 |
+
|
130 |
+
self.norm1 = norm_layer(dim)
|
131 |
+
self.attn = MaskUnitAttention(
|
132 |
+
dim, dim_out, heads, q_stride, window_size, use_mask_unit_attn
|
133 |
+
)
|
134 |
+
|
135 |
+
self.norm2 = norm_layer(dim_out)
|
136 |
+
self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer)
|
137 |
+
|
138 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
|
139 |
+
if dim != dim_out:
|
140 |
+
self.proj = nn.Linear(dim, dim_out)
|
141 |
+
|
142 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
143 |
+
# Attention + Q Pooling
|
144 |
+
x_norm = self.norm1(x)
|
145 |
+
if self.dim != self.dim_out:
|
146 |
+
x = do_pool(self.proj(x_norm), stride=self.attn.q_stride)
|
147 |
+
x = x + self.drop_path(self.attn(x_norm))
|
148 |
+
|
149 |
+
# MLP
|
150 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
151 |
+
return x
|
152 |
+
|
153 |
+
|
154 |
+
class Head(nn.Module):
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
dim: int,
|
158 |
+
num_classes: int,
|
159 |
+
dropout_rate: float = 0.0,
|
160 |
+
act_func: Callable[[torch.Tensor], torch.Tensor] = lambda x: x.softmax(dim=-1),
|
161 |
+
):
|
162 |
+
super().__init__()
|
163 |
+
self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
|
164 |
+
self.projection = nn.Linear(dim, num_classes)
|
165 |
+
# act_fun for eval and testing only
|
166 |
+
self.act_func = act_func
|
167 |
+
|
168 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
169 |
+
x = self.dropout(x)
|
170 |
+
x = self.projection(x)
|
171 |
+
if not self.training:
|
172 |
+
x = self.act_func(x)
|
173 |
+
return x
|
174 |
+
|
175 |
+
|
176 |
+
class PatchEmbed(nn.Module):
|
177 |
+
"""Patch embed that supports any number of spatial dimensions (1d, 2d, 3d)."""
|
178 |
+
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
dim_in: int,
|
182 |
+
dim_out: int,
|
183 |
+
kernel: Tuple[int, ...],
|
184 |
+
stride: Tuple[int, ...],
|
185 |
+
padding: Tuple[int, ...],
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
|
189 |
+
# Support any number of spatial dimensions
|
190 |
+
self.spatial_dims = len(kernel)
|
191 |
+
self.proj = conv_nd(self.spatial_dims)(
|
192 |
+
dim_in,
|
193 |
+
dim_out,
|
194 |
+
kernel_size=kernel,
|
195 |
+
stride=stride,
|
196 |
+
padding=padding,
|
197 |
+
)
|
198 |
+
|
199 |
+
def forward(
|
200 |
+
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
|
201 |
+
) -> torch.Tensor:
|
202 |
+
x = do_masked_conv(x, self.proj, mask)
|
203 |
+
x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1)
|
204 |
+
return x
|
205 |
+
|
206 |
+
|
207 |
+
class Hiera(nn.Module):
|
208 |
+
def __init__(
|
209 |
+
self,
|
210 |
+
input_size: Tuple[int, ...] = (224, 224),
|
211 |
+
in_chans: int = 3,
|
212 |
+
embed_dim: int = 96, # initial embed dim
|
213 |
+
num_heads: int = 1, # initial number of heads
|
214 |
+
num_classes: int = 1000,
|
215 |
+
stages: Tuple[int, ...] = (2, 3, 16, 3),
|
216 |
+
q_pool: int = 3, # number of q_pool stages
|
217 |
+
q_stride: Tuple[int, ...] = (2, 2),
|
218 |
+
mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1)
|
219 |
+
# mask_unit_attn: which stages use mask unit attention?
|
220 |
+
mask_unit_attn: Tuple[bool, ...] = (True, True, False, False),
|
221 |
+
dim_mul: float = 2.0,
|
222 |
+
head_mul: float = 2.0,
|
223 |
+
patch_kernel: Tuple[int, ...] = (7, 7),
|
224 |
+
patch_stride: Tuple[int, ...] = (4, 4),
|
225 |
+
patch_padding: Tuple[int, ...] = (3, 3),
|
226 |
+
mlp_ratio: float = 4.0,
|
227 |
+
drop_path_rate: float = 0.0,
|
228 |
+
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
|
229 |
+
head_dropout: float = 0.0,
|
230 |
+
head_init_scale: float = 0.001,
|
231 |
+
sep_pos_embed: bool = False,
|
232 |
+
):
|
233 |
+
super().__init__()
|
234 |
+
|
235 |
+
depth = sum(stages)
|
236 |
+
self.patch_stride = patch_stride
|
237 |
+
self.tokens_spatial_shape = [i // s for i, s in zip(input_size, patch_stride)]
|
238 |
+
num_tokens = math.prod(self.tokens_spatial_shape)
|
239 |
+
flat_mu_size = math.prod(mask_unit_size)
|
240 |
+
flat_q_stride = math.prod(q_stride)
|
241 |
+
|
242 |
+
assert q_pool < len(stages)
|
243 |
+
self.q_pool, self.q_stride = q_pool, q_stride
|
244 |
+
self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
|
245 |
+
self.mask_spatial_shape = [
|
246 |
+
i // s for i, s in zip(self.tokens_spatial_shape, self.mask_unit_size)
|
247 |
+
]
|
248 |
+
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
249 |
+
|
250 |
+
self.patch_embed = PatchEmbed(
|
251 |
+
in_chans, embed_dim, patch_kernel, patch_stride, patch_padding
|
252 |
+
)
|
253 |
+
|
254 |
+
self.sep_pos_embed = sep_pos_embed
|
255 |
+
if sep_pos_embed:
|
256 |
+
self.pos_embed_spatial = nn.Parameter(
|
257 |
+
torch.zeros(
|
258 |
+
1,
|
259 |
+
self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
|
260 |
+
embed_dim,
|
261 |
+
)
|
262 |
+
)
|
263 |
+
self.pos_embed_temporal = nn.Parameter(
|
264 |
+
torch.zeros(1, self.tokens_spatial_shape[0], embed_dim)
|
265 |
+
)
|
266 |
+
else:
|
267 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
|
268 |
+
|
269 |
+
# Setup roll and reroll modules
|
270 |
+
self.unroll = Unroll(
|
271 |
+
input_size, patch_stride, [q_stride] * len(self.stage_ends[:-1])
|
272 |
+
)
|
273 |
+
self.reroll = Reroll(
|
274 |
+
input_size,
|
275 |
+
patch_stride,
|
276 |
+
[q_stride] * len(self.stage_ends[:-1]),
|
277 |
+
self.stage_ends,
|
278 |
+
q_pool,
|
279 |
+
)
|
280 |
+
# q_pool locations
|
281 |
+
q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]]
|
282 |
+
# stochastic depth decay rule
|
283 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
284 |
+
|
285 |
+
# Transformer blocks
|
286 |
+
cur_stage = 0
|
287 |
+
self.blocks = nn.ModuleList()
|
288 |
+
|
289 |
+
for i in range(depth):
|
290 |
+
dim_out = embed_dim
|
291 |
+
# Mask unit or global attention.
|
292 |
+
# Lag by 1 block, so that global attention,
|
293 |
+
# applied post pooling on lower resolution
|
294 |
+
use_mask_unit_attn = mask_unit_attn[cur_stage]
|
295 |
+
|
296 |
+
if i - 1 in self.stage_ends:
|
297 |
+
dim_out = int(embed_dim * dim_mul)
|
298 |
+
num_heads = int(num_heads * head_mul)
|
299 |
+
cur_stage += 1
|
300 |
+
if i in q_pool_blocks:
|
301 |
+
flat_mu_size //= flat_q_stride
|
302 |
+
|
303 |
+
block = HieraBlock(
|
304 |
+
dim=embed_dim,
|
305 |
+
dim_out=dim_out,
|
306 |
+
heads=num_heads,
|
307 |
+
mlp_ratio=mlp_ratio,
|
308 |
+
drop_path=dpr[i],
|
309 |
+
norm_layer=norm_layer,
|
310 |
+
q_stride=(flat_q_stride if i in q_pool_blocks else 1),
|
311 |
+
window_size=flat_mu_size,
|
312 |
+
use_mask_unit_attn=use_mask_unit_attn,
|
313 |
+
)
|
314 |
+
|
315 |
+
embed_dim = dim_out
|
316 |
+
self.blocks.append(block)
|
317 |
+
|
318 |
+
self.norm = norm_layer(embed_dim)
|
319 |
+
self.head = Head(embed_dim, num_classes, dropout_rate=head_dropout)
|
320 |
+
|
321 |
+
# Initialize everything
|
322 |
+
if sep_pos_embed:
|
323 |
+
nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
|
324 |
+
nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)
|
325 |
+
else:
|
326 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
327 |
+
self.apply(partial(self._init_weights))
|
328 |
+
self.head.projection.weight.data.mul_(head_init_scale)
|
329 |
+
self.head.projection.bias.data.mul_(head_init_scale)
|
330 |
+
|
331 |
+
def _init_weights(self, m, init_bias=0.02):
|
332 |
+
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
333 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
334 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
335 |
+
nn.init.constant_(m.bias, init_bias)
|
336 |
+
elif isinstance(m, nn.LayerNorm):
|
337 |
+
nn.init.constant_(m.bias, init_bias)
|
338 |
+
nn.init.constant_(m.weight, 1.0)
|
339 |
+
|
340 |
+
@torch.jit.ignore
|
341 |
+
def no_weight_decay(self):
|
342 |
+
if self.sep_pos_embed:
|
343 |
+
return ["pos_embed_spatial", "pos_embed_temporal"]
|
344 |
+
else:
|
345 |
+
return ["pos_embed"]
|
346 |
+
|
347 |
+
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
|
348 |
+
"""
|
349 |
+
Generates a random mask, mask_ratio fraction are dropped.
|
350 |
+
1 is *keep*, 0 is *remove*. Useful for MAE, FLIP, etc.
|
351 |
+
"""
|
352 |
+
B = x.shape[0]
|
353 |
+
# Tokens selected for masking at mask unit level
|
354 |
+
num_windows = math.prod(self.mask_spatial_shape) # num_mask_units
|
355 |
+
len_keep = int(num_windows * (1 - mask_ratio))
|
356 |
+
noise = torch.rand(B, num_windows, device=x.device)
|
357 |
+
|
358 |
+
# Sort noise for each sample
|
359 |
+
ids_shuffle = torch.argsort(
|
360 |
+
noise, dim=1
|
361 |
+
) # ascend: small is keep, large is remove
|
362 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
363 |
+
|
364 |
+
# Generate the binary mask: 1 is *keep*, 0 is *remove*
|
365 |
+
# Note this is opposite to original MAE
|
366 |
+
mask = torch.zeros([B, num_windows], device=x.device)
|
367 |
+
mask[:, :len_keep] = 1
|
368 |
+
# Unshuffle to get the binary mask
|
369 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
370 |
+
|
371 |
+
return mask.bool()
|
372 |
+
|
373 |
+
def get_pos_embed(self) -> torch.Tensor:
|
374 |
+
if self.sep_pos_embed:
|
375 |
+
return self.pos_embed_spatial.repeat(
|
376 |
+
1, self.tokens_spatial_shape[0], 1
|
377 |
+
) + torch.repeat_interleave(
|
378 |
+
self.pos_embed_temporal,
|
379 |
+
self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
|
380 |
+
dim=1,
|
381 |
+
)
|
382 |
+
else:
|
383 |
+
return self.pos_embed
|
384 |
+
|
385 |
+
def forward(
|
386 |
+
self,
|
387 |
+
x: torch.Tensor,
|
388 |
+
mask: torch.Tensor = None,
|
389 |
+
return_intermediates: bool = False,
|
390 |
+
) -> torch.Tensor:
|
391 |
+
"""
|
392 |
+
mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim.
|
393 |
+
Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch.
|
394 |
+
"""
|
395 |
+
# Slowfast training passes in a list
|
396 |
+
if isinstance(x, list):
|
397 |
+
x = x[0]
|
398 |
+
intermediates = []
|
399 |
+
|
400 |
+
x = self.patch_embed(
|
401 |
+
x,
|
402 |
+
mask=mask.view(
|
403 |
+
x.shape[0], 1, *self.mask_spatial_shape
|
404 |
+
) # B, C, *mask_spatial_shape
|
405 |
+
if mask is not None
|
406 |
+
else None,
|
407 |
+
)
|
408 |
+
x = x + self.get_pos_embed()
|
409 |
+
x = self.unroll(x)
|
410 |
+
|
411 |
+
# Discard masked tokens
|
412 |
+
if mask is not None:
|
413 |
+
x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(
|
414 |
+
x.shape[0], -1, x.shape[-1]
|
415 |
+
)
|
416 |
+
|
417 |
+
for i, blk in enumerate(self.blocks):
|
418 |
+
x = blk(x)
|
419 |
+
|
420 |
+
if return_intermediates and i in self.stage_ends:
|
421 |
+
intermediates.append(self.reroll(x, i, mask=mask))
|
422 |
+
|
423 |
+
if mask is None:
|
424 |
+
x = x.mean(dim=1)
|
425 |
+
x = self.norm(x)
|
426 |
+
x = self.head(x)
|
427 |
+
|
428 |
+
# x may not always be in spatial order here.
|
429 |
+
# e.g. if q_pool = 2, mask_unit_size = (8, 8), and
|
430 |
+
# q_stride = (2, 2), not all unrolls were consumed,
|
431 |
+
# intermediates[-1] is x in spatial order
|
432 |
+
if return_intermediates:
|
433 |
+
return x, intermediates
|
434 |
+
|
435 |
+
return x
|
436 |
+
|
437 |
+
|
438 |
+
# Image models
|
439 |
+
|
440 |
+
@pretrained_model({
|
441 |
+
"mae_in1k_ft_in1k": "https://huggingface.co/merve/hiera-tiny-ft-224-in1k/resolve/main/hiera_tiny_224.pth",
|
442 |
+
"mae_in1k": "https://huggingface.co/merve/hiera-tiny-224-in1k/resolve/main/mae_hiera_tiny_224.pth",
|
443 |
+
}, default="mae_in1k_ft_in1k")
|
444 |
+
def hiera_tiny_224(**kwdargs):
|
445 |
+
return Hiera(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2), **kwdargs)
|
446 |
+
|
447 |
+
|
448 |
+
@pretrained_model({
|
449 |
+
"mae_in1k_ft_in1k": "https://huggingface.co/merve/hiera-small-ft-224-in1k/resolve/main/hiera_small_224.pth",
|
450 |
+
"mae_in1k": "https://huggingface.co/merve/hiera-small-224-in1k/resolve/main/mae_hiera_small_224.pth",
|
451 |
+
}, default="mae_in1k_ft_in1k")
|
452 |
+
def hiera_small_224(**kwdargs):
|
453 |
+
return Hiera(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), **kwdargs)
|
454 |
+
|
455 |
+
|
456 |
+
@pretrained_model({
|
457 |
+
"mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth",
|
458 |
+
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth",
|
459 |
+
}, default="mae_in1k_ft_in1k")
|
460 |
+
def hiera_base_224(**kwdargs):
|
461 |
+
return Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), **kwdargs)
|
462 |
+
|
463 |
+
|
464 |
+
@pretrained_model({
|
465 |
+
"mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth",
|
466 |
+
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth",
|
467 |
+
}, default="mae_in1k_ft_in1k")
|
468 |
+
def hiera_base_plus_224(**kwdargs):
|
469 |
+
return Hiera(embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs)
|
470 |
+
|
471 |
+
|
472 |
+
@pretrained_model({
|
473 |
+
"mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth",
|
474 |
+
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth",
|
475 |
+
}, default="mae_in1k_ft_in1k")
|
476 |
+
def hiera_large_224(**kwdargs):
|
477 |
+
return Hiera(embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs)
|
478 |
+
|
479 |
+
|
480 |
+
@pretrained_model({
|
481 |
+
"mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth",
|
482 |
+
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth",
|
483 |
+
}, default="mae_in1k_ft_in1k")
|
484 |
+
def hiera_huge_224(**kwdargs):
|
485 |
+
return Hiera(embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs)
|
486 |
+
|
487 |
+
|
488 |
+
# Video models
|
489 |
+
|
490 |
+
@pretrained_model({
|
491 |
+
"mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_base_16x224.pth",
|
492 |
+
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth",
|
493 |
+
}, default="mae_k400_ft_k400")
|
494 |
+
def hiera_base_16x224(num_classes: int = 400, **kwdargs):
|
495 |
+
return Hiera(
|
496 |
+
num_classes=num_classes, # K400 has 400 classes
|
497 |
+
input_size=(16, 224, 224),
|
498 |
+
q_stride=(1, 2, 2),
|
499 |
+
mask_unit_size=(1, 8, 8),
|
500 |
+
patch_kernel=(3, 7, 7),
|
501 |
+
patch_stride=(2, 4, 4),
|
502 |
+
patch_padding=(1, 3, 3),
|
503 |
+
sep_pos_embed=True,
|
504 |
+
**kwdargs
|
505 |
+
)
|
506 |
+
|
507 |
+
|
508 |
+
@pretrained_model({
|
509 |
+
"mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_16x224.pth",
|
510 |
+
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth",
|
511 |
+
}, default="mae_k400_ft_k400")
|
512 |
+
def hiera_base_plus_16x224(**kwdargs):
|
513 |
+
return hiera_base_16x224(
|
514 |
+
embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs
|
515 |
+
)
|
516 |
+
|
517 |
+
|
518 |
+
@pretrained_model({
|
519 |
+
"mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_large_16x224.pth",
|
520 |
+
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth",
|
521 |
+
}, default="mae_k400_ft_k400")
|
522 |
+
def hiera_large_16x224(**kwdargs):
|
523 |
+
return hiera_base_16x224(
|
524 |
+
embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs
|
525 |
+
)
|
526 |
+
|
527 |
+
|
528 |
+
@pretrained_model({
|
529 |
+
"mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_huge_16x224.pth",
|
530 |
+
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth",
|
531 |
+
}, default="mae_k400_ft_k400")
|
532 |
+
def hiera_huge_16x224(**kwdargs):
|
533 |
+
return hiera_base_16x224(
|
534 |
+
embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs
|
535 |
+
)
|
hiera/hiera_mae.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# mae: https://github.com/facebookresearch/mae
|
9 |
+
# slowfast: https://github.com/facebookresearch/SlowFast
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
|
13 |
+
from functools import partial
|
14 |
+
from typing import Tuple, Optional
|
15 |
+
|
16 |
+
import math
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from .hiera import Hiera, HieraBlock
|
21 |
+
from .hiera_utils import pretrained_model, undo_windowing, conv_nd
|
22 |
+
|
23 |
+
|
24 |
+
def apply_fusion_head(head: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
25 |
+
if isinstance(head, nn.Identity):
|
26 |
+
return x
|
27 |
+
|
28 |
+
B, num_mask_units = x.shape[0:2]
|
29 |
+
# Apply head, e.g [B, #MUs, My, Mx, C] -> head([B * #MUs, C, My, Mx])
|
30 |
+
permute = [0] + [len(x.shape) - 2] + list(range(1, len(x.shape) - 2))
|
31 |
+
x = head(x.reshape(B * num_mask_units, *x.shape[2:]).permute(permute))
|
32 |
+
|
33 |
+
# Restore original layout, e.g. [B * #MUs, C', My', Mx'] -> [B, #MUs, My', Mx', C']
|
34 |
+
permute = [0] + list(range(2, len(x.shape))) + [1]
|
35 |
+
x = x.permute(permute).reshape(B, num_mask_units, *x.shape[2:], x.shape[1])
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
class MaskedAutoencoderHiera(Hiera):
|
40 |
+
"""Masked Autoencoder with Hiera backbone"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
in_chans: int = 3,
|
45 |
+
patch_stride: Tuple[int, ...] = (4, 4),
|
46 |
+
mlp_ratio: float = 4.0,
|
47 |
+
decoder_embed_dim: int = 512,
|
48 |
+
decoder_depth: int = 8,
|
49 |
+
decoder_num_heads: int = 16,
|
50 |
+
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
|
51 |
+
**kwdargs,
|
52 |
+
):
|
53 |
+
super().__init__(
|
54 |
+
in_chans=in_chans,
|
55 |
+
patch_stride=patch_stride,
|
56 |
+
mlp_ratio=mlp_ratio,
|
57 |
+
norm_layer=norm_layer,
|
58 |
+
**kwdargs,
|
59 |
+
)
|
60 |
+
|
61 |
+
del self.norm, self.head
|
62 |
+
encoder_dim_out = self.blocks[-1].dim_out
|
63 |
+
self.encoder_norm = norm_layer(encoder_dim_out)
|
64 |
+
self.mask_unit_spatial_shape_final = [
|
65 |
+
i // s ** (self.q_pool) for i, s in zip(self.mask_unit_size, self.q_stride)
|
66 |
+
]
|
67 |
+
self.tokens_spatial_shape_final = [
|
68 |
+
i // s ** (self.q_pool)
|
69 |
+
for i, s in zip(self.tokens_spatial_shape, self.q_stride)
|
70 |
+
]
|
71 |
+
# --------------------------------------------------------------------------
|
72 |
+
# Multi-scale fusion heads
|
73 |
+
curr_mu_size = self.mask_unit_size
|
74 |
+
self.multi_scale_fusion_heads = nn.ModuleList()
|
75 |
+
|
76 |
+
for i in self.stage_ends[: self.q_pool]: # resolution constant after q_pool
|
77 |
+
kernel = [
|
78 |
+
i // s for i, s in zip(curr_mu_size, self.mask_unit_spatial_shape_final)
|
79 |
+
]
|
80 |
+
curr_mu_size = [i // s for i, s in zip(curr_mu_size, self.q_stride)]
|
81 |
+
self.multi_scale_fusion_heads.append(
|
82 |
+
conv_nd(len(self.q_stride))(
|
83 |
+
self.blocks[i].dim_out,
|
84 |
+
encoder_dim_out,
|
85 |
+
kernel_size=kernel,
|
86 |
+
stride=kernel,
|
87 |
+
)
|
88 |
+
)
|
89 |
+
self.multi_scale_fusion_heads.append(nn.Identity()) # final stage, no transform
|
90 |
+
|
91 |
+
# --------------------------------------------------------------------------
|
92 |
+
# MAE decoder specifics
|
93 |
+
self.decoder_embed = nn.Linear(encoder_dim_out, decoder_embed_dim)
|
94 |
+
|
95 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
96 |
+
|
97 |
+
self.decoder_pos_embed = nn.Parameter(
|
98 |
+
torch.zeros(
|
99 |
+
1, math.prod(self.tokens_spatial_shape_final), decoder_embed_dim
|
100 |
+
)
|
101 |
+
)
|
102 |
+
|
103 |
+
self.decoder_blocks = nn.ModuleList(
|
104 |
+
[
|
105 |
+
HieraBlock(
|
106 |
+
dim=decoder_embed_dim,
|
107 |
+
dim_out=decoder_embed_dim,
|
108 |
+
heads=decoder_num_heads,
|
109 |
+
norm_layer=norm_layer,
|
110 |
+
mlp_ratio=mlp_ratio,
|
111 |
+
)
|
112 |
+
for i in range(decoder_depth)
|
113 |
+
]
|
114 |
+
)
|
115 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
116 |
+
|
117 |
+
self.pred_stride = patch_stride[-1] * (
|
118 |
+
self.q_stride[-1] ** self.q_pool
|
119 |
+
) # patch stride of prediction
|
120 |
+
|
121 |
+
self.decoder_pred = nn.Linear(
|
122 |
+
decoder_embed_dim,
|
123 |
+
(self.pred_stride ** min(2, len(self.q_stride))) * in_chans,
|
124 |
+
) # predictor
|
125 |
+
# --------------------------------------------------------------------------
|
126 |
+
|
127 |
+
self.initialize_weights()
|
128 |
+
|
129 |
+
def initialize_weights(self):
|
130 |
+
nn.init.trunc_normal_(self.mask_token, std=0.02)
|
131 |
+
nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)
|
132 |
+
self.apply(self._mae_init_weights)
|
133 |
+
|
134 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
135 |
+
w = self.patch_embed.proj.weight.data
|
136 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
137 |
+
|
138 |
+
def _mae_init_weights(self, m: nn.Module):
|
139 |
+
if isinstance(m, nn.Linear):
|
140 |
+
nn.init.xavier_uniform_(m.weight)
|
141 |
+
if m.bias is not None:
|
142 |
+
nn.init.constant_(m.bias, 0)
|
143 |
+
elif isinstance(m, nn.LayerNorm):
|
144 |
+
nn.init.constant_(m.bias, 0)
|
145 |
+
nn.init.constant_(m.weight, 1.0)
|
146 |
+
|
147 |
+
def get_pixel_label_2d(
|
148 |
+
self, input_img: torch.Tensor, mask: torch.Tensor, norm: bool = True
|
149 |
+
) -> torch.Tensor:
|
150 |
+
# mask (boolean tensor): True must correspond to *masked*
|
151 |
+
input_img = input_img.permute(0, 2, 3, 1)
|
152 |
+
|
153 |
+
size = self.pred_stride
|
154 |
+
label = input_img.unfold(1, size, size).unfold(2, size, size)
|
155 |
+
label = label.flatten(1, 2).flatten(2)
|
156 |
+
label = label[mask]
|
157 |
+
if norm:
|
158 |
+
mean = label.mean(dim=-1, keepdim=True)
|
159 |
+
var = label.var(dim=-1, keepdim=True)
|
160 |
+
label = (label - mean) / (var + 1.0e-6) ** 0.5
|
161 |
+
|
162 |
+
return label
|
163 |
+
|
164 |
+
def get_pixel_label_3d(
|
165 |
+
self, input_vid: torch.Tensor, mask: torch.Tensor, norm: bool = True
|
166 |
+
) -> torch.Tensor:
|
167 |
+
# mask (boolean tensor): True must correspond to *masked*
|
168 |
+
|
169 |
+
# We use time strided loss, only take the first frame from each token
|
170 |
+
input_vid = input_vid[:, :, ::self.patch_stride[0], :, :]
|
171 |
+
|
172 |
+
size = self.pred_stride
|
173 |
+
label = input_vid.unfold(3, size, size).unfold(4, size, size)
|
174 |
+
label = label.permute(0, 2, 3, 4, 5, 6, 1) # Different from 2d, mistake during training lol
|
175 |
+
label = label.flatten(1, 3).flatten(2)
|
176 |
+
label = label[mask]
|
177 |
+
|
178 |
+
if norm:
|
179 |
+
mean = label.mean(dim=-1, keepdim=True)
|
180 |
+
var = label.var(dim=-1, keepdim=True)
|
181 |
+
label = (label - mean) / (var + 1.0e-6) ** 0.5
|
182 |
+
|
183 |
+
return label
|
184 |
+
|
185 |
+
|
186 |
+
def forward_encoder(
|
187 |
+
self, x: torch.Tensor, mask_ratio: float, mask: Optional[torch.Tensor] = None
|
188 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
189 |
+
|
190 |
+
if mask is None:
|
191 |
+
mask = self.get_random_mask(x, mask_ratio) # [B, #MUs_all]
|
192 |
+
|
193 |
+
# Get multi-scale representations from encoder
|
194 |
+
_, intermediates = super().forward(x, mask, return_intermediates=True)
|
195 |
+
# Resolution unchanged after q_pool stages, so skip those features
|
196 |
+
intermediates = intermediates[: self.q_pool] + intermediates[-1:]
|
197 |
+
|
198 |
+
# Multi-scale fusion
|
199 |
+
x = 0.0
|
200 |
+
for head, interm_x in zip(self.multi_scale_fusion_heads, intermediates):
|
201 |
+
x += apply_fusion_head(head, interm_x)
|
202 |
+
|
203 |
+
x = self.encoder_norm(x)
|
204 |
+
|
205 |
+
return x, mask
|
206 |
+
|
207 |
+
def forward_decoder(
|
208 |
+
self, x: torch.Tensor, mask: torch.Tensor
|
209 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
210 |
+
# Embed tokens
|
211 |
+
x = self.decoder_embed(x)
|
212 |
+
|
213 |
+
# Combine visible and mask tokens
|
214 |
+
|
215 |
+
# x: [B, #MUs, *mask_unit_spatial_shape_final, encoder_dim_out]
|
216 |
+
# mask: [B, #MUs_all]
|
217 |
+
x_dec = torch.zeros(*mask.shape, *x.shape[2:], device=x.device, dtype=x.dtype)
|
218 |
+
mask_tokens = self.mask_token.view(
|
219 |
+
(1,) * (len(mask.shape) + len(x.shape[2:-1])) + (-1,)
|
220 |
+
)
|
221 |
+
mask = mask.reshape(mask.shape + (1,) * len(x.shape[2:]))
|
222 |
+
mask = mask.expand((-1,) * 2 + x.shape[2:]).bool()
|
223 |
+
x_dec[mask] = x.flatten()
|
224 |
+
x_dec = ~mask * mask_tokens + mask * x_dec
|
225 |
+
|
226 |
+
# Get back spatial order
|
227 |
+
x = undo_windowing(
|
228 |
+
x_dec,
|
229 |
+
self.tokens_spatial_shape_final,
|
230 |
+
self.mask_unit_spatial_shape_final,
|
231 |
+
)
|
232 |
+
mask = undo_windowing(
|
233 |
+
mask[..., 0:1],
|
234 |
+
self.tokens_spatial_shape_final,
|
235 |
+
self.mask_unit_spatial_shape_final,
|
236 |
+
)
|
237 |
+
|
238 |
+
# Flatten
|
239 |
+
x = x.reshape(x.shape[0], -1, x.shape[-1])
|
240 |
+
mask = mask.view(x.shape[0], -1)
|
241 |
+
|
242 |
+
# Add pos embed
|
243 |
+
x = x + self.decoder_pos_embed
|
244 |
+
|
245 |
+
# Apply decoder blocks
|
246 |
+
for blk in self.decoder_blocks:
|
247 |
+
x = blk(x)
|
248 |
+
x = self.decoder_norm(x)
|
249 |
+
|
250 |
+
# Predictor projection
|
251 |
+
x = self.decoder_pred(x)
|
252 |
+
|
253 |
+
return x, mask
|
254 |
+
|
255 |
+
def forward_loss(
|
256 |
+
self, x: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor
|
257 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
258 |
+
"""
|
259 |
+
Note: in mask, 0 is *visible*, 1 is *masked*
|
260 |
+
|
261 |
+
x: e.g. [B, 3, H, W]
|
262 |
+
pred: [B * num_pred_tokens, num_pixels_in_pred_patch * in_chans]
|
263 |
+
label: [B * num_pred_tokens, num_pixels_in_pred_patch * in_chans]
|
264 |
+
"""
|
265 |
+
if len(self.q_stride) == 2:
|
266 |
+
label = self.get_pixel_label_2d(x, mask)
|
267 |
+
elif len(self.q_stride) == 3:
|
268 |
+
label = self.get_pixel_label_3d(x, mask)
|
269 |
+
else:
|
270 |
+
raise NotImplementedError
|
271 |
+
|
272 |
+
pred = pred[mask]
|
273 |
+
loss = (pred - label) ** 2
|
274 |
+
|
275 |
+
return loss.mean(), pred, label
|
276 |
+
|
277 |
+
def forward(
|
278 |
+
self,
|
279 |
+
x: torch.Tensor,
|
280 |
+
mask_ratio: float = 0.6,
|
281 |
+
mask: Optional[torch.Tensor] = None,
|
282 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
283 |
+
|
284 |
+
latent, mask = self.forward_encoder(x, mask_ratio, mask=mask)
|
285 |
+
pred, pred_mask = self.forward_decoder(
|
286 |
+
latent, mask
|
287 |
+
) # pred_mask is mask at resolution of *prediction*
|
288 |
+
|
289 |
+
# Toggle mask, to generate labels for *masked* tokens
|
290 |
+
return *self.forward_loss(x, pred, ~pred_mask), mask
|
291 |
+
|
292 |
+
|
293 |
+
|
294 |
+
|
295 |
+
# Image Models
|
296 |
+
|
297 |
+
@pretrained_model({
|
298 |
+
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth",
|
299 |
+
}, default="mae_in1k")
|
300 |
+
def mae_hiera_tiny_224(**kwargs):
|
301 |
+
return MaskedAutoencoderHiera(
|
302 |
+
embed_dim=96, num_heads=1, stages=(1, 2, 7, 2), q_pool=2, **kwargs,
|
303 |
+
)
|
304 |
+
|
305 |
+
|
306 |
+
@pretrained_model({
|
307 |
+
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth",
|
308 |
+
}, default="mae_in1k")
|
309 |
+
def mae_hiera_small_224(**kwargs):
|
310 |
+
return MaskedAutoencoderHiera(
|
311 |
+
embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), q_pool=2, **kwargs,
|
312 |
+
)
|
313 |
+
|
314 |
+
|
315 |
+
@pretrained_model({
|
316 |
+
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth",
|
317 |
+
}, default="mae_in1k")
|
318 |
+
def mae_hiera_base_224(**kwargs):
|
319 |
+
return MaskedAutoencoderHiera(
|
320 |
+
embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), q_pool=2, **kwargs,
|
321 |
+
)
|
322 |
+
|
323 |
+
|
324 |
+
@pretrained_model({
|
325 |
+
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth",
|
326 |
+
}, default="mae_in1k")
|
327 |
+
def mae_hiera_base_plus_224(**kwargs):
|
328 |
+
return MaskedAutoencoderHiera(
|
329 |
+
embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), q_pool=2, **kwargs,
|
330 |
+
)
|
331 |
+
|
332 |
+
|
333 |
+
@pretrained_model({
|
334 |
+
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth",
|
335 |
+
}, default="mae_in1k")
|
336 |
+
def mae_hiera_large_224(**kwargs):
|
337 |
+
return MaskedAutoencoderHiera(
|
338 |
+
embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), q_pool=2, **kwargs,
|
339 |
+
)
|
340 |
+
|
341 |
+
|
342 |
+
@pretrained_model({
|
343 |
+
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth",
|
344 |
+
}, default="mae_in1k")
|
345 |
+
def mae_hiera_huge_224(**kwargs):
|
346 |
+
return MaskedAutoencoderHiera(
|
347 |
+
embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), q_pool=2, **kwargs,
|
348 |
+
)
|
349 |
+
|
350 |
+
|
351 |
+
|
352 |
+
# Video Models
|
353 |
+
|
354 |
+
@pretrained_model({
|
355 |
+
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth",
|
356 |
+
}, default="mae_k400")
|
357 |
+
def mae_hiera_base_16x224(num_classes: int = 400, **kwdargs):
|
358 |
+
return MaskedAutoencoderHiera(
|
359 |
+
num_classes=num_classes, # K400 has 400 classes
|
360 |
+
input_size=(16, 224, 224),
|
361 |
+
q_stride=(1, 2, 2),
|
362 |
+
mask_unit_size=(1, 8, 8),
|
363 |
+
patch_kernel=(3, 7, 7),
|
364 |
+
patch_stride=(2, 4, 4),
|
365 |
+
patch_padding=(1, 3, 3),
|
366 |
+
sep_pos_embed=True,
|
367 |
+
q_pool=2,
|
368 |
+
**kwdargs
|
369 |
+
)
|
370 |
+
|
371 |
+
|
372 |
+
@pretrained_model({
|
373 |
+
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth",
|
374 |
+
}, default="mae_k400")
|
375 |
+
@pretrained_model(None)
|
376 |
+
def mae_hiera_base_plus_16x224(**kwdargs):
|
377 |
+
return mae_hiera_base_16x224(
|
378 |
+
embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs
|
379 |
+
)
|
380 |
+
|
381 |
+
|
382 |
+
@pretrained_model({
|
383 |
+
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth",
|
384 |
+
}, default="mae_k400")
|
385 |
+
@pretrained_model(None)
|
386 |
+
def mae_hiera_large_16x224(**kwdargs):
|
387 |
+
return mae_hiera_base_16x224(
|
388 |
+
embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs
|
389 |
+
)
|
390 |
+
|
391 |
+
|
392 |
+
@pretrained_model({
|
393 |
+
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth",
|
394 |
+
}, default="mae_k400")
|
395 |
+
def mae_hiera_huge_16x224(**kwdargs):
|
396 |
+
return mae_hiera_base_16x224(
|
397 |
+
embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs
|
398 |
+
)
|
hiera/hiera_utils.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
#
|
8 |
+
# Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles
|
9 |
+
#
|
10 |
+
# Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan,
|
11 |
+
# Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed,
|
12 |
+
# Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer.
|
13 |
+
#
|
14 |
+
# Paper: https://arxiv.org/abs/2306.00989/
|
15 |
+
#
|
16 |
+
# References:
|
17 |
+
# slowfast: https://github.com/facebookresearch/SlowFast
|
18 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
19 |
+
# --------------------------------------------------------
|
20 |
+
|
21 |
+
import math
|
22 |
+
from typing import List, Tuple, Optional, Type, Callable, Dict
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.nn as nn
|
26 |
+
import torch.nn.functional as F
|
27 |
+
|
28 |
+
|
29 |
+
def pretrained_model(checkpoints: Dict[str, str], default: str = None) -> Callable:
|
30 |
+
""" Loads a Hiera model from a pretrained source (if pretrained=True). Use "checkpoint" to specify the checkpoint. """
|
31 |
+
|
32 |
+
def inner(model_func: Callable) -> Callable:
|
33 |
+
def model_def(pretrained: bool = False, checkpoint: str = default, strict: bool = True, **kwdargs) -> nn.Module:
|
34 |
+
if pretrained:
|
35 |
+
if checkpoints is None:
|
36 |
+
raise RuntimeError("This model currently doesn't have pretrained weights available.")
|
37 |
+
elif checkpoint is None:
|
38 |
+
raise RuntimeError("No checkpoint specified.")
|
39 |
+
elif checkpoint not in checkpoints:
|
40 |
+
raise RuntimeError(f"Invalid checkpoint specified ({checkpoint}). Options are: {list(checkpoints.keys())}.")
|
41 |
+
|
42 |
+
state_dict = torch.hub.load_state_dict_from_url(checkpoints[checkpoint], map_location="cpu")
|
43 |
+
|
44 |
+
if "head.projection.weight" in state_dict["model_state"]:
|
45 |
+
# Set the number of classes equal to the state_dict only if the user doesn't want to overwrite it
|
46 |
+
if "num_classes" not in kwdargs:
|
47 |
+
kwdargs["num_classes"] = state_dict["model_state"]["head.projection.weight"].shape[0]
|
48 |
+
# If the user specified a different number of classes, remove the projection weights or else we'll error out
|
49 |
+
elif kwdargs["num_classes"] != state_dict["model_state"]["head.projection.weight"].shape[0]:
|
50 |
+
del state_dict["model_state"]["head.projection.weight"]
|
51 |
+
del state_dict["model_state"]["head.projection.bias"]
|
52 |
+
|
53 |
+
model = model_func(**kwdargs)
|
54 |
+
if pretrained:
|
55 |
+
# Disable being strict when trying to load a encoder-decoder model into an encoder-only model
|
56 |
+
if "decoder_pos_embed" in state_dict["model_state"] and not hasattr(model, "decoder_pos_embed"):
|
57 |
+
strict = False
|
58 |
+
|
59 |
+
model.load_state_dict(state_dict["model_state"], strict=strict)
|
60 |
+
|
61 |
+
return model
|
62 |
+
|
63 |
+
return model_def
|
64 |
+
|
65 |
+
return inner
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
def conv_nd(n: int) -> Type[nn.Module]:
|
70 |
+
"""
|
71 |
+
Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3.
|
72 |
+
If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises)
|
73 |
+
"""
|
74 |
+
return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
|
75 |
+
|
76 |
+
|
77 |
+
def do_pool(x: torch.Tensor, stride: int) -> torch.Tensor:
|
78 |
+
# Refer to `Unroll` to see how this performs a maxpool-Nd
|
79 |
+
return x.view(x.shape[0], stride, -1, x.shape[-1]).max(dim=1).values
|
80 |
+
|
81 |
+
|
82 |
+
def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tensor:
|
83 |
+
# target_size: [(T), (H), W]
|
84 |
+
# (spatial) mask: [B, C, (t), (h), w]
|
85 |
+
if mask is None:
|
86 |
+
return mask
|
87 |
+
|
88 |
+
assert len(mask.shape[2:]) == len(target_size)
|
89 |
+
if mask.shape[2:] != target_size:
|
90 |
+
return F.interpolate(mask.float(), size=target_size)
|
91 |
+
return mask
|
92 |
+
|
93 |
+
|
94 |
+
def do_masked_conv(
|
95 |
+
x: torch.Tensor, conv: nn.Module, mask: Optional[torch.Tensor] = None
|
96 |
+
) -> torch.Tensor:
|
97 |
+
"""Zero-out the masked regions of the input before conv.
|
98 |
+
Prevents leakage of masked regions when using overlapping kernels.
|
99 |
+
"""
|
100 |
+
if conv is None:
|
101 |
+
return x
|
102 |
+
if mask is None:
|
103 |
+
return conv(x)
|
104 |
+
|
105 |
+
mask = get_resized_mask(target_size=x.shape[2:], mask=mask)
|
106 |
+
return conv(x * mask.bool())
|
107 |
+
|
108 |
+
|
109 |
+
def undo_windowing(
|
110 |
+
x: torch.Tensor, shape: List[int], mu_shape: List[int]
|
111 |
+
) -> torch.Tensor:
|
112 |
+
"""
|
113 |
+
Restore spatial organization by undoing windowed organization of mask units.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
x: organized by mask units windows, e.g. in 2d [B, #MUy*#MUx, MUy, MUx, C]
|
117 |
+
shape: current spatial shape, if it were not organized into mask unit
|
118 |
+
windows, e.g. in 2d [B, #MUy*MUy, #MUx*MUx, C].
|
119 |
+
mu_shape: current mask unit shape, e.g. in 2d [MUy, MUx]
|
120 |
+
Returns:
|
121 |
+
x: e.g. in 2d, [B, #MUy*MUy, #MUx*MUx, C]
|
122 |
+
"""
|
123 |
+
D = len(shape)
|
124 |
+
B, C = x.shape[0], x.shape[-1]
|
125 |
+
# [B, #MUy*#MUx, MUy, MUx, C] -> [B, #MUy, #MUx, MUy, MUx, C]
|
126 |
+
num_MUs = [s // mu for s, mu in zip(shape, mu_shape)]
|
127 |
+
x = x.view(B, *num_MUs, *mu_shape, C)
|
128 |
+
|
129 |
+
# [B, #MUy, #MUx, MUy, MUx, C] -> [B, #MUy*MUy, #MUx*MUx, C]
|
130 |
+
permute = (
|
131 |
+
[0]
|
132 |
+
+ sum(
|
133 |
+
[list(p) for p in zip(range(1, 1 + D), range(1 + D, 1 + 2 * D))],
|
134 |
+
[],
|
135 |
+
)
|
136 |
+
+ [len(x.shape) - 1]
|
137 |
+
)
|
138 |
+
x = x.permute(permute).reshape(B, *shape, C)
|
139 |
+
|
140 |
+
return x
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
class Unroll(nn.Module):
|
145 |
+
"""
|
146 |
+
Reorders the tokens such that patches are contiguous in memory.
|
147 |
+
E.g., given [B, (H, W), C] and stride of (Sy, Sx), this will re-order the tokens as
|
148 |
+
[B, (Sy, Sx, H // Sy, W // Sx), C]
|
149 |
+
|
150 |
+
This allows operations like Max2d to be computed as x.view(B, Sx*Sy, -1, C).max(dim=1).
|
151 |
+
Not only is this faster, but it also makes it easy to support inputs of arbitrary
|
152 |
+
dimensions in addition to patch-wise sparsity.
|
153 |
+
|
154 |
+
Performing this operation multiple times in sequence puts entire windows as contiguous
|
155 |
+
in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
|
156 |
+
size 8x8 would be contiguous in memory, allowing operations like mask unit attention
|
157 |
+
computed easily and efficiently, while also allowing max to be applied sequentially.
|
158 |
+
|
159 |
+
Note: This means that intermediate values of the model are not in HxW order, so they
|
160 |
+
need to be re-rolled if you want to use the intermediate values as a HxW feature map.
|
161 |
+
The last block of the network is fine though, since by then the strides are all consumed.
|
162 |
+
"""
|
163 |
+
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
input_size: Tuple[int, ...],
|
167 |
+
patch_stride: Tuple[int, ...],
|
168 |
+
unroll_schedule: List[Tuple[int, ...]],
|
169 |
+
):
|
170 |
+
super().__init__()
|
171 |
+
self.size = [i // s for i, s in zip(input_size, patch_stride)]
|
172 |
+
self.schedule = unroll_schedule
|
173 |
+
|
174 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
175 |
+
"""
|
176 |
+
Input: Flattened patch embeddings [B, N, C]
|
177 |
+
Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd
|
178 |
+
"""
|
179 |
+
B, _, C = x.shape
|
180 |
+
|
181 |
+
cur_size = self.size
|
182 |
+
x = x.view(*([B] + cur_size + [C]))
|
183 |
+
|
184 |
+
for strides in self.schedule:
|
185 |
+
# Move patches with the given strides to the batch dimension
|
186 |
+
|
187 |
+
# Create a view of the tensor with the patch stride as separate dims
|
188 |
+
# For example in 2d: [B, H // Sy, Sy, W // Sx, Sx, C]
|
189 |
+
cur_size = [i // s for i, s in zip(cur_size, strides)]
|
190 |
+
new_shape = [B] + sum([[i, s] for i, s in zip(cur_size, strides)], []) + [C]
|
191 |
+
x = x.view(new_shape)
|
192 |
+
|
193 |
+
# Move the patch stride into the batch dimension
|
194 |
+
# For example in 2d: [B, Sy, Sx, H // Sy, W // Sx, C]
|
195 |
+
L = len(new_shape)
|
196 |
+
permute = (
|
197 |
+
[0] + list(range(2, L - 1, 2)) + list(range(1, L - 1, 2)) + [L - 1]
|
198 |
+
)
|
199 |
+
x = x.permute(permute)
|
200 |
+
|
201 |
+
# Now finally flatten the relevant dims into the batch dimension
|
202 |
+
x = x.flatten(0, len(strides))
|
203 |
+
B *= math.prod(strides)
|
204 |
+
|
205 |
+
x = x.reshape(-1, math.prod(self.size), C)
|
206 |
+
return x
|
207 |
+
|
208 |
+
|
209 |
+
class Reroll(nn.Module):
|
210 |
+
"""
|
211 |
+
Undos the "unroll" operation so that you can use intermediate features.
|
212 |
+
"""
|
213 |
+
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
input_size: Tuple[int, ...],
|
217 |
+
patch_stride: Tuple[int, ...],
|
218 |
+
unroll_schedule: List[Tuple[int, ...]],
|
219 |
+
stage_ends: List[int],
|
220 |
+
q_pool: int,
|
221 |
+
):
|
222 |
+
super().__init__()
|
223 |
+
self.size = [i // s for i, s in zip(input_size, patch_stride)]
|
224 |
+
|
225 |
+
# The first stage has to reverse everything
|
226 |
+
# The next stage has to reverse all but the first unroll, etc.
|
227 |
+
self.schedule = {}
|
228 |
+
size = self.size
|
229 |
+
for i in range(stage_ends[-1] + 1):
|
230 |
+
self.schedule[i] = unroll_schedule, size
|
231 |
+
# schedule unchanged if no pooling at a stage end
|
232 |
+
if i in stage_ends[:q_pool]:
|
233 |
+
if len(unroll_schedule) > 0:
|
234 |
+
size = [n // s for n, s in zip(size, unroll_schedule[0])]
|
235 |
+
unroll_schedule = unroll_schedule[1:]
|
236 |
+
|
237 |
+
def forward(
|
238 |
+
self, x: torch.Tensor, block_idx: int, mask: torch.Tensor = None
|
239 |
+
) -> torch.Tensor:
|
240 |
+
"""
|
241 |
+
Roll the given tensor back up to spatial order assuming it's from the given block.
|
242 |
+
|
243 |
+
If no mask is provided:
|
244 |
+
- Returns [B, H, W, C] for 2d, [B, T, H, W, C] for 3d, etc.
|
245 |
+
If a mask is provided:
|
246 |
+
- Returns [B, #MUs, MUy, MUx, C] for 2d, etc.
|
247 |
+
"""
|
248 |
+
schedule, size = self.schedule[block_idx]
|
249 |
+
B, N, C = x.shape
|
250 |
+
|
251 |
+
D = len(size)
|
252 |
+
cur_mu_shape = [1] * D
|
253 |
+
|
254 |
+
for strides in schedule:
|
255 |
+
# Extract the current patch from N
|
256 |
+
x = x.view(B, *strides, N // math.prod(strides), *cur_mu_shape, C)
|
257 |
+
|
258 |
+
# Move that patch into the current MU
|
259 |
+
# Example in 2d: [B, Sy, Sx, N//(Sy*Sx), MUy, MUx, C] -> [B, N//(Sy*Sx), Sy, MUy, Sx, MUx, C]
|
260 |
+
L = len(x.shape)
|
261 |
+
permute = (
|
262 |
+
[0, 1 + D]
|
263 |
+
+ sum(
|
264 |
+
[list(p) for p in zip(range(1, 1 + D), range(1 + D + 1, L - 1))],
|
265 |
+
[],
|
266 |
+
)
|
267 |
+
+ [L - 1]
|
268 |
+
)
|
269 |
+
x = x.permute(permute)
|
270 |
+
|
271 |
+
# Reshape to [B, N//(Sy*Sx), *MU, C]
|
272 |
+
for i in range(D):
|
273 |
+
cur_mu_shape[i] *= strides[i]
|
274 |
+
x = x.reshape(B, -1, *cur_mu_shape, C)
|
275 |
+
N = x.shape[1]
|
276 |
+
|
277 |
+
# Current shape (e.g., 2d: [B, #MUy*#MUx, MUy, MUx, C])
|
278 |
+
x = x.view(B, N, *cur_mu_shape, C)
|
279 |
+
|
280 |
+
# If masked, return [B, #MUs, MUy, MUx, C]
|
281 |
+
if mask is not None:
|
282 |
+
return x
|
283 |
+
|
284 |
+
# If not masked, we can return [B, H, W, C]
|
285 |
+
x = undo_windowing(x, size, cur_mu_shape)
|
286 |
+
|
287 |
+
return x
|