Hila commited on
Commit
9f7f854
1 Parent(s): 05d981d

Upload ViT_new.py

Browse files
Files changed (1) hide show
  1. ViT_new.py +975 -0
ViT_new.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+
3
+ A PyTorch implement of Vision Transformers as described in:
4
+
5
+ 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
6
+ - https://arxiv.org/abs/2010.11929
7
+
8
+ `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
9
+ - https://arxiv.org/abs/2106.10270
10
+
11
+ The official jax code is released and available at https://github.com/google-research/vision_transformer
12
+
13
+ DeiT model defs and weights from https://github.com/facebookresearch/deit,
14
+ paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
15
+
16
+ Acknowledgments:
17
+ * The paper authors for releasing code and weights, thanks!
18
+ * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
19
+ for some einops/einsum fun
20
+ * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
21
+ * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
22
+
23
+ Hacked together by / Copyright 2020, Ross Wightman
24
+ """
25
+ import math
26
+ import logging
27
+ from functools import partial
28
+ from collections import OrderedDict
29
+ from copy import deepcopy
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
36
+ from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv
37
+ from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
38
+ from timm.models.registry import register_model
39
+
40
+ _logger = logging.getLogger(__name__)
41
+
42
+
43
+ def _cfg(url='', **kwargs):
44
+ return {
45
+ 'url': url,
46
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
47
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
48
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
49
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
50
+ **kwargs
51
+ }
52
+
53
+
54
+ default_cfgs = {
55
+ # patch models (weights from official Google JAX impl)
56
+ 'vit_tiny_patch16_224': _cfg(
57
+ url='https://storage.googleapis.com/vit_models/augreg/'
58
+ 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
59
+ 'vit_tiny_patch16_384': _cfg(
60
+ url='https://storage.googleapis.com/vit_models/augreg/'
61
+ 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
62
+ input_size=(3, 384, 384), crop_pct=1.0),
63
+ 'vit_small_patch32_224': _cfg(
64
+ url='https://storage.googleapis.com/vit_models/augreg/'
65
+ 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
66
+ 'vit_small_patch32_384': _cfg(
67
+ url='https://storage.googleapis.com/vit_models/augreg/'
68
+ 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
69
+ input_size=(3, 384, 384), crop_pct=1.0),
70
+ 'vit_small_patch16_224': _cfg(
71
+ url='https://storage.googleapis.com/vit_models/augreg/'
72
+ 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
73
+ 'vit_small_patch16_384': _cfg(
74
+ url='https://storage.googleapis.com/vit_models/augreg/'
75
+ 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
76
+ input_size=(3, 384, 384), crop_pct=1.0),
77
+ 'vit_base_patch32_224': _cfg(
78
+ url='https://storage.googleapis.com/vit_models/augreg/'
79
+ 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
80
+ 'vit_base_patch32_384': _cfg(
81
+ url='https://storage.googleapis.com/vit_models/augreg/'
82
+ 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
83
+ input_size=(3, 384, 384), crop_pct=1.0),
84
+ 'vit_base_patch16_224': _cfg(
85
+ url='https://storage.googleapis.com/vit_models/augreg/'
86
+ 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
87
+ 'vit_base_patch16_384': _cfg(
88
+ url='https://storage.googleapis.com/vit_models/augreg/'
89
+ 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
90
+ input_size=(3, 384, 384), crop_pct=1.0),
91
+ 'vit_base_patch8_224': _cfg(
92
+ url='https://storage.googleapis.com/vit_models/augreg/'
93
+ 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
94
+ 'vit_large_patch32_224': _cfg(
95
+ url='', # no official model weights for this combo, only for in21k
96
+ ),
97
+ 'vit_large_patch32_384': _cfg(
98
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
99
+ input_size=(3, 384, 384), crop_pct=1.0),
100
+ 'vit_large_patch16_224': _cfg(
101
+ url='https://storage.googleapis.com/vit_models/augreg/'
102
+ 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
103
+ 'vit_large_patch16_384': _cfg(
104
+ url='https://storage.googleapis.com/vit_models/augreg/'
105
+ 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
106
+ input_size=(3, 384, 384), crop_pct=1.0),
107
+
108
+ 'vit_huge_patch14_224': _cfg(url=''),
109
+ 'vit_giant_patch14_224': _cfg(url=''),
110
+ 'vit_gigantic_patch14_224': _cfg(url=''),
111
+
112
+ # patch models, imagenet21k (weights from official Google JAX impl)
113
+ 'vit_tiny_patch16_224_in21k': _cfg(
114
+ url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
115
+ num_classes=21843),
116
+ 'vit_small_patch32_224_in21k': _cfg(
117
+ url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
118
+ num_classes=21843),
119
+ 'vit_small_patch16_224_in21k': _cfg(
120
+ url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
121
+ num_classes=21843),
122
+ 'vit_base_patch32_224_in21k': _cfg(
123
+ url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
124
+ num_classes=21843),
125
+ 'vit_base_patch16_224_in21k': _cfg(
126
+ url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
127
+ num_classes=21843),
128
+ 'vit_base_patch8_224_in21k': _cfg(
129
+ url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
130
+ num_classes=21843),
131
+ 'vit_large_patch32_224_in21k': _cfg(
132
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
133
+ num_classes=21843),
134
+ 'vit_large_patch16_224_in21k': _cfg(
135
+ url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
136
+ num_classes=21843),
137
+ 'vit_huge_patch14_224_in21k': _cfg(
138
+ url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
139
+ hf_hub='timm/vit_huge_patch14_224_in21k',
140
+ num_classes=21843),
141
+
142
+ # SAM trained models (https://arxiv.org/abs/2106.01548)
143
+ 'vit_base_patch32_sam_224': _cfg(
144
+ url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'),
145
+ 'vit_base_patch16_sam_224': _cfg(
146
+ url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'),
147
+
148
+ # deit models (FB weights)
149
+ 'deit_tiny_patch16_224': _cfg(
150
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
151
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
152
+ 'deit_small_patch16_224': _cfg(
153
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
154
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
155
+ 'deit_base_patch16_224': _cfg(
156
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
157
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
158
+ 'deit_base_patch16_384': _cfg(
159
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
160
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
161
+ 'deit_tiny_distilled_patch16_224': _cfg(
162
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
163
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
164
+ 'deit_small_distilled_patch16_224': _cfg(
165
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
166
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
167
+ 'deit_base_distilled_patch16_224': _cfg(
168
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
169
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
170
+ 'deit_base_distilled_patch16_384': _cfg(
171
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
172
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
173
+ classifier=('head', 'head_dist')),
174
+
175
+ # ViT ImageNet-21K-P pretraining by MILL
176
+ 'vit_base_patch16_224_miil_in21k': _cfg(
177
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
178
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
179
+ ),
180
+ 'vit_base_patch16_224_miil': _cfg(
181
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
182
+ '/vit_base_patch16_224_1k_miil_84_4.pth',
183
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
184
+ ),
185
+ }
186
+
187
+
188
+ class Attention(nn.Module):
189
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
190
+ super().__init__()
191
+ self.num_heads = num_heads
192
+ head_dim = dim // num_heads
193
+ self.scale = head_dim ** -0.5
194
+
195
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
196
+ self.attn_drop = nn.Dropout(attn_drop)
197
+ self.proj = nn.Linear(dim, dim)
198
+ self.proj_drop = nn.Dropout(proj_drop)
199
+
200
+ self.attn_gradients = None
201
+ self.attention_map = None
202
+
203
+ def save_attn_gradients(self, attn_gradients):
204
+ self.attn_gradients = attn_gradients
205
+
206
+ def get_attn_gradients(self):
207
+ return self.attn_gradients
208
+
209
+ def save_attention_map(self, attention_map):
210
+ self.attention_map = attention_map
211
+
212
+ def get_attention_map(self):
213
+ return self.attention_map
214
+
215
+ def forward(self, x, register_hook=False):
216
+ B, N, C = x.shape
217
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
218
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
219
+
220
+ attn = (q @ k.transpose(-2, -1)) * self.scale
221
+ attn = attn.softmax(dim=-1)
222
+ attn = self.attn_drop(attn)
223
+
224
+ self.save_attention_map(attn)
225
+ if register_hook:
226
+ attn.register_hook(self.save_attn_gradients)
227
+
228
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
229
+ x = self.proj(x)
230
+ x = self.proj_drop(x)
231
+ return x
232
+
233
+
234
+ class Block(nn.Module):
235
+
236
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
237
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
238
+ super().__init__()
239
+ self.norm1 = norm_layer(dim)
240
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
241
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
242
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
243
+ self.norm2 = norm_layer(dim)
244
+ mlp_hidden_dim = int(dim * mlp_ratio)
245
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
246
+
247
+ def forward(self, x, register_hook=False):
248
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
249
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
250
+ return x
251
+
252
+
253
+ class VisionTransformer(nn.Module):
254
+ """ Vision Transformer
255
+
256
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
257
+ - https://arxiv.org/abs/2010.11929
258
+
259
+ Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
260
+ - https://arxiv.org/abs/2012.12877
261
+ """
262
+
263
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
264
+ num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
265
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
266
+ act_layer=None, weight_init=''):
267
+ """
268
+ Args:
269
+ img_size (int, tuple): input image size
270
+ patch_size (int, tuple): patch size
271
+ in_chans (int): number of input channels
272
+ num_classes (int): number of classes for classification head
273
+ embed_dim (int): embedding dimension
274
+ depth (int): depth of transformer
275
+ num_heads (int): number of attention heads
276
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
277
+ qkv_bias (bool): enable bias for qkv if True
278
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
279
+ distilled (bool): model includes a distillation token and head as in DeiT models
280
+ drop_rate (float): dropout rate
281
+ attn_drop_rate (float): attention dropout rate
282
+ drop_path_rate (float): stochastic depth rate
283
+ embed_layer (nn.Module): patch embedding layer
284
+ norm_layer: (nn.Module): normalization layer
285
+ weight_init: (str): weight init scheme
286
+ """
287
+ super().__init__()
288
+ self.num_classes = num_classes
289
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
290
+ self.num_tokens = 2 if distilled else 1
291
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
292
+ act_layer = act_layer or nn.GELU
293
+
294
+ self.patch_embed = embed_layer(
295
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
296
+ num_patches = self.patch_embed.num_patches
297
+
298
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
299
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
300
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
301
+ self.pos_drop = nn.Dropout(p=drop_rate)
302
+
303
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
304
+ self.blocks = nn.ModuleList([Block(
305
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
306
+ attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
307
+ for i in range(depth)])
308
+ self.norm = norm_layer(embed_dim)
309
+
310
+ # Representation layer
311
+ if representation_size and not distilled:
312
+ self.num_features = representation_size
313
+ self.pre_logits = nn.Sequential(OrderedDict([
314
+ ('fc', nn.Linear(embed_dim, representation_size)),
315
+ ('act', nn.Tanh())
316
+ ]))
317
+ else:
318
+ self.pre_logits = nn.Identity()
319
+
320
+ # Classifier head(s)
321
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
322
+ self.head_dist = None
323
+ if distilled:
324
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
325
+
326
+ self.init_weights(weight_init)
327
+
328
+ def init_weights(self, mode=''):
329
+ assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
330
+ head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
331
+ trunc_normal_(self.pos_embed, std=.02)
332
+ if self.dist_token is not None:
333
+ trunc_normal_(self.dist_token, std=.02)
334
+ if mode.startswith('jax'):
335
+ # leave cls token as zeros to match jax impl
336
+ named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
337
+ else:
338
+ trunc_normal_(self.cls_token, std=.02)
339
+ self.apply(_init_vit_weights)
340
+
341
+ def _init_weights(self, m):
342
+ # this fn left here for compat with downstream users
343
+ _init_vit_weights(m)
344
+
345
+ @torch.jit.ignore()
346
+ def load_pretrained(self, checkpoint_path, prefix=''):
347
+ _load_weights(self, checkpoint_path, prefix)
348
+
349
+ @torch.jit.ignore
350
+ def no_weight_decay(self):
351
+ return {'pos_embed', 'cls_token', 'dist_token'}
352
+
353
+ def get_classifier(self):
354
+ if self.dist_token is None:
355
+ return self.head
356
+ else:
357
+ return self.head, self.head_dist
358
+
359
+ def reset_classifier(self, num_classes, global_pool=''):
360
+ self.num_classes = num_classes
361
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
362
+ if self.num_tokens == 2:
363
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
364
+
365
+ def forward_features(self, x, register_hook=False):
366
+ x = self.patch_embed(x)
367
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
368
+ if self.dist_token is None:
369
+ x = torch.cat((cls_token, x), dim=1)
370
+ else:
371
+ x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
372
+ x = self.pos_drop(x + self.pos_embed)
373
+ # x = self.blocks(x)
374
+ for blk in self.blocks:
375
+ x = blk(x, register_hook=register_hook)
376
+ x = self.norm(x)
377
+ if self.dist_token is None:
378
+ return self.pre_logits(x[:, 0])
379
+ else:
380
+ return x[:, 0], x[:, 1]
381
+
382
+ def forward(self, x, register_hook=False):
383
+ x = self.forward_features(x, register_hook=register_hook)
384
+ if self.head_dist is not None:
385
+ x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
386
+ if self.training and not torch.jit.is_scripting():
387
+ # during inference, return the average of both classifier predictions
388
+ return x, x_dist
389
+ else:
390
+ return (x + x_dist) / 2
391
+ else:
392
+ x = self.head(x)
393
+ return x
394
+
395
+
396
+ def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
397
+ """ ViT weight initialization
398
+ * When called without n, head_bias, jax_impl args it will behave exactly the same
399
+ as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
400
+ * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
401
+ """
402
+ if isinstance(module, nn.Linear):
403
+ if name.startswith('head'):
404
+ nn.init.zeros_(module.weight)
405
+ nn.init.constant_(module.bias, head_bias)
406
+ elif name.startswith('pre_logits'):
407
+ lecun_normal_(module.weight)
408
+ nn.init.zeros_(module.bias)
409
+ else:
410
+ if jax_impl:
411
+ nn.init.xavier_uniform_(module.weight)
412
+ if module.bias is not None:
413
+ if 'mlp' in name:
414
+ nn.init.normal_(module.bias, std=1e-6)
415
+ else:
416
+ nn.init.zeros_(module.bias)
417
+ else:
418
+ trunc_normal_(module.weight, std=.02)
419
+ if module.bias is not None:
420
+ nn.init.zeros_(module.bias)
421
+ elif jax_impl and isinstance(module, nn.Conv2d):
422
+ # NOTE conv was left to pytorch default in my original init
423
+ lecun_normal_(module.weight)
424
+ if module.bias is not None:
425
+ nn.init.zeros_(module.bias)
426
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
427
+ nn.init.zeros_(module.bias)
428
+ nn.init.ones_(module.weight)
429
+
430
+
431
+ @torch.no_grad()
432
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
433
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
434
+ """
435
+ import numpy as np
436
+
437
+ def _n2p(w, t=True):
438
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
439
+ w = w.flatten()
440
+ if t:
441
+ if w.ndim == 4:
442
+ w = w.transpose([3, 2, 0, 1])
443
+ elif w.ndim == 3:
444
+ w = w.transpose([2, 0, 1])
445
+ elif w.ndim == 2:
446
+ w = w.transpose([1, 0])
447
+ return torch.from_numpy(w)
448
+
449
+ w = np.load(checkpoint_path)
450
+ if not prefix and 'opt/target/embedding/kernel' in w:
451
+ prefix = 'opt/target/'
452
+
453
+ if hasattr(model.patch_embed, 'backbone'):
454
+ # hybrid
455
+ backbone = model.patch_embed.backbone
456
+ stem_only = not hasattr(backbone, 'stem')
457
+ stem = backbone if stem_only else backbone.stem
458
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
459
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
460
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
461
+ if not stem_only:
462
+ for i, stage in enumerate(backbone.stages):
463
+ for j, block in enumerate(stage.blocks):
464
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
465
+ for r in range(3):
466
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
467
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
468
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
469
+ if block.downsample is not None:
470
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
471
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
472
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
473
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
474
+ else:
475
+ embed_conv_w = adapt_input_conv(
476
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
477
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
478
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
479
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
480
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
481
+ if pos_embed_w.shape != model.pos_embed.shape:
482
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
483
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
484
+ model.pos_embed.copy_(pos_embed_w)
485
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
486
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
487
+ if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
488
+ model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
489
+ model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
490
+ if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
491
+ model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
492
+ model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
493
+ for i, block in enumerate(model.blocks.children()):
494
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
495
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
496
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
497
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
498
+ block.attn.qkv.weight.copy_(torch.cat([
499
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
500
+ block.attn.qkv.bias.copy_(torch.cat([
501
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
502
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
503
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
504
+ for r in range(2):
505
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
506
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
507
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
508
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
509
+
510
+
511
+ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
512
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
513
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
514
+ _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
515
+ ntok_new = posemb_new.shape[1]
516
+ if num_tokens:
517
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
518
+ ntok_new -= num_tokens
519
+ else:
520
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
521
+ gs_old = int(math.sqrt(len(posemb_grid)))
522
+ if not len(gs_new): # backwards compatibility
523
+ gs_new = [int(math.sqrt(ntok_new))] * 2
524
+ assert len(gs_new) >= 2
525
+ _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
526
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
527
+ posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
528
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
529
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
530
+ return posemb
531
+
532
+
533
+ def checkpoint_filter_fn(state_dict, model):
534
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
535
+ out_dict = {}
536
+ if 'model' in state_dict:
537
+ # For deit models
538
+ state_dict = state_dict['model']
539
+ for k, v in state_dict.items():
540
+ if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
541
+ # For old models that I trained prior to conv based patchification
542
+ O, I, H, W = model.patch_embed.proj.weight.shape
543
+ v = v.reshape(O, -1, H, W)
544
+ elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
545
+ # To resize pos embedding when using model at different size from pretrained weights
546
+ v = resize_pos_embed(
547
+ v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
548
+ out_dict[k] = v
549
+ return out_dict
550
+
551
+
552
+ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
553
+ default_cfg = default_cfg or default_cfgs[variant]
554
+ if kwargs.get('features_only', None):
555
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
556
+
557
+ # NOTE this extra code to support handling of repr size for in21k pretrained models
558
+ default_num_classes = default_cfg['num_classes']
559
+ num_classes = kwargs.get('num_classes', default_num_classes)
560
+ repr_size = kwargs.pop('representation_size', None)
561
+ if repr_size is not None and num_classes != default_num_classes:
562
+ # Remove representation layer if fine-tuning. This may not always be the desired action,
563
+ # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
564
+ _logger.warning("Removing representation layer for fine-tuning.")
565
+ repr_size = None
566
+
567
+ model = build_model_with_cfg(
568
+ VisionTransformer, variant, pretrained,
569
+ default_cfg=default_cfg,
570
+ representation_size=repr_size,
571
+ pretrained_filter_fn=checkpoint_filter_fn,
572
+ pretrained_custom_load='npz' in default_cfg['url'],
573
+ **kwargs)
574
+ return model
575
+
576
+
577
+ @register_model
578
+ def vit_tiny_patch16_224(pretrained=False, **kwargs):
579
+ """ ViT-Tiny (Vit-Ti/16)
580
+ """
581
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
582
+ model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
583
+ return model
584
+
585
+
586
+ @register_model
587
+ def vit_tiny_patch16_384(pretrained=False, **kwargs):
588
+ """ ViT-Tiny (Vit-Ti/16) @ 384x384.
589
+ """
590
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
591
+ model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
592
+ return model
593
+
594
+
595
+ @register_model
596
+ def vit_small_patch32_224(pretrained=False, **kwargs):
597
+ """ ViT-Small (ViT-S/32)
598
+ """
599
+ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
600
+ model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
601
+ return model
602
+
603
+
604
+ @register_model
605
+ def vit_small_patch32_384(pretrained=False, **kwargs):
606
+ """ ViT-Small (ViT-S/32) at 384x384.
607
+ """
608
+ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
609
+ model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs)
610
+ return model
611
+
612
+
613
+ @register_model
614
+ def vit_small_patch16_224(pretrained=False, **kwargs):
615
+ """ ViT-Small (ViT-S/16)
616
+ NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
617
+ """
618
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
619
+ model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
620
+ return model
621
+
622
+
623
+ @register_model
624
+ def vit_small_patch16_384(pretrained=False, **kwargs):
625
+ """ ViT-Small (ViT-S/16)
626
+ NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
627
+ """
628
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
629
+ model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
630
+ return model
631
+
632
+
633
+ @register_model
634
+ def vit_base_patch32_224(pretrained=False, **kwargs):
635
+ """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
636
+ ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
637
+ """
638
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
639
+ model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
640
+ return model
641
+
642
+
643
+ @register_model
644
+ def vit_base_patch32_384(pretrained=False, **kwargs):
645
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
646
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
647
+ """
648
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
649
+ model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
650
+ return model
651
+
652
+
653
+ @register_model
654
+ def vit_base_patch16_224(pretrained=False, **kwargs):
655
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
656
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
657
+ """
658
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
659
+ model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
660
+ return model
661
+
662
+
663
+ @register_model
664
+ def vit_base_patch16_384(pretrained=False, **kwargs):
665
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
666
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
667
+ """
668
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
669
+ model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
670
+ return model
671
+
672
+
673
+ @register_model
674
+ def vit_base_patch8_224(pretrained=False, **kwargs):
675
+ """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
676
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
677
+ """
678
+ model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
679
+ model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs)
680
+ return model
681
+
682
+
683
+ @register_model
684
+ def vit_large_patch32_224(pretrained=False, **kwargs):
685
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
686
+ """
687
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
688
+ model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
689
+ return model
690
+
691
+
692
+ @register_model
693
+ def vit_large_patch32_384(pretrained=False, **kwargs):
694
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
695
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
696
+ """
697
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
698
+ model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
699
+ return model
700
+
701
+
702
+ @register_model
703
+ def vit_large_patch16_224(pretrained=False, **kwargs):
704
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
705
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
706
+ """
707
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
708
+ model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
709
+ return model
710
+
711
+
712
+ @register_model
713
+ def vit_large_patch16_384(pretrained=False, **kwargs):
714
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
715
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
716
+ """
717
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
718
+ model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
719
+ return model
720
+
721
+
722
+ @register_model
723
+ def vit_base_patch16_sam_224(pretrained=False, **kwargs):
724
+ """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
725
+ """
726
+ # NOTE original SAM weights release worked with representation_size=768
727
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs)
728
+ model = _create_vision_transformer('vit_base_patch16_sam_224', pretrained=pretrained, **model_kwargs)
729
+ return model
730
+
731
+
732
+ @register_model
733
+ def vit_base_patch32_sam_224(pretrained=False, **kwargs):
734
+ """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
735
+ """
736
+ # NOTE original SAM weights release worked with representation_size=768
737
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs)
738
+ model = _create_vision_transformer('vit_base_patch32_sam_224', pretrained=pretrained, **model_kwargs)
739
+ return model
740
+
741
+
742
+ @register_model
743
+ def vit_huge_patch14_224(pretrained=False, **kwargs):
744
+ """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
745
+ """
746
+ model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
747
+ model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs)
748
+ return model
749
+
750
+
751
+ @register_model
752
+ def vit_giant_patch14_224(pretrained=False, **kwargs):
753
+ """ ViT-Giant model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
754
+ """
755
+ model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs)
756
+ model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs)
757
+ return model
758
+
759
+
760
+ @register_model
761
+ def vit_gigantic_patch14_224(pretrained=False, **kwargs):
762
+ """ ViT-Gigantic model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
763
+ """
764
+ model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
765
+ model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs)
766
+ return model
767
+
768
+
769
+ @register_model
770
+ def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
771
+ """ ViT-Tiny (Vit-Ti/16).
772
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
773
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
774
+ """
775
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
776
+ model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
777
+ return model
778
+
779
+
780
+ @register_model
781
+ def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
782
+ """ ViT-Small (ViT-S/16)
783
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
784
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
785
+ """
786
+ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
787
+ model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
788
+ return model
789
+
790
+
791
+ @register_model
792
+ def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
793
+ """ ViT-Small (ViT-S/16)
794
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
795
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
796
+ """
797
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
798
+ model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
799
+ return model
800
+
801
+
802
+ @register_model
803
+ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
804
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
805
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
806
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
807
+ """
808
+ model_kwargs = dict(
809
+ patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
810
+ model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
811
+ return model
812
+
813
+
814
+ @register_model
815
+ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
816
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
817
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
818
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
819
+ """
820
+ model_kwargs = dict(
821
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
822
+ model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
823
+ return model
824
+
825
+
826
+ @register_model
827
+ def vit_base_patch8_224_in21k(pretrained=False, **kwargs):
828
+ """ ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
829
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
830
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
831
+ """
832
+ model_kwargs = dict(
833
+ patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
834
+ model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs)
835
+ return model
836
+
837
+
838
+ @register_model
839
+ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
840
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
841
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
842
+ NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
843
+ """
844
+ model_kwargs = dict(
845
+ patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
846
+ model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
847
+ return model
848
+
849
+
850
+ @register_model
851
+ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
852
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
853
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
854
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
855
+ """
856
+ model_kwargs = dict(
857
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
858
+ model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
859
+ return model
860
+
861
+
862
+ @register_model
863
+ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
864
+ """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
865
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
866
+ NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
867
+ """
868
+ model_kwargs = dict(
869
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
870
+ model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
871
+ return model
872
+
873
+
874
+ @register_model
875
+ def deit_tiny_patch16_224(pretrained=False, **kwargs):
876
+ """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
877
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
878
+ """
879
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
880
+ model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
881
+ return model
882
+
883
+
884
+ @register_model
885
+ def deit_small_patch16_224(pretrained=False, **kwargs):
886
+ """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
887
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
888
+ """
889
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
890
+ model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
891
+ return model
892
+
893
+
894
+ @register_model
895
+ def deit_base_patch16_224(pretrained=False, **kwargs):
896
+ """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
897
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
898
+ """
899
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
900
+ model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
901
+ return model
902
+
903
+
904
+ @register_model
905
+ def deit_base_patch16_384(pretrained=False, **kwargs):
906
+ """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
907
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
908
+ """
909
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
910
+ model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
911
+ return model
912
+
913
+
914
+ @register_model
915
+ def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
916
+ """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
917
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
918
+ """
919
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
920
+ model = _create_vision_transformer(
921
+ 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
922
+ return model
923
+
924
+
925
+ @register_model
926
+ def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
927
+ """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
928
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
929
+ """
930
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
931
+ model = _create_vision_transformer(
932
+ 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
933
+ return model
934
+
935
+
936
+ @register_model
937
+ def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
938
+ """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
939
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
940
+ """
941
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
942
+ model = _create_vision_transformer(
943
+ 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
944
+ return model
945
+
946
+
947
+ @register_model
948
+ def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
949
+ """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
950
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
951
+ """
952
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
953
+ model = _create_vision_transformer(
954
+ 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
955
+ return model
956
+
957
+
958
+ @register_model
959
+ def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
960
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
961
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
962
+ """
963
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
964
+ model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs)
965
+ return model
966
+
967
+
968
+ @register_model
969
+ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
970
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
971
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
972
+ """
973
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
974
+ model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
975
+ return model