Bingchen Zhao commited on
Commit
9ad81d2
1 Parent(s): 7e80db4

init commit

Browse files
Files changed (2) hide show
  1. app.py +108 -0
  2. model.py +846 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import MaskedAutoencoderViT, mae_vit_base_patch16
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from transformers import AutoTokenizer
8
+ from collections import OrderedDict
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', )
12
+
13
+ ckpt = torch.load(hf_hub_download('tennant/MUG', 'mae_bert_vit_b_cc3m.pth'))
14
+
15
+ new_dict = OrderedDict()
16
+ for k, v in ckpt.items():
17
+ k = k[len('image_encoder.model.'):]
18
+ new_dict.update({k: v})
19
+
20
+ model = mae_vit_base_patch16(uni_dim=768, less_u=True)
21
+
22
+ model.load_state_dict(new_dict)
23
+ if torch.cuda.is_available():
24
+ model.cuda()
25
+ model.eval()
26
+
27
+ @torch.no_grad()
28
+ def visual_recon(x, model):
29
+ target = model.patchify(x)
30
+ mean = target.mean(dim=-1, keepdim=True)
31
+ var = target.var(dim=-1, keepdim=True)
32
+
33
+ latent, mask, ids_restore, _ = model.forward_encoder(x, mask_ratio=0.75)
34
+ y, _ = model.forward_decoder(latent, ids_restore)
35
+ y = y * (var + 1.e-6)**.5 + mean
36
+ y = model.unpatchify(y)
37
+ y = torch.einsum('nchw->nhwc', y)
38
+
39
+ mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3)
40
+ mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
41
+ mask = torch.einsum('nchw->nhwc', mask)
42
+
43
+ x = torch.einsum('nchw->nhwc', x)
44
+
45
+ return x * (1 - mask), x * (1 - mask) + y * mask, y, latent
46
+
47
+ @torch.no_grad()
48
+ def caption_next_word(latent, model, tokenizer, prefix='a photo of a'):
49
+ assert latent.shape[0] == 1, 'can only caption one image at a time'
50
+
51
+ x_l = torch.tensor(tokenizer([prefix, ])['input_ids'])[:, :-1]
52
+ seq = x_l.shape[1]
53
+ if torch.cuda.is_available():
54
+ x_l = x_l.cuda()
55
+
56
+ cls_mask = rearrange(x_l != 0, 'b j -> b 1 j')
57
+ attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
58
+
59
+ x_l = model.embed_text(x_l)
60
+
61
+ for cross_attn1, cross_attn2 in model.multimodal_layers:
62
+ x_l = cross_attn1(x_l, latent)
63
+ x_l = cross_attn2(x_l, latent)
64
+
65
+ pred = model.to_logits(x_l)
66
+ next_word = pred.argmax(dim=-1)[0, -1]
67
+ next_word = tokenizer.decode(next_word)
68
+
69
+ return next_word
70
+
71
+ def caption(max_len, latent, model, tokenizer, prefix='a photo of a'):
72
+ words = prefix.split()
73
+ while len(words) < max_len:
74
+ next_word = caption_next_word(latent, model, tokenizer, prefix=' '.join(words))
75
+ words.append(next_word)
76
+ return ' '.join(words)
77
+
78
+
79
+ def gr_caption(x):
80
+ imagenet_mean = np.array([0.485, 0.456, 0.406])
81
+ imagenet_std = np.array([0.229, 0.224, 0.225])
82
+ x = np.array(x) / 255.
83
+ x = x - imagenet_mean
84
+ x = x / imagenet_std
85
+
86
+ x = torch.tensor(x).float()
87
+ x = x.unsqueeze(0)
88
+ x = torch.einsum('nhwc->nchw', x)
89
+ if torch.cuda.is_available():
90
+ x = x.cuda()
91
+
92
+ def unnorm_pix(img):
93
+ img = img.squeeze(0).cpu().detach().numpy()
94
+ img = img * imagenet_std + imagenet_mean
95
+ return np.clip(img, a_min=0., a_max=1.)
96
+
97
+ masked, masked_recon, recon, latent = visual_recon(x, model)
98
+ caption_from_model = caption(10, latent, model, tokenizer, )
99
+
100
+ masked, masked_recon, recon = map(unnorm_pix, (masked, masked_recon, recon))
101
+
102
+ return masked, masked_recon, recon, caption_from_model
103
+
104
+ import gradio as gr
105
+
106
+ demo = gr.Interface(gr_caption, inputs=[gr.Image(shape=(224, 224))], outputs=[gr.Image(shape=(224, 224)), gr.Image(shape=(224, 224)), gr.Image(shape=(224, 224)), 'text'])
107
+ demo.launch()
108
+
model.py ADDED
@@ -0,0 +1,846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ from torch._C import Value
16
+ import torch.nn as nn
17
+ import numpy as np
18
+
19
+ from timm.models.vision_transformer import PatchEmbed, Block
20
+ from transformers import EncoderDecoderModel, BertTokenizer, AutoTokenizer
21
+
22
+
23
+ from torch import einsum, nn
24
+ import torch.nn.functional as F
25
+ from einops import rearrange, repeat
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+
31
+ class FocalLoss(nn.CrossEntropyLoss):
32
+ ''' Focal loss for classification tasks on imbalanced datasets '''
33
+
34
+ def __init__(self, gamma=1.0, alpha=None, ignore_index=-100, reduction='none'):
35
+ super().__init__(weight=alpha, ignore_index=ignore_index, reduction='none')
36
+ self.reduction = reduction
37
+ self.gamma = gamma
38
+
39
+ def forward(self, input_, target):
40
+ cross_entropy = super().forward(input_, target)
41
+ # Temporarily mask out ignore index to '0' for valid gather-indices input.
42
+ # This won't contribute final loss as the cross_entropy contribution
43
+ # for these would be zero.
44
+ target = target * (target != self.ignore_index).long()
45
+ input_prob = torch.gather(F.softmax(input_, 1), 1, target.unsqueeze(1)).squeeze(1)
46
+ loss = torch.pow(1 - input_prob, self.gamma) * cross_entropy
47
+ return torch.mean(loss) if self.reduction == 'mean' \
48
+ else torch.sum(loss) if self.reduction == 'sum' \
49
+ else loss
50
+
51
+
52
+ # helper functions
53
+
54
+ import math
55
+ from functools import reduce
56
+
57
+ def prob_mask_like(t, prob):
58
+ return torch.zeros_like(t).float().uniform_(0, 1) < prob
59
+
60
+ def mask_with_tokens(t, token_ids):
61
+ init_no_mask = torch.full_like(t, False, dtype=torch.bool)
62
+ mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
63
+ return mask
64
+
65
+ def get_mask_subset_with_prob(mask, prob):
66
+ batch, seq_len, device = *mask.shape, mask.device
67
+ max_masked = math.ceil(prob * seq_len)
68
+
69
+ num_tokens = mask.sum(dim=-1, keepdim=True)
70
+ mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
71
+ mask_excess = mask_excess[:, :max_masked]
72
+
73
+ rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
74
+ _, sampled_indices = rand.topk(max_masked, dim=-1)
75
+ sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)
76
+
77
+ new_mask = torch.zeros((batch, seq_len + 1), device=device)
78
+ new_mask.scatter_(-1, sampled_indices, 1)
79
+ return new_mask[:, 1:].bool()
80
+
81
+
82
+ def exists(val):
83
+ return val is not None
84
+
85
+ def default(val, d):
86
+ return val if exists(val) else d
87
+
88
+ # normalization
89
+ # they use layernorm without bias, something that pytorch does not offer
90
+
91
+
92
+ class LayerNorm(nn.Module):
93
+ def __init__(self, dim):
94
+ super().__init__()
95
+ self.gamma = nn.Parameter(torch.ones(dim))
96
+ self.register_buffer("beta", torch.zeros(dim))
97
+
98
+ def forward(self, x):
99
+ return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
100
+
101
+ # residual
102
+ class Residual(nn.Module):
103
+ def __init__(self, fn):
104
+ super().__init__()
105
+ self.fn = fn
106
+
107
+ def forward(self, x, *args, **kwargs):
108
+ return self.fn(x, *args, **kwargs) + x
109
+
110
+ # rotary positional embedding
111
+ # https://arxiv.org/abs/2104.09864
112
+ class RotaryEmbedding(nn.Module):
113
+ def __init__(self, dim):
114
+ super().__init__()
115
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
116
+ self.register_buffer("inv_freq", inv_freq)
117
+
118
+ def forward(self, max_seq_len, *, device):
119
+ seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
120
+ freqs = einsum("i , j -> i j", seq, self.inv_freq)
121
+ return torch.cat((freqs, freqs), dim=-1)
122
+
123
+
124
+ def rotate_half(x):
125
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
126
+ x1, x2 = x.unbind(dim=-2)
127
+ return torch.cat((-x2, x1), dim=-1)
128
+
129
+
130
+ def apply_rotary_pos_emb(pos, t):
131
+ return (t * pos.cos()) + (rotate_half(t) * pos.sin())
132
+
133
+
134
+ # classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GELU for gating the feedforward
135
+ # https://arxiv.org/abs/2002.05202
136
+ class SwiGLU(nn.Module):
137
+ def forward(self, x):
138
+ x, gate = x.chunk(2, dim=-1)
139
+ return F.silu(gate) * x
140
+
141
+
142
+ # parallel attention and feedforward with residual
143
+ # discovered by Wang et al + EleutherAI from GPT-J fame
144
+ class ParallelTransformerBlock(nn.Module):
145
+ def __init__(self, dim, dim_head=64, heads=8, ff_mult=4, attn_drop_rate=0.0):
146
+ super().__init__()
147
+ self.norm = LayerNorm(dim)
148
+
149
+ attn_inner_dim = dim_head * heads
150
+ ff_inner_dim = dim * ff_mult
151
+ self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
152
+
153
+ self.heads = heads
154
+ self.scale = dim_head**-0.5
155
+ self.rotary_emb = RotaryEmbedding(dim_head)
156
+
157
+ self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
158
+ self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
159
+
160
+ self.ff_out = nn.Sequential(
161
+ SwiGLU(),
162
+ nn.Linear(ff_inner_dim, dim, bias=False)
163
+ )
164
+
165
+ self.attn_drop_rate = attn_drop_rate
166
+
167
+ # for caching causal mask and rotary embeddings
168
+
169
+ self.register_buffer("mask", None, persistent=False)
170
+ self.register_buffer("pos_emb", None, persistent=False)
171
+
172
+ def get_mask(self, n, device):
173
+ if self.mask is not None and self.mask.shape[-1] >= n:
174
+ return self.mask[:n, :n]
175
+
176
+ mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
177
+ self.register_buffer("mask", mask, persistent=False)
178
+ return mask
179
+
180
+ def get_rotary_embedding(self, n, device):
181
+ if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
182
+ return self.pos_emb[:n]
183
+
184
+ pos_emb = self.rotary_emb(n, device=device)
185
+ self.register_buffer("pos_emb", pos_emb, persistent=False)
186
+ return pos_emb
187
+
188
+ def forward(self, x, attn_mask=None):
189
+ """
190
+ Performs self attention and feedforward
191
+ einstein notation
192
+ b - batch
193
+ h - heads
194
+ n, i, j - sequence length (base sequence length, source, target)
195
+ d - feature dimension
196
+ """
197
+
198
+ n, device, h = x.shape[1], x.device, self.heads
199
+ # pre layernorm
200
+ x = self.norm(x)
201
+ # attention queries, keys, values, and feedforward inner
202
+ q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
203
+
204
+ # split heads
205
+ # they use multi-query single-key-value attention, yet another Noam Shazeer paper
206
+ # they found no performance loss past a certain scale, and more efficient decoding obviously
207
+ # https://arxiv.org/abs/1911.02150
208
+ q = rearrange(q, "b n (h d) -> b h n d", h=h)
209
+ # rotary embeddings
210
+ positions = self.get_rotary_embedding(n, device)
211
+ q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
212
+ # scale
213
+ q = q * self.scale
214
+ # similarity
215
+ sim = einsum("b h i d, b j d -> b h i j", q, k)
216
+ # causal mask
217
+ causal_mask = self.get_mask(n, device)
218
+ sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
219
+
220
+ # extra attention mask - for masking out attention from text CLS token to padding
221
+ if exists(attn_mask):
222
+ attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
223
+ sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
224
+
225
+ if self.attn_drop_rate != 0.:
226
+ # import ipdb; ipdb.set_trace()
227
+ drop_ind = sim != -torch.finfo(sim.dtype).max
228
+ dropout_mask = torch.cuda.FloatTensor(*sim[drop_ind].shape).uniform_() > self.attn_drop_rate
229
+ sim[drop_ind] = sim[drop_ind].masked_fill(~dropout_mask, -torch.finfo(sim.dtype).max)
230
+
231
+ # attention
232
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
233
+ attn = sim.softmax(dim=-1)
234
+ # aggregate values
235
+ out = einsum("b h i j, b j d -> b h i d", attn, v)
236
+ # merge heads
237
+ out = rearrange(out, "b h n d -> b n (h d)")
238
+ return self.attn_out(out) + self.ff_out(ff)
239
+
240
+ # cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward
241
+ class CrossAttention(nn.Module):
242
+ def __init__(
243
+ self,
244
+ dim,
245
+ *,
246
+ context_dim=None,
247
+ dim_head=64,
248
+ heads=8,
249
+ parallel_ff=False,
250
+ ff_mult=4,
251
+ norm_context=False,
252
+ dropout=0.0,
253
+ ):
254
+ super().__init__()
255
+ self.heads = heads
256
+ self.scale = dim_head ** -0.5
257
+ inner_dim = heads * dim_head
258
+ context_dim = default(context_dim, dim)
259
+
260
+ self.norm = LayerNorm(dim)
261
+ self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
262
+
263
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
264
+ self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
265
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
266
+
267
+ self.dropout = dropout
268
+
269
+ # whether to have parallel feedforward
270
+ ff_inner_dim = ff_mult * dim
271
+
272
+ self.ff = nn.Sequential(
273
+ nn.Linear(dim, ff_inner_dim * 2, bias=False),
274
+ SwiGLU(),
275
+ nn.Linear(ff_inner_dim, dim, bias=False)
276
+ ) if parallel_ff else None
277
+
278
+ def forward(self, x, context):
279
+ """
280
+ Use text and query, and image as kv
281
+ einstein notation
282
+ b - batch
283
+ h - heads
284
+ n, i, j - sequence length (base sequence length, source, target)
285
+ d - feature dimension
286
+ """
287
+
288
+ # pre-layernorm, for queries and context
289
+ x = self.norm(x)
290
+ context = self.context_norm(context)
291
+ # get queries
292
+ q = self.to_q(x)
293
+ q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
294
+ # scale
295
+ q = q * self.scale
296
+ # get key / values
297
+ k, v = self.to_kv(context).chunk(2, dim=-1)
298
+ # query / key similarity
299
+ sim = einsum('b h i d, b j d -> b h i j', q, k)
300
+
301
+ # dropout
302
+ if self.training:
303
+ dropout_mask = torch.cuda.FloatTensor(*sim.shape).uniform_() > self.dropout
304
+ sim = sim.masked_fill(~dropout_mask, -torch.finfo(sim.dtype).max)
305
+
306
+ # attention
307
+ sim = sim - sim.amax(dim=-1, keepdim=True)
308
+ attn = sim.softmax(dim=-1)
309
+ # aggregate
310
+ out = einsum('b h i j, b j d -> b h i d', attn, v)
311
+ # merge and combine heads
312
+ out = rearrange(out, 'b h n d -> b n (h d)')
313
+ out = self.to_out(out)
314
+ # add parallel feedforward (for multimodal layers)
315
+ if exists(self.ff):
316
+ out = out + self.ff(x)
317
+ return out
318
+
319
+
320
+
321
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
322
+ """
323
+ grid_size: int of the grid height and width
324
+ return:
325
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
326
+ """
327
+ grid_h = np.arange(grid_size, dtype=np.float32)
328
+ grid_w = np.arange(grid_size, dtype=np.float32)
329
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
330
+ grid = np.stack(grid, axis=0)
331
+
332
+ grid = grid.reshape([2, 1, grid_size, grid_size])
333
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
334
+ if cls_token:
335
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
336
+ return pos_embed
337
+
338
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
339
+ assert embed_dim % 2 == 0
340
+
341
+ # use half of dimensions to encode grid_h
342
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
343
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
344
+
345
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
346
+ return emb
347
+
348
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
349
+ """
350
+ embed_dim: output dimension for each position
351
+ pos: a list of positions to be encoded: size (M,)
352
+ out: (M, D)
353
+ """
354
+ assert embed_dim % 2 == 0
355
+ omega = np.arange(embed_dim // 2, dtype=np.float)
356
+ omega /= embed_dim / 2.
357
+ omega = 1. / 10000**omega # (D/2,)
358
+
359
+ pos = pos.reshape(-1) # (M,)
360
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
361
+
362
+ emb_sin = np.sin(out) # (M, D/2)
363
+ emb_cos = np.cos(out) # (M, D/2)
364
+
365
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
366
+ return emb
367
+
368
+ class MaskedAutoencoderViT(nn.Module):
369
+ """ Masked Autoencoder with VisionTransformer backbone
370
+ """
371
+ def __init__(self, img_size=224, patch_size=16, in_chans=3,
372
+ embed_dim=1024, depth=24, num_heads=16,
373
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
374
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=True,
375
+ unimodal_depth=2, multimodal_depth=8, dim_head=64,heads=8,
376
+ ff_mult=4, extract_multi_level=False, use_focal_loss=False, focal_gamma=1.0,
377
+ less_u=False, use_weak_negative=False, use_label_smooth=False, ls_coef=0.1,
378
+ use_maximum_entropy=False, ce_additional=False, use_word_weights=False, use_token_pos=False,
379
+ use_expect_k=False, use_top_k=False, mae_decoder_caption=False, decoder_slot_depth=2, disable_decoder_vis_token_grad=False,
380
+ cross_attn_dropout=0.0, predict_next_k_words=False, next_k=3, masked_text=False, masked_text_ratio=0.25, text_length=70,
381
+ projector_layer=0, uni_dim=1024, uni_dim_head=64, uni_heads=8, uni_ff_mult=4, text_drop_attn=0.):
382
+ super().__init__()
383
+
384
+ # --------------------------------------------------------------------------
385
+ # MAE encoder specifics
386
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
387
+ num_patches = self.patch_embed.num_patches
388
+
389
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
390
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
391
+
392
+ self.blocks = nn.ModuleList([
393
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
394
+ for i in range(depth)])
395
+ self.norm = norm_layer(embed_dim)
396
+ # --------------------------------------------------------------------------
397
+
398
+ # --------------------------------------------------------------------------
399
+ # MAE decoder specifics
400
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
401
+
402
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
403
+
404
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
405
+
406
+ self.mae_decoder_depth = decoder_depth
407
+ self.mae_decoder_caption = mae_decoder_caption
408
+ self.decoder_blocks = nn.ModuleList([
409
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
410
+ for i in range(decoder_depth)])
411
+
412
+ if self.mae_decoder_caption:
413
+
414
+ self.decoder_slot_layers = nn.ModuleList([])
415
+ for _ in range(decoder_slot_depth):
416
+ self.decoder_slot_layers.append(
417
+ Residual(CrossAttention(dim=decoder_embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult,)),
418
+ # Residual(CrossAttention(dim=decoder_embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult,))
419
+ )
420
+ self.decoder_caption_proj = nn.Linear(decoder_embed_dim, embed_dim)
421
+ self.disable_decoder_vis_token_grad = disable_decoder_vis_token_grad
422
+
423
+ self.decoder_norm = norm_layer(decoder_embed_dim)
424
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # encoder to decoder
425
+ # --------------------------------------------------------------------------
426
+
427
+ self.norm_pix_loss = norm_pix_loss
428
+
429
+ # --------------------------------------------------------------------------
430
+ # captioner specifics
431
+ # unimodal layer is for text tokens.
432
+ # multimodal layer is for text to query from image.
433
+ self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased",
434
+ cache_dir='/disk/scratch_fast/bingchen/.cache/torch/hub/checkpoints/bert-base-uncased', local_files_only=True)
435
+
436
+ # token embeddings
437
+ # NOTE: +1 for mask token used by MLM objective
438
+ # self.token_emb = nn.Embedding(len(self.tokenizer.vocab) + 1, uni_dim)
439
+
440
+ self.token_emb = nn.Embedding(len(self.tokenizer.vocab), uni_dim)
441
+ self.text_cls_token = nn.Parameter(torch.randn(uni_dim))
442
+
443
+ self.embed_dim = embed_dim
444
+ self.uni_dim = uni_dim
445
+
446
+ #import ipdb; ipdb.set_trace()
447
+ # unimodal layers
448
+ # TODO: search on the four parameters
449
+ # uni_dim=1024, uni_dim_head=64, uni_heads=8, uni_ff_mult=4
450
+ self.text_drop_attn = text_drop_attn
451
+ self.unimodal_layers = nn.ModuleList([])
452
+ for _ in range(unimodal_depth):
453
+ self.unimodal_layers.append(
454
+ Residual(ParallelTransformerBlock(dim=uni_dim, dim_head=uni_dim_head,
455
+ heads=uni_heads, ff_mult=uni_ff_mult, attn_drop_rate=self.text_drop_attn)),
456
+ )
457
+
458
+ self.need_uni_2_mul_proj = False
459
+ if uni_dim != embed_dim:
460
+ self.need_uni_2_mul_proj = True
461
+ self.uni_2_mul_proj = nn.Linear(uni_dim, embed_dim)
462
+
463
+ # multimodal layers
464
+ self.multimodal_layers = nn.ModuleList([])
465
+ self.less_u = less_u
466
+ if less_u:
467
+ for _ in range(multimodal_depth):
468
+ self.multimodal_layers.append(nn.ModuleList([
469
+ Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout)),
470
+ Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout))
471
+ ]))
472
+ else:
473
+ for _ in range(multimodal_depth):
474
+ self.multimodal_layers.append(nn.ModuleList([
475
+ Residual(ParallelTransformerBlock(dim=embed_dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)),
476
+ Residual(CrossAttention(dim=embed_dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult, dropout=cross_attn_dropout))
477
+ ]))
478
+
479
+ # to logits: for softmax caption loss
480
+ self.to_logits = nn.Sequential(
481
+ LayerNorm(embed_dim),
482
+ nn.Linear(embed_dim, len(self.tokenizer.vocab), bias=False)
483
+ )
484
+
485
+ self.ce_additional = ce_additional
486
+ if ce_additional:
487
+ # to logits: for other losses
488
+ self.to_logits_1 = nn.Sequential(
489
+ LayerNorm(embed_dim),
490
+ nn.Linear(embed_dim, len(self.tokenizer.vocab), bias=False)
491
+ )
492
+
493
+ nn.init.normal_(self.token_emb.weight, std=0.02)
494
+
495
+ self.pad_id = 0
496
+ self.cls_id = 101
497
+ self.sep_id = 102
498
+
499
+ self.logsoftmax = nn.LogSoftmax(dim=1)
500
+
501
+ self.extract_multi_level = extract_multi_level
502
+ if self.extract_multi_level:
503
+ self.projectors = nn.ModuleList([nn.Sequential(
504
+ nn.Linear(embed_dim, embed_dim // 2),
505
+ nn.GELU(),
506
+ nn.Linear(embed_dim // 2, embed_dim),
507
+ norm_layer(embed_dim)
508
+ ) for _ in [2, 5, 8,]])
509
+ # --------------------------------------------------------------------------
510
+
511
+ self.use_focal_loss = use_focal_loss
512
+
513
+ self.use_weak_negative = use_weak_negative
514
+ self.use_label_smooth = use_label_smooth
515
+ self.ls_coef = ls_coef
516
+ self.use_entropy = use_maximum_entropy
517
+ self.use_word_weights = use_word_weights
518
+ self.use_token_pos = use_token_pos
519
+
520
+ self.predict_next_k_words = predict_next_k_words
521
+ self.next_k = next_k
522
+ self.pad = torch.nn.ReplicationPad1d((0, self.next_k-1))
523
+
524
+ self.use_expect_k = use_expect_k
525
+ self.use_top_k = use_top_k
526
+
527
+ if self.use_word_weights or self.use_token_pos:
528
+ self.focal_loss = FocalLoss(ignore_index=self.pad_id, gamma=focal_gamma, reduction='none')
529
+ else:
530
+ self.focal_loss = FocalLoss(ignore_index=self.pad_id, gamma=focal_gamma, reduction='mean')
531
+
532
+ self.masked_text = masked_text
533
+ self.masked_text_ratio = masked_text_ratio
534
+ # self.text_mask_token = nn.Parameter(torch.randn(embed_dim))
535
+ self.mask_token_id = len(self.tokenizer.vocab)
536
+
537
+ # self.text_position_embed = nn.Parameter(torch.zeros(1, text_length, embed_dim), requires_grad=False)
538
+ self.text_length = text_length
539
+
540
+ self.latent_projector_layer = projector_layer
541
+ if self.latent_projector_layer != 0:
542
+ self.latent_projector = [
543
+ nn.Linear(embed_dim, embed_dim),
544
+ nn.ReLU()
545
+ ] * (self.latent_projector_layer - 1)
546
+ self.latent_projector.append(nn.Linear(embed_dim, embed_dim))
547
+
548
+ self.latent_projector = nn.Sequential(*self.latent_projector)
549
+
550
+
551
+ self.initialize_weights()
552
+
553
+
554
+ def initialize_weights(self):
555
+ # initialization
556
+ # initialize (and freeze) pos_embed by sin-cos embedding
557
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
558
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
559
+
560
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
561
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
562
+
563
+ # text_pos_embed = get_1d_sincos_pos_embed_from_grid(self.embed_dim, )
564
+ # torch.nn.init.xavier_normal_(self.text_position_embed) # learnable text position embedding
565
+
566
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
567
+ w = self.patch_embed.proj.weight.data
568
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
569
+
570
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
571
+ torch.nn.init.normal_(self.cls_token, std=.02)
572
+ torch.nn.init.normal_(self.mask_token, std=.02)
573
+ # torch.nn.init.normal_(self.text_mask_token, std=.02)
574
+
575
+ # initialize nn.Linear and nn.LayerNorm
576
+ self.apply(self._init_weights)
577
+
578
+ def _init_weights(self, m):
579
+ if isinstance(m, nn.Linear):
580
+ # we use xavier_uniform following official JAX ViT:
581
+ torch.nn.init.xavier_uniform_(m.weight)
582
+ if isinstance(m, nn.Linear) and m.bias is not None:
583
+ nn.init.constant_(m.bias, 0)
584
+ elif isinstance(m, nn.LayerNorm):
585
+ nn.init.constant_(m.bias, 0)
586
+ nn.init.constant_(m.weight, 1.0)
587
+
588
+ def patchify(self, imgs):
589
+ """
590
+ imgs: (N, 3, H, W)
591
+ x: (N, L, patch_size**2 *3)
592
+ """
593
+ p = self.patch_embed.patch_size[0]
594
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
595
+
596
+ h = w = imgs.shape[2] // p
597
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
598
+ x = torch.einsum('nchpwq->nhwpqc', x)
599
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
600
+ return x
601
+
602
+ def unpatchify(self, x):
603
+ """
604
+ x: (N, L, patch_size**2 *3)
605
+ imgs: (N, 3, H, W)
606
+ """
607
+ p = self.patch_embed.patch_size[0]
608
+ h = w = int(x.shape[1]**.5)
609
+ assert h * w == x.shape[1]
610
+
611
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
612
+ x = torch.einsum('nhwpqc->nchpwq', x)
613
+ imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
614
+ return imgs
615
+
616
+ def random_masking(self, x, mask_ratio):
617
+ """
618
+ Perform per-sample random masking by per-sample shuffling.
619
+ Per-sample shuffling is done by argsort random noise.
620
+ x: [N, L, D], sequence
621
+ """
622
+ N, L, D = x.shape # batch, length, dim
623
+ len_keep = int(L * (1 - mask_ratio))
624
+
625
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
626
+
627
+ # sort noise for each sample
628
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
629
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
630
+
631
+ # keep the first subset
632
+ ids_keep = ids_shuffle[:, :len_keep]
633
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
634
+
635
+ # generate the binary mask: 0 is keep, 1 is remove
636
+ mask = torch.ones([N, L], device=x.device)
637
+ mask[:, :len_keep] = 0
638
+ # unshuffle to get the binary mask
639
+ mask = torch.gather(mask, dim=1, index=ids_restore)
640
+
641
+ return x_masked, mask, ids_restore, ids_keep
642
+
643
+ def forward_encoder(self, x, mask_ratio):
644
+ # embed patches
645
+ x = self.patch_embed(x)
646
+
647
+ # add pos embed w/o cls token
648
+ x = x + self.pos_embed[:, 1:, :]
649
+
650
+ # masking: length -> length * mask_ratio
651
+ x, mask, ids_restore, ids_keep = self.random_masking(x, mask_ratio)
652
+
653
+ # append cls token
654
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
655
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
656
+ x = torch.cat((cls_tokens, x), dim=1)
657
+
658
+ if self.extract_multi_level:
659
+ multi_level_feats = []
660
+ # apply Transformer blocks
661
+ for blk_idx, blk in enumerate(self.blocks):
662
+ x = blk(x)
663
+ if blk_idx in [2, 5, 8]:
664
+ multi_level_feats.append(self.projectors[[2,5,8].index(blk_idx)](x))
665
+ x = self.norm(x)
666
+ multi_level_feats.append(x)
667
+
668
+ return multi_level_feats, mask, ids_restore
669
+
670
+
671
+ # apply Transformer blocks
672
+ for blk_idx, blk in enumerate(self.blocks):
673
+ x = blk(x)
674
+ x = self.norm(x)
675
+
676
+ return x, mask, ids_restore, ids_keep
677
+
678
+ def forward_decoder(self, x, ids_restore):
679
+ # embed tokens
680
+ x = self.decoder_embed(x)
681
+ # non_mask_token = x
682
+
683
+ # append mask tokens to sequence
684
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
685
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
686
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
687
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
688
+
689
+ # add pos embed
690
+ x = x + self.decoder_pos_embed
691
+
692
+ # apply Transformer blocks
693
+ decoder_feat = []
694
+ for idx, blk in enumerate(self.decoder_blocks):
695
+ x = blk(x)
696
+ if idx == self.mae_decoder_depth // 2:
697
+ decoder_feat.append(x)
698
+
699
+ x = self.decoder_norm(x)
700
+
701
+ # use the output from decoder to do captioning
702
+
703
+ # predictor projection
704
+ x = self.decoder_pred(x)
705
+
706
+ # remove cls token
707
+ x = x[:, 1:, :]
708
+
709
+ return x, decoder_feat
710
+
711
+ def forward_loss(self, imgs, pred, mask):
712
+ """
713
+ imgs: [N, 3, H, W]
714
+ pred: [N, L, p*p*3]
715
+ mask: [N, L], 0 is keep, 1 is remove,
716
+ """
717
+ target = self.patchify(imgs)
718
+ if self.norm_pix_loss:
719
+ mean = target.mean(dim=-1, keepdim=True)
720
+ var = target.var(dim=-1, keepdim=True)
721
+ target = (target - mean) / (var + 1.e-6)**.5
722
+
723
+ loss = (pred - target) ** 2
724
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
725
+
726
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
727
+ return loss
728
+
729
+ def embed_text(self, text):
730
+ batch, device = text.shape[0], text.device
731
+
732
+ seq = text.shape[1]
733
+
734
+ text_tokens = self.token_emb(text)
735
+
736
+ # append text cls tokens
737
+ text_cls_tokens = repeat(self.text_cls_token, 'd -> b 1 d', b=batch)
738
+ text_tokens = torch.cat((text_tokens, text_cls_tokens), dim=-2)
739
+
740
+ # create specific mask for text cls token at the end
741
+ # to prevent it from attending to padding
742
+ cls_mask = rearrange(text != self.pad_id, 'b j -> b 1 j')
743
+ attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
744
+
745
+ # go through unimodal layers
746
+ for attn_ff in self.unimodal_layers:
747
+ text_tokens = attn_ff(text_tokens, attn_mask=attn_mask)
748
+
749
+ if self.need_uni_2_mul_proj:
750
+ text_tokens = self.uni_2_mul_proj(text_tokens)
751
+
752
+ # get text cls token
753
+ text_tokens, text_cls_tokens = text_tokens[:, :-1], text_tokens[:, -1]
754
+ return text_tokens
755
+
756
+
757
+
758
+ def forward(self, imgs, caption_ids=None, attention_mask=None, mask_ratio=0.75,
759
+ freeze_bert=False, teacher_forcing=False, caption_only=False,
760
+ encoder_only=False, word_weights=None, syn_count=None):
761
+ latent, mask, ids_restore, ids_keep = self.forward_encoder(imgs, mask_ratio)
762
+
763
+ if not caption_only:
764
+ pred, decoder_feat = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
765
+ mae_loss = self.forward_loss(imgs, pred, mask)
766
+ else:
767
+ mae_loss = 0.
768
+
769
+ if self.latent_projector_layer != 0:
770
+ latent = self.latent_projector(latent)
771
+
772
+ # latent: visual info: N, L, C
773
+ # caption_ids: N, Len
774
+ text, labels = caption_ids[:, :-1], caption_ids[:, 1:]
775
+
776
+ seq = text.shape[1]
777
+ text_tokens = self.embed_text(text) # N, Len, C
778
+
779
+ # create specific mask for text cls token at the end
780
+ # to prevent it from attending to padding
781
+ cls_mask = rearrange(text != self.pad_id, 'b j -> b 1 j')
782
+ attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
783
+ unimodal_text_tokens = text_tokens
784
+ if not self.less_u:
785
+ for attn_ff, cross_attn in self.multimodal_layers:
786
+ text_tokens = attn_ff(text_tokens, attn_mask=attn_mask[:, :-1, :-1])
787
+ text_tokens = cross_attn(text_tokens, latent)
788
+ else:
789
+ # dim, num_head,
790
+ for cross_attn1, cross_attn2 in self.multimodal_layers:
791
+ text_tokens = cross_attn1(text_tokens, latent)
792
+ text_tokens = cross_attn2(text_tokens, latent)
793
+
794
+ logits = self.to_logits(text_tokens) # N, Len, NVocab
795
+ logits = logits.reshape(-1, len(self.tokenizer.vocab))
796
+ labels = labels.reshape(-1)
797
+
798
+ caption_loss = F.cross_entropy(logits, labels, ignore_index=self.pad_id,)
799
+
800
+
801
+ return mae_loss, caption_loss, None
802
+
803
+
804
+
805
+ def mae_vit_small_patch16_dec512d8b(**kwargs):
806
+ model = MaskedAutoencoderViT(
807
+ patch_size=16, embed_dim=384, depth=12, num_heads=6,
808
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
809
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
810
+ return model
811
+
812
+
813
+
814
+ def mae_vit_base_patch16_dec512d8b(**kwargs):
815
+ model = MaskedAutoencoderViT(
816
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
817
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
818
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
819
+ return model
820
+
821
+ def mae_vit_large_patch16_dec512d8b(**kwargs):
822
+ model = MaskedAutoencoderViT(
823
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
824
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
825
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
826
+ return model
827
+
828
+
829
+ def mae_vit_huge_patch14_dec512d8b(**kwargs):
830
+ model = MaskedAutoencoderViT(
831
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16,
832
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
833
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
834
+ return model
835
+
836
+
837
+ # set recommended archs
838
+ mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b
839
+ mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
840
+ mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
841
+ mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
842
+
843
+
844
+
845
+
846
+