robinzixuan commited on
Commit
7859efc
1 Parent(s): a3815b0

Upload vision_transformer.py

Browse files
Files changed (1) hide show
  1. vision_transformer.py +1853 -0
vision_transformer.py ADDED
@@ -0,0 +1,1853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+
3
+ A PyTorch implement of Vision Transformers as described in:
4
+
5
+ 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
6
+ - https://arxiv.org/abs/2010.11929
7
+
8
+ `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
9
+ - https://arxiv.org/abs/2106.10270
10
+
11
+ `FlexiViT: One Model for All Patch Sizes`
12
+ - https://arxiv.org/abs/2212.08013
13
+
14
+ The official jax code is released and available at
15
+ * https://github.com/google-research/vision_transformer
16
+ * https://github.com/google-research/big_vision
17
+
18
+ Acknowledgments:
19
+ * The paper authors for releasing code and weights, thanks!
20
+ * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch
21
+ * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
22
+ * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
23
+
24
+ Hacked together by / Copyright 2020, Ross Wightman
25
+ """
26
+ import logging
27
+ import math
28
+ from collections import OrderedDict
29
+ from functools import partial
30
+ from typing import Optional, List
31
+
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ import torch.utils.checkpoint
36
+
37
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
38
+ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
39
+ from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
40
+ resample_abs_pos_embed
41
+ from timm.models._builder import build_model_with_cfg
42
+ from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
43
+ from timm.models._pretrained import generate_default_cfgs
44
+ from timm.models._registry import register_model
45
+ import math
46
+ from functools import partial
47
+ from typing import Optional, Tuple
48
+ import argparse
49
+ import json
50
+ import logging
51
+ import os
52
+
53
+ import numpy as np
54
+ import torch
55
+ import torch.utils.checkpoint
56
+ import torch
57
+ import torch.nn as nn
58
+ import torch.nn.functional as F
59
+ import torch.utils.checkpoint
60
+ from torch.jit import Final
61
+ from quantization.utils import BaseEnumOptions
62
+ from transformers_language.models.softmax import clipped_softmax, clipped_softmax1
63
+
64
+ __all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this
65
+
66
+
67
+ _logger = logging.getLogger(__name__)
68
+
69
+
70
+ # import torch.nn.Function as F
71
+ # Set to True if exporting a model with Same padding via ONNX
72
+ _EXPORTABLE = False
73
+
74
+ # Set to True if wanting to use torch.jit.script on a model
75
+ _SCRIPTABLE = False
76
+
77
+
78
+ # use torch.scaled_dot_product_attention where possible
79
+ _HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
80
+ if 'TIMM_FUSED_ATTN' in os.environ:
81
+ _USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN'])
82
+ else:
83
+ _USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
84
+
85
+ def logit(p, eps=1e-16):
86
+ p = np.clip(p, eps, 1 - eps)
87
+ return -np.log(1 / p - 1)
88
+
89
+
90
+ class AttentionGateType(BaseEnumOptions):
91
+ none = 0
92
+ unconditional_per_head = 1
93
+ conditional_per_head = 2
94
+ conditional_per_token = 3
95
+
96
+ def use_fused_attn(experimental: bool = False) -> bool:
97
+ # NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0
98
+ if not _HAS_FUSED_ATTN or _EXPORTABLE:
99
+ return False
100
+ if experimental:
101
+ return _USE_FUSED_ATTN > 1
102
+ return _USE_FUSED_ATTN > 0
103
+
104
+ def scaled_dot_product_attention(query, key, value, softmax_fn, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
105
+ # Efficient implementation equivalent to the following:
106
+ device = "cuda" if torch.cuda.is_available() else "cpu"
107
+ L, S = query.size(-2), key.size(-2)
108
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
109
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
110
+ if is_causal:
111
+ assert attn_mask is None
112
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
113
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
114
+ attn_bias.to(query.dtype)
115
+
116
+ if attn_mask is not None:
117
+ if attn_mask.dtype == torch.bool:
118
+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
119
+ else:
120
+ attn_bias += attn_mask
121
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
122
+ attn_weight += attn_bias
123
+ attn_weight = softmax_fn(attn_weight, dim=-1)
124
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
125
+ return attn_weight @ value
126
+
127
+ class Attention(nn.Module):
128
+ fused_attn: Final[bool]
129
+
130
+ def __init__(
131
+ self,
132
+ dim: int,
133
+ num_heads: int = 8,
134
+ qkv_bias: bool = False,
135
+ qk_norm: bool = False,
136
+ attn_drop: float = 0.,
137
+ proj_drop: float = 0.,
138
+ norm_layer: nn.Module = nn.LayerNorm,
139
+ softmax_fn=torch.nn.functional.softmax,
140
+ gamma=None,
141
+ ssm_eps=None,
142
+ tau=None,
143
+ skip_attn=False,
144
+ attn_gate_type=AttentionGateType.none,
145
+ attn_gate_init=None,
146
+ attn_gate_mlp=False,
147
+ attn_gate_mlp2=False,
148
+ attn_gate_linear_all_features=False,
149
+ fine_tuning=False,
150
+ max_seq_length=None,
151
+
152
+ ) -> None:
153
+ super().__init__()
154
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
155
+ self.num_attention_heads = num_heads
156
+ self.attention_head_size = dim // num_heads
157
+ self.scale = self.attention_head_size ** -0.5
158
+ self.fused_attn = use_fused_attn()
159
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
160
+ self.q_norm = norm_layer(self.attention_head_size) if qk_norm else nn.Identity()
161
+ self.k_norm = norm_layer(self.attention_head_size) if qk_norm else nn.Identity()
162
+ self.attn_drop = nn.Dropout(attn_drop)
163
+ self.proj = nn.Linear(dim, dim)
164
+ self.proj_drop = nn.Dropout(proj_drop)
165
+
166
+ self.attn_scores = nn.Identity() # before attention mask
167
+ self.attn_probs_before_dropout = nn.Identity()
168
+ self.attn_probs_after_dropout = nn.Identity()
169
+
170
+ self.gamma = gamma
171
+ self.ssm_eps = ssm_eps
172
+ self.tau = tau
173
+ self.max_seq_length = max_seq_length
174
+
175
+ # define softmax function
176
+
177
+ self.softmax_fn = softmax_fn
178
+
179
+ self.skip_attn = skip_attn
180
+
181
+ # attention gating
182
+ self.last_gate_avg_prob = None
183
+ self.last_gate_all_probs = None
184
+
185
+ self.attn_gate_type = attn_gate_type
186
+ self.attn_gate_init = attn_gate_init
187
+ self.attn_gate_mlp = attn_gate_mlp
188
+ self.attn_gate_mlp2 = attn_gate_mlp2
189
+ self.attn_gate_linear_all_features = attn_gate_linear_all_features
190
+
191
+ self.alpha = None
192
+ self.gate_fn = torch.sigmoid
193
+ self.pooling_fn = partial(torch.mean, dim=1, keepdims=True)
194
+
195
+ self.fine_tuning = fine_tuning
196
+
197
+ # gate scaling factor
198
+ self.gate_scaling_factor = 1.0
199
+ if self.fine_tuning and self.attn_gate_init is not None:
200
+ self.gate_scaling_factor = 1.0 / self.attn_gate_init
201
+
202
+ # define gate
203
+ if self.attn_gate_type == AttentionGateType.unconditional_per_head:
204
+ init_alpha = torch.zeros(size=(self.num_attention_heads,))
205
+ self.alpha = nn.Parameter(init_alpha, requires_grad=True)
206
+
207
+ elif self.attn_gate_type in (
208
+ AttentionGateType.conditional_per_head,
209
+ AttentionGateType.conditional_per_token,
210
+ ):
211
+ if self.attn_gate_linear_all_features:
212
+ self.alpha = nn.Linear(self.all_head_size, self.num_attention_heads, bias=True)
213
+
214
+ else: # separate predictors for each head
215
+ module_list = []
216
+ for _ in range(self.num_attention_heads):
217
+ if self.attn_gate_mlp:
218
+ fc = nn.Sequential(
219
+ nn.Linear(
220
+ self.attention_head_size, self.attention_head_size // 4, bias=True
221
+ ),
222
+ nn.ReLU(),
223
+ nn.Linear(self.attention_head_size // 4, 1, bias=True),
224
+ )
225
+ elif self.attn_gate_mlp2:
226
+ fc = nn.Sequential(
227
+ nn.Linear(
228
+ self.attention_head_size, self.attention_head_size, bias=True
229
+ ),
230
+ nn.ReLU(),
231
+ nn.Linear(self.attention_head_size, 1, bias=True),
232
+ )
233
+ else:
234
+ fc = nn.Linear(self.attention_head_size, 1, bias=True)
235
+
236
+ if self.attn_gate_init is not None:
237
+ init_bias = logit(self.attn_gate_init)
238
+ torch.nn.init.constant_(fc.bias, init_bias)
239
+
240
+ if self.fine_tuning:
241
+ # init to a very small values
242
+ torch.nn.init.normal_(fc.weight, mean=0.0, std=0.01)
243
+
244
+ module_list.append(fc)
245
+ self.alpha = nn.ModuleList(module_list)
246
+
247
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
248
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
249
+ x = x.view(new_x_shape)
250
+ return x.permute(0, 2, 1, 3)
251
+
252
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
253
+ hidden_states = x
254
+ B, N, C = x.shape
255
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_attention_heads, self.attention_head_size).permute(2, 0, 3, 1, 4)
256
+ q, k, v = qkv.unbind(0)
257
+ q, k = self.q_norm(q), self.k_norm(k)
258
+
259
+ if self.fused_attn:
260
+ context_layer = scaled_dot_product_attention(
261
+ q, k, v, self.softmax_fn,
262
+ dropout_p=self.attn_drop.p if self.training else 0.,
263
+ )
264
+ else:
265
+ q = q * self.scale
266
+ attn = q @ k.transpose(-2, -1)
267
+
268
+ attn = self.softmax_fn(attn, dim=-1)
269
+ attn = self.attn_probs_before_dropout(attn)
270
+ attn = self.attn_drop(attn)
271
+ attn = self.attn_probs_after_dropout(attn)
272
+ context_layer = attn @ v
273
+
274
+
275
+ # *** Gating ***
276
+ if self.attn_gate_type == AttentionGateType.unconditional_per_head:
277
+ gate = self.gate_fn(self.alpha) # (H,)
278
+ context_layer *= gate.view(-1, 1, 1) # (B, H, T, d_head)
279
+
280
+ self.last_gate_avg_prob = gate.view(-1)
281
+
282
+ elif self.attn_gate_type in (
283
+ AttentionGateType.conditional_per_head,
284
+ AttentionGateType.conditional_per_token,
285
+ ):
286
+
287
+ x = hidden_states
288
+
289
+ if self.attn_gate_linear_all_features: # assume per_token
290
+ alpha = self.alpha(x) # (B, T, H)
291
+ gate = self.gate_fn(alpha)
292
+ gate = gate.permute(0, 2, 1).contiguous() # (B, H, T)
293
+ gate = gate.unsqueeze(3) # (B, H, T, 1)
294
+
295
+ else:
296
+ x = self.transpose_for_scores(x) # (B, H, T, d_head)
297
+
298
+ alpha = []
299
+ for head_idx in range(self.num_attention_heads):
300
+ x_head = x[:, head_idx, ...] # (B, T, d_head)
301
+ fc_head = self.alpha[head_idx]
302
+ alpha_head = fc_head(x_head) # (B, T, 1)
303
+ if self.attn_gate_type == AttentionGateType.conditional_per_head:
304
+ alpha_head = self.pooling_fn(alpha_head) # (B, 1, 1)
305
+ alpha.append(alpha_head)
306
+ alpha = torch.stack(alpha, dim=1) # (B, H, *, 1)
307
+ gate = self.gate_fn(alpha)
308
+
309
+ context_layer *= gate * self.gate_scaling_factor
310
+
311
+ self.last_gate_all_probs = gate # all gates to see the distributions
312
+ avg_gate = gate.mean(dim=0)
313
+ self.last_gate_avg_prob = avg_gate.view(self.num_attention_heads, -1).mean(dim=1)
314
+
315
+
316
+ x = context_layer.transpose(1, 2).reshape(B, N, C)
317
+ x = self.proj(x)
318
+ x = self.proj_drop(x)
319
+ return x
320
+
321
+
322
+ class LayerScale(nn.Module):
323
+ def __init__(self, dim, init_values=1e-5, inplace=False):
324
+ super().__init__()
325
+ self.inplace = inplace
326
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
327
+
328
+ def forward(self, x):
329
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
330
+
331
+
332
+ class Block(nn.Module):
333
+
334
+ def __init__(
335
+ self,
336
+ dim,
337
+ num_heads,
338
+ mlp_ratio=4.,
339
+ qkv_bias=False,
340
+ drop=0.,
341
+ attn_drop=0.,
342
+ init_values=None,
343
+ drop_path=0.,
344
+ act_layer=nn.GELU,
345
+ norm_layer=nn.LayerNorm
346
+ ):
347
+ super().__init__()
348
+ self.norm1 = norm_layer(dim)
349
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
350
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
351
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
352
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
353
+
354
+ self.norm2 = norm_layer(dim)
355
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
356
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
357
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
358
+
359
+ def forward(self, x):
360
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
361
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
362
+ return x
363
+
364
+
365
+ class ResPostBlock(nn.Module):
366
+
367
+ def __init__(
368
+ self,
369
+ dim,
370
+ num_heads,
371
+ mlp_ratio=4.,
372
+ qkv_bias=False,
373
+ drop=0.,
374
+ attn_drop=0.,
375
+ init_values=None,
376
+ drop_path=0.,
377
+ act_layer=nn.GELU,
378
+ norm_layer=nn.LayerNorm
379
+ ):
380
+ super().__init__()
381
+ self.init_values = init_values
382
+
383
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
384
+ self.norm1 = norm_layer(dim)
385
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
386
+
387
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
388
+ self.norm2 = norm_layer(dim)
389
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
390
+
391
+ self.init_weights()
392
+
393
+ def init_weights(self):
394
+ # NOTE this init overrides that base model init with specific changes for the block type
395
+ if self.init_values is not None:
396
+ nn.init.constant_(self.norm1.weight, self.init_values)
397
+ nn.init.constant_(self.norm2.weight, self.init_values)
398
+
399
+ def forward(self, x):
400
+ x = x + self.drop_path1(self.norm1(self.attn(x)))
401
+ x = x + self.drop_path2(self.norm2(self.mlp(x)))
402
+ return x
403
+
404
+
405
+ class ParallelBlock(nn.Module):
406
+
407
+ def __init__(
408
+ self,
409
+ dim,
410
+ num_heads,
411
+ num_parallel=2,
412
+ mlp_ratio=4.,
413
+ qkv_bias=False,
414
+ init_values=None,
415
+ drop=0.,
416
+ attn_drop=0.,
417
+ drop_path=0.,
418
+ act_layer=nn.GELU,
419
+ norm_layer=nn.LayerNorm
420
+ ):
421
+ super().__init__()
422
+ self.num_parallel = num_parallel
423
+ self.attns = nn.ModuleList()
424
+ self.ffns = nn.ModuleList()
425
+ for _ in range(num_parallel):
426
+ self.attns.append(nn.Sequential(OrderedDict([
427
+ ('norm', norm_layer(dim)),
428
+ ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)),
429
+ ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
430
+ ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
431
+ ])))
432
+ self.ffns.append(nn.Sequential(OrderedDict([
433
+ ('norm', norm_layer(dim)),
434
+ ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)),
435
+ ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
436
+ ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
437
+ ])))
438
+
439
+ def _forward_jit(self, x):
440
+ x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
441
+ x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
442
+ return x
443
+
444
+ @torch.jit.ignore
445
+ def _forward(self, x):
446
+ x = x + sum(attn(x) for attn in self.attns)
447
+ x = x + sum(ffn(x) for ffn in self.ffns)
448
+ return x
449
+
450
+ def forward(self, x):
451
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
452
+ return self._forward_jit(x)
453
+ else:
454
+ return self._forward(x)
455
+
456
+
457
+ class VisionTransformer(nn.Module):
458
+ """ Vision Transformer
459
+
460
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
461
+ - https://arxiv.org/abs/2010.11929
462
+ """
463
+
464
+ def __init__(
465
+ self,
466
+ img_size=224,
467
+ patch_size=16,
468
+ in_chans=3,
469
+ num_classes=1000,
470
+ global_pool='token',
471
+ embed_dim=768,
472
+ depth=12,
473
+ num_heads=12,
474
+ mlp_ratio=4.,
475
+ qkv_bias=True,
476
+ init_values=None,
477
+ class_token=True,
478
+ no_embed_class=False,
479
+ pre_norm=False,
480
+ fc_norm=None,
481
+ drop_rate=0.,
482
+ attn_drop_rate=0.,
483
+ drop_path_rate=0.,
484
+ weight_init='',
485
+ embed_layer=PatchEmbed,
486
+ norm_layer=None,
487
+ act_layer=None,
488
+ block_fn=Block,
489
+ ):
490
+ """
491
+ Args:
492
+ img_size (int, tuple): input image size
493
+ patch_size (int, tuple): patch size
494
+ in_chans (int): number of input channels
495
+ num_classes (int): number of classes for classification head
496
+ global_pool (str): type of global pooling for final sequence (default: 'token')
497
+ embed_dim (int): embedding dimension
498
+ depth (int): depth of transformer
499
+ num_heads (int): number of attention heads
500
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
501
+ qkv_bias (bool): enable bias for qkv if True
502
+ init_values: (float): layer-scale init values
503
+ class_token (bool): use class token
504
+ fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
505
+ drop_rate (float): dropout rate
506
+ attn_drop_rate (float): attention dropout rate
507
+ drop_path_rate (float): stochastic depth rate
508
+ weight_init (str): weight init scheme
509
+ embed_layer (nn.Module): patch embedding layer
510
+ norm_layer: (nn.Module): normalization layer
511
+ act_layer: (nn.Module): MLP activation layer
512
+ """
513
+ super().__init__()
514
+ assert global_pool in ('', 'avg', 'token')
515
+ assert class_token or global_pool != 'token'
516
+ use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
517
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
518
+ act_layer = act_layer or nn.GELU
519
+
520
+ self.num_classes = num_classes
521
+ self.global_pool = global_pool
522
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
523
+ self.num_prefix_tokens = 1 if class_token else 0
524
+ self.no_embed_class = no_embed_class
525
+ self.grad_checkpointing = False
526
+
527
+ self.patch_embed = embed_layer(
528
+ img_size=img_size,
529
+ patch_size=patch_size,
530
+ in_chans=in_chans,
531
+ embed_dim=embed_dim,
532
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
533
+ )
534
+ num_patches = self.patch_embed.num_patches
535
+
536
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
537
+ embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
538
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
539
+ self.pos_drop = nn.Dropout(p=drop_rate)
540
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
541
+
542
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
543
+ self.blocks = nn.Sequential(*[
544
+ block_fn(
545
+ dim=embed_dim,
546
+ num_heads=num_heads,
547
+ mlp_ratio=mlp_ratio,
548
+ qkv_bias=qkv_bias,
549
+ init_values=init_values,
550
+ drop=drop_rate,
551
+ attn_drop=attn_drop_rate,
552
+ drop_path=dpr[i],
553
+ norm_layer=norm_layer,
554
+ act_layer=act_layer
555
+ )
556
+ for i in range(depth)])
557
+ self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
558
+
559
+ # Classifier Head
560
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
561
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
562
+
563
+ if weight_init != 'skip':
564
+ self.init_weights(weight_init)
565
+
566
+ def init_weights(self, mode=''):
567
+ assert mode in ('jax', 'jax_nlhb', 'moco', '')
568
+ head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
569
+ trunc_normal_(self.pos_embed, std=.02)
570
+ if self.cls_token is not None:
571
+ nn.init.normal_(self.cls_token, std=1e-6)
572
+ named_apply(get_init_weights_vit(mode, head_bias), self)
573
+
574
+ def _init_weights(self, m):
575
+ # this fn left here for compat with downstream users
576
+ init_weights_vit_timm(m)
577
+
578
+ @torch.jit.ignore()
579
+ def load_pretrained(self, checkpoint_path, prefix=''):
580
+ _load_weights(self, checkpoint_path, prefix)
581
+
582
+ @torch.jit.ignore
583
+ def no_weight_decay(self):
584
+ return {'pos_embed', 'cls_token', 'dist_token'}
585
+
586
+ @torch.jit.ignore
587
+ def group_matcher(self, coarse=False):
588
+ return dict(
589
+ stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
590
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
591
+ )
592
+
593
+ @torch.jit.ignore
594
+ def set_grad_checkpointing(self, enable=True):
595
+ self.grad_checkpointing = enable
596
+
597
+ @torch.jit.ignore
598
+ def get_classifier(self):
599
+ return self.head
600
+
601
+ def reset_classifier(self, num_classes: int, global_pool=None):
602
+ self.num_classes = num_classes
603
+ if global_pool is not None:
604
+ assert global_pool in ('', 'avg', 'token')
605
+ self.global_pool = global_pool
606
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
607
+
608
+ def _pos_embed(self, x):
609
+ if self.no_embed_class:
610
+ # deit-3, updated JAX (big vision)
611
+ # position embedding does not overlap with class token, add then concat
612
+ x = x + self.pos_embed
613
+ if self.cls_token is not None:
614
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
615
+ else:
616
+ # original timm, JAX, and deit vit impl
617
+ # pos_embed has entry for class token, concat then add
618
+ if self.cls_token is not None:
619
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
620
+ x = x + self.pos_embed
621
+ return self.pos_drop(x)
622
+
623
+ def forward_features(self, x):
624
+ x = self.patch_embed(x)
625
+ x = self._pos_embed(x)
626
+ x = self.norm_pre(x)
627
+ if self.grad_checkpointing and not torch.jit.is_scripting():
628
+ x = checkpoint_seq(self.blocks, x)
629
+ else:
630
+ x = self.blocks(x)
631
+ x = self.norm(x)
632
+ return x
633
+
634
+ def forward_head(self, x, pre_logits: bool = False):
635
+ if self.global_pool:
636
+ x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
637
+ x = self.fc_norm(x)
638
+ return x if pre_logits else self.head(x)
639
+
640
+ def forward(self, x):
641
+ x = self.forward_features(x)
642
+ x = self.forward_head(x)
643
+ return x
644
+
645
+
646
+ def init_weights_vit_timm(module: nn.Module, name: str = ''):
647
+ """ ViT weight initialization, original timm impl (for reproducibility) """
648
+ if isinstance(module, nn.Linear):
649
+ trunc_normal_(module.weight, std=.02)
650
+ if module.bias is not None:
651
+ nn.init.zeros_(module.bias)
652
+ elif hasattr(module, 'init_weights'):
653
+ module.init_weights()
654
+
655
+
656
+ def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.):
657
+ """ ViT weight initialization, matching JAX (Flax) impl """
658
+ if isinstance(module, nn.Linear):
659
+ if name.startswith('head'):
660
+ nn.init.zeros_(module.weight)
661
+ nn.init.constant_(module.bias, head_bias)
662
+ else:
663
+ nn.init.xavier_uniform_(module.weight)
664
+ if module.bias is not None:
665
+ nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
666
+ elif isinstance(module, nn.Conv2d):
667
+ lecun_normal_(module.weight)
668
+ if module.bias is not None:
669
+ nn.init.zeros_(module.bias)
670
+ elif hasattr(module, 'init_weights'):
671
+ module.init_weights()
672
+
673
+
674
+ def init_weights_vit_moco(module: nn.Module, name: str = ''):
675
+ """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """
676
+ if isinstance(module, nn.Linear):
677
+ if 'qkv' in name:
678
+ # treat the weights of Q, K, V separately
679
+ val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
680
+ nn.init.uniform_(module.weight, -val, val)
681
+ else:
682
+ nn.init.xavier_uniform_(module.weight)
683
+ if module.bias is not None:
684
+ nn.init.zeros_(module.bias)
685
+ elif hasattr(module, 'init_weights'):
686
+ module.init_weights()
687
+
688
+
689
+ def get_init_weights_vit(mode='jax', head_bias: float = 0.):
690
+ if 'jax' in mode:
691
+ return partial(init_weights_vit_jax, head_bias=head_bias)
692
+ elif 'moco' in mode:
693
+ return init_weights_vit_moco
694
+ else:
695
+ return init_weights_vit_timm
696
+
697
+
698
+ def resize_pos_embed(
699
+ posemb,
700
+ posemb_new,
701
+ num_prefix_tokens=1,
702
+ gs_new=(),
703
+ interpolation='bicubic',
704
+ antialias=False,
705
+ ):
706
+ """ Rescale the grid of position embeddings when loading from state_dict.
707
+
708
+ *DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed
709
+
710
+ Adapted from:
711
+ https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
712
+ """
713
+ ntok_new = posemb_new.shape[1]
714
+ if num_prefix_tokens:
715
+ posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
716
+ ntok_new -= num_prefix_tokens
717
+ else:
718
+ posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
719
+ gs_old = int(math.sqrt(len(posemb_grid)))
720
+ if not len(gs_new): # backwards compatibility
721
+ gs_new = [int(math.sqrt(ntok_new))] * 2
722
+ assert len(gs_new) >= 2
723
+ _logger.info(f'Resized position embedding: {posemb.shape} ({[gs_old, gs_old]}) to {posemb_new.shape} ({gs_new}).')
724
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
725
+ posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=interpolation, antialias=antialias, align_corners=False)
726
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
727
+ posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
728
+ return posemb
729
+
730
+
731
+ @torch.no_grad()
732
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
733
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
734
+ """
735
+ import numpy as np
736
+
737
+ def _n2p(w, t=True):
738
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
739
+ w = w.flatten()
740
+ if t:
741
+ if w.ndim == 4:
742
+ w = w.transpose([3, 2, 0, 1])
743
+ elif w.ndim == 3:
744
+ w = w.transpose([2, 0, 1])
745
+ elif w.ndim == 2:
746
+ w = w.transpose([1, 0])
747
+ return torch.from_numpy(w)
748
+
749
+ w = np.load(checkpoint_path)
750
+ interpolation = 'bilinear'
751
+ antialias = False
752
+ big_vision = False
753
+ if not prefix:
754
+ if 'opt/target/embedding/kernel' in w:
755
+ prefix = 'opt/target/'
756
+ elif 'params/embedding/kernel' in w:
757
+ prefix = 'params/'
758
+ big_vision = True
759
+
760
+ if hasattr(model.patch_embed, 'backbone'):
761
+ # hybrid
762
+ backbone = model.patch_embed.backbone
763
+ stem_only = not hasattr(backbone, 'stem')
764
+ stem = backbone if stem_only else backbone.stem
765
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
766
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
767
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
768
+ if not stem_only:
769
+ for i, stage in enumerate(backbone.stages):
770
+ for j, block in enumerate(stage.blocks):
771
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
772
+ for r in range(3):
773
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
774
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
775
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
776
+ if block.downsample is not None:
777
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
778
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
779
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
780
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
781
+ else:
782
+ embed_conv_w = adapt_input_conv(
783
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
784
+ if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]:
785
+ embed_conv_w = resample_patch_embed(
786
+ embed_conv_w,
787
+ model.patch_embed.proj.weight.shape[-2:],
788
+ interpolation=interpolation,
789
+ antialias=antialias,
790
+ verbose=True,
791
+ )
792
+
793
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
794
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
795
+ if model.cls_token is not None:
796
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
797
+ if big_vision:
798
+ pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
799
+ else:
800
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
801
+ if pos_embed_w.shape != model.pos_embed.shape:
802
+ old_shape = pos_embed_w.shape
803
+ num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
804
+ pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
805
+ pos_embed_w,
806
+ new_size=model.patch_embed.grid_size,
807
+ num_prefix_tokens=num_prefix_tokens,
808
+ interpolation=interpolation,
809
+ antialias=antialias,
810
+ verbose=True,
811
+ )
812
+ model.pos_embed.copy_(pos_embed_w)
813
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
814
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
815
+ if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
816
+ model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
817
+ model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
818
+ # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
819
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
820
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
821
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
822
+ mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
823
+ for i, block in enumerate(model.blocks.children()):
824
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
825
+ mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
826
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
827
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
828
+ block.attn.qkv.weight.copy_(torch.cat([
829
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
830
+ block.attn.qkv.bias.copy_(torch.cat([
831
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
832
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
833
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
834
+ for r in range(2):
835
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
836
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
837
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
838
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
839
+
840
+
841
+ def _convert_openai_clip(state_dict, model):
842
+ out_dict = {}
843
+ swaps = [
844
+ ('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'),
845
+ ('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'),
846
+ ('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'),
847
+ ]
848
+ for k, v in state_dict.items():
849
+ if not k.startswith('visual.'):
850
+ continue
851
+ for sp in swaps:
852
+ k = k.replace(sp[0], sp[1])
853
+
854
+ if k == 'proj':
855
+ k = 'head.weight'
856
+ v = v.transpose(0, 1)
857
+ out_dict['head.bias'] = torch.zeros(v.shape[0])
858
+ elif k == 'class_embedding':
859
+ k = 'cls_token'
860
+ v = v.unsqueeze(0).unsqueeze(1)
861
+ elif k == 'pos_embed':
862
+ v = v.unsqueeze(0)
863
+ if v.shape[1] != model.pos_embed.shape[1]:
864
+ # To resize pos embedding when using model at different size from pretrained weights
865
+ v = resize_pos_embed(
866
+ v,
867
+ model.pos_embed,
868
+ 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
869
+ model.patch_embed.grid_size
870
+ )
871
+ out_dict[k] = v
872
+ return out_dict
873
+
874
+
875
+ def checkpoint_filter_fn(
876
+ state_dict,
877
+ model,
878
+ adapt_layer_scale=False,
879
+ interpolation='bicubic',
880
+ antialias=True,
881
+ ):
882
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
883
+ import re
884
+ out_dict = {}
885
+ if 'model' in state_dict:
886
+ # For deit models
887
+ state_dict = state_dict['model']
888
+
889
+ if 'visual.class_embedding' in state_dict:
890
+ return _convert_openai_clip(state_dict, model)
891
+
892
+ for k, v in state_dict.items():
893
+ if 'patch_embed.proj.weight' in k:
894
+ O, I, H, W = model.patch_embed.proj.weight.shape
895
+ if len(v.shape) < 4:
896
+ # For old models that I trained prior to conv based patchification
897
+ O, I, H, W = model.patch_embed.proj.weight.shape
898
+ v = v.reshape(O, -1, H, W)
899
+ if v.shape[-1] != W or v.shape[-2] != H:
900
+ v = resample_patch_embed(
901
+ v,
902
+ (H, W),
903
+ interpolation=interpolation,
904
+ antialias=antialias,
905
+ verbose=True,
906
+ )
907
+ elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
908
+ # To resize pos embedding when using model at different size from pretrained weights
909
+ num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
910
+ v = resample_abs_pos_embed(
911
+ v,
912
+ new_size=model.patch_embed.grid_size,
913
+ num_prefix_tokens=num_prefix_tokens,
914
+ interpolation=interpolation,
915
+ antialias=antialias,
916
+ verbose=True,
917
+ )
918
+ elif adapt_layer_scale and 'gamma_' in k:
919
+ # remap layer-scale gamma into sub-module (deit3 models)
920
+ k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
921
+ elif 'pre_logits' in k:
922
+ # NOTE representation layer removed as not used in latest 21k/1k pretrained weights
923
+ continue
924
+ out_dict[k] = v
925
+ return out_dict
926
+
927
+
928
+ def _cfg(url='', **kwargs):
929
+ return {
930
+ 'url': url,
931
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
932
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
933
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
934
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
935
+ **kwargs
936
+ }
937
+
938
+
939
+ default_cfgs = generate_default_cfgs({
940
+
941
+ # re-finetuned augreg 21k FT on in1k weights
942
+ 'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
943
+ hf_hub_id='timm/'),
944
+ 'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(),
945
+ 'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg(
946
+ hf_hub_id='timm/'),
947
+
948
+ # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k
949
+ 'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg(
950
+ url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
951
+ hf_hub_id='timm/',
952
+ custom_load=True),
953
+ 'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg(
954
+ url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
955
+ hf_hub_id='timm/',
956
+ custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
957
+ 'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg(
958
+ url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
959
+ hf_hub_id='timm/',
960
+ custom_load=True),
961
+ 'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg(
962
+ url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
963
+ hf_hub_id='timm/',
964
+ custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
965
+ 'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg(
966
+ url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
967
+ hf_hub_id='timm/',
968
+ custom_load=True),
969
+ 'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg(
970
+ url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
971
+ hf_hub_id='timm/',
972
+ custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
973
+ 'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg(
974
+ url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
975
+ hf_hub_id='timm/',
976
+ custom_load=True),
977
+ 'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg(
978
+ url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
979
+ hf_hub_id='timm/',
980
+ custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
981
+ 'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg(
982
+ url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
983
+ hf_hub_id='timm/',
984
+ custom_load=True),
985
+ 'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg(
986
+ url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
987
+ hf_hub_id='timm/',
988
+ custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
989
+ 'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg(
990
+ url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
991
+ hf_hub_id='timm/',
992
+ custom_load=True),
993
+ 'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg(
994
+ url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
995
+ hf_hub_id='timm/',
996
+ custom_load=True),
997
+ 'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg(
998
+ url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
999
+ hf_hub_id='timm/',
1000
+ custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
1001
+
1002
+ # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k
1003
+ 'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg(
1004
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
1005
+ hf_hub_id='timm/'),
1006
+ 'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg(
1007
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
1008
+ hf_hub_id='timm/',
1009
+ input_size=(3, 384, 384), crop_pct=1.0),
1010
+ 'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg(
1011
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
1012
+ hf_hub_id='timm/',
1013
+ input_size=(3, 384, 384), crop_pct=1.0),
1014
+
1015
+ # How to train your ViT (augreg) weights trained on in1k only
1016
+ 'vit_small_patch16_224.augreg_in1k': _cfg(
1017
+ url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
1018
+ hf_hub_id='timm/',
1019
+ custom_load=True),
1020
+ 'vit_small_patch16_384.augreg_in1k': _cfg(
1021
+ url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
1022
+ hf_hub_id='timm/',
1023
+ custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
1024
+ 'vit_base_patch32_224.augreg_in1k': _cfg(
1025
+ url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
1026
+ hf_hub_id='timm/',
1027
+ custom_load=True),
1028
+ 'vit_base_patch32_384.augreg_in1k': _cfg(
1029
+ url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
1030
+ hf_hub_id='timm/',
1031
+ custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
1032
+ 'vit_base_patch16_224.augreg_in1k': _cfg(
1033
+ url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
1034
+ hf_hub_id='timm/',
1035
+ custom_load=True),
1036
+ 'vit_base_patch16_384.augreg_in1k': _cfg(
1037
+ url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
1038
+ hf_hub_id='timm/',
1039
+ custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
1040
+
1041
+ 'vit_large_patch14_224.untrained': _cfg(url=''),
1042
+ 'vit_huge_patch14_224.untrained': _cfg(url=''),
1043
+ 'vit_giant_patch14_224.untrained': _cfg(url=''),
1044
+ 'vit_gigantic_patch14_224.untrained': _cfg(url=''),
1045
+
1046
+ # patch models, imagenet21k (weights from official Google JAX impl)
1047
+ 'vit_large_patch32_224.orig_in21k': _cfg(
1048
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
1049
+ hf_hub_id='timm/',
1050
+ num_classes=21843),
1051
+ 'vit_huge_patch14_224.orig_in21k': _cfg(
1052
+ url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
1053
+ hf_hub_id='timm/',
1054
+ custom_load=True, num_classes=21843),
1055
+
1056
+ # How to train your ViT (augreg) weights, pretrained on in21k
1057
+ 'vit_tiny_patch16_224.augreg_in21k': _cfg(
1058
+ url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
1059
+ hf_hub_id='timm/',
1060
+ custom_load=True, num_classes=21843),
1061
+ 'vit_small_patch32_224.augreg_in21k': _cfg(
1062
+ url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
1063
+ hf_hub_id='timm/',
1064
+ custom_load=True, num_classes=21843),
1065
+ 'vit_small_patch16_224.augreg_in21k': _cfg(
1066
+ url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
1067
+ hf_hub_id='timm/',
1068
+ custom_load=True, num_classes=21843),
1069
+ 'vit_base_patch32_224.augreg_in21k': _cfg(
1070
+ url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
1071
+ hf_hub_id='timm/',
1072
+ custom_load=True, num_classes=21843),
1073
+ 'vit_base_patch16_224.augreg_in21k': _cfg(
1074
+ url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
1075
+ hf_hub_id='timm/',
1076
+ custom_load=True, num_classes=21843),
1077
+ 'vit_base_patch8_224.augreg_in21k': _cfg(
1078
+ url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
1079
+ hf_hub_id='timm/',
1080
+ custom_load=True, num_classes=21843),
1081
+ 'vit_large_patch16_224.augreg_in21k': _cfg(
1082
+ url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
1083
+ hf_hub_id='timm/',
1084
+ custom_load=True, num_classes=21843),
1085
+
1086
+ # SAM trained models (https://arxiv.org/abs/2106.01548)
1087
+ 'vit_base_patch32_224.sam': _cfg(
1088
+ url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True,
1089
+ hf_hub_id='timm/'),
1090
+ 'vit_base_patch16_224.sam': _cfg(
1091
+ url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True,
1092
+ hf_hub_id='timm/'),
1093
+
1094
+ # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only)
1095
+ 'vit_small_patch16_224.dino': _cfg(
1096
+ url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth',
1097
+ hf_hub_id='timm/',
1098
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1099
+ 'vit_small_patch8_224.dino': _cfg(
1100
+ url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth',
1101
+ hf_hub_id='timm/',
1102
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1103
+ 'vit_base_patch16_224.dino': _cfg(
1104
+ url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth',
1105
+ hf_hub_id='timm/',
1106
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1107
+ 'vit_base_patch8_224.dino': _cfg(
1108
+ url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
1109
+ hf_hub_id='timm/',
1110
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
1111
+
1112
+ # ViT ImageNet-21K-P pretraining by MILL
1113
+ 'vit_base_patch16_224_miil.in21k': _cfg(
1114
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
1115
+ hf_hub_id='timm/',
1116
+ mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221),
1117
+ 'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg(
1118
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth',
1119
+ hf_hub_id='timm/',
1120
+ mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
1121
+
1122
+ # Custom timm variants
1123
+ 'vit_base_patch16_rpn_224.in1k': _cfg(
1124
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth',
1125
+ hf_hub_id='timm/'),
1126
+ 'vit_medium_patch16_gap_240.in12k': _cfg(
1127
+ hf_hub_id='timm/',
1128
+ input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821),
1129
+ 'vit_medium_patch16_gap_256.in12k_ft_in1k': _cfg(
1130
+ hf_hub_id='timm/',
1131
+ input_size=(3, 256, 256), crop_pct=0.95),
1132
+ 'vit_medium_patch16_gap_384.in12k_ft_in1k': _cfg(
1133
+ hf_hub_id='timm/',
1134
+ input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'),
1135
+ 'vit_base_patch16_gap_224': _cfg(),
1136
+
1137
+ # CLIP pretrained image tower and related fine-tuned weights
1138
+ 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg(
1139
+ hf_hub_id='timm/',
1140
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
1141
+ 'vit_base_patch32_clip_384.laion2b_ft_in12k_in1k': _cfg(
1142
+ hf_hub_id='timm/',
1143
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
1144
+ 'vit_base_patch32_clip_448.laion2b_ft_in12k_in1k': _cfg(
1145
+ hf_hub_id='timm/',
1146
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
1147
+ 'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg(
1148
+ hf_hub_id='timm/',
1149
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
1150
+ 'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg(
1151
+ hf_hub_id='timm/',
1152
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1153
+ crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
1154
+ 'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
1155
+ hf_hub_id='timm/',
1156
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
1157
+ 'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg(
1158
+ hf_hub_id='timm/',
1159
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
1160
+ crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
1161
+ 'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
1162
+ hf_hub_id='timm/',
1163
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
1164
+ 'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg(
1165
+ hf_hub_id='timm/',
1166
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1167
+ crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
1168
+
1169
+ 'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg(
1170
+ # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k',
1171
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
1172
+ 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg(
1173
+ hf_hub_id='timm/',
1174
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1175
+ crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
1176
+ 'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg(
1177
+ hf_hub_id='timm/',
1178
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
1179
+ 'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg(
1180
+ hf_hub_id='timm/',
1181
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1182
+ crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
1183
+ 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg(
1184
+ hf_hub_id='timm/',
1185
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
1186
+ 'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg(
1187
+ hf_hub_id='timm/',
1188
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1189
+ crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
1190
+
1191
+ 'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg(
1192
+ hf_hub_id='timm/',
1193
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
1194
+ 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg(
1195
+ hf_hub_id='timm/',
1196
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
1197
+ 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg(
1198
+ hf_hub_id='timm/',
1199
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1200
+ crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
1201
+ 'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg(
1202
+ hf_hub_id='timm/',
1203
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
1204
+ 'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg(
1205
+ hf_hub_id='timm/',
1206
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
1207
+ crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
1208
+ 'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg(
1209
+ hf_hub_id='timm/',
1210
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
1211
+ 'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg(
1212
+ hf_hub_id='',
1213
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1214
+ crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
1215
+
1216
+ 'vit_base_patch32_clip_224.openai_ft_in1k': _cfg(
1217
+ hf_hub_id='timm/',
1218
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
1219
+ 'vit_base_patch16_clip_224.openai_ft_in1k': _cfg(
1220
+ hf_hub_id='timm/',
1221
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
1222
+ 'vit_base_patch16_clip_384.openai_ft_in1k': _cfg(
1223
+ hf_hub_id='timm/',
1224
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1225
+ crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
1226
+ 'vit_large_patch14_clip_224.openai_ft_in1k': _cfg(
1227
+ hf_hub_id='timm/',
1228
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
1229
+
1230
+ 'vit_base_patch32_clip_224.laion2b_ft_in12k': _cfg(
1231
+ #hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k',
1232
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
1233
+ 'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg(
1234
+ hf_hub_id='timm/',
1235
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
1236
+ 'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg(
1237
+ hf_hub_id='timm/',
1238
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821),
1239
+ 'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg(
1240
+ hf_hub_id='timm/',
1241
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
1242
+
1243
+ 'vit_base_patch32_clip_224.openai_ft_in12k': _cfg(
1244
+ # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k',
1245
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
1246
+ 'vit_base_patch16_clip_224.openai_ft_in12k': _cfg(
1247
+ hf_hub_id='timm/',
1248
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
1249
+ 'vit_large_patch14_clip_224.openai_ft_in12k': _cfg(
1250
+ hf_hub_id='timm/',
1251
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
1252
+
1253
+ 'vit_base_patch32_clip_224.laion2b': _cfg(
1254
+ hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K',
1255
+ hf_hub_filename='open_clip_pytorch_model.bin',
1256
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
1257
+ 'vit_base_patch16_clip_224.laion2b': _cfg(
1258
+ # hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K',
1259
+ hf_hub_filename='open_clip_pytorch_model.bin',
1260
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
1261
+ 'vit_large_patch14_clip_224.laion2b': _cfg(
1262
+ hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
1263
+ hf_hub_filename='open_clip_pytorch_model.bin',
1264
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768),
1265
+ 'vit_huge_patch14_clip_224.laion2b': _cfg(
1266
+ hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
1267
+ hf_hub_filename='open_clip_pytorch_model.bin',
1268
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
1269
+ 'vit_giant_patch14_clip_224.laion2b': _cfg(
1270
+ hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
1271
+ hf_hub_filename='open_clip_pytorch_model.bin',
1272
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
1273
+
1274
+ 'vit_base_patch32_clip_224.openai': _cfg(
1275
+ hf_hub_id='timm/',
1276
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
1277
+ 'vit_base_patch16_clip_224.openai': _cfg(
1278
+ hf_hub_id='timm/',
1279
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
1280
+ 'vit_large_patch14_clip_224.openai': _cfg(
1281
+ hf_hub_id='timm/',
1282
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
1283
+
1284
+ # experimental (may be removed)
1285
+ 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
1286
+ 'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95),
1287
+ 'vit_small_patch16_36x1_224': _cfg(url=''),
1288
+ 'vit_small_patch16_18x2_224': _cfg(url=''),
1289
+ 'vit_base_patch16_18x2_224': _cfg(url=''),
1290
+
1291
+ # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain
1292
+ # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip
1293
+ 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg(
1294
+ # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt',
1295
+ hf_hub_id='timm/',
1296
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1297
+ input_size=(3, 196, 196), crop_pct=1.0),
1298
+ 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg(
1299
+ # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt',
1300
+ hf_hub_id='timm/',
1301
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1302
+ input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
1303
+ 'eva_large_patch14_196.in22k_ft_in1k': _cfg(
1304
+ # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt',
1305
+ hf_hub_id='timm/',
1306
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1307
+ input_size=(3, 196, 196), crop_pct=1.0),
1308
+ 'eva_large_patch14_336.in22k_ft_in1k': _cfg(
1309
+ # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt',
1310
+ hf_hub_id='timm/',
1311
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1312
+ input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
1313
+
1314
+ 'flexivit_small.1200ep_in1k': _cfg(
1315
+ url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz', custom_load=True,
1316
+ hf_hub_id='timm/',
1317
+ input_size=(3, 240, 240), crop_pct=0.95),
1318
+ 'flexivit_small.600ep_in1k': _cfg(
1319
+ url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz', custom_load=True,
1320
+ hf_hub_id='timm/',
1321
+ input_size=(3, 240, 240), crop_pct=0.95),
1322
+ 'flexivit_small.300ep_in1k': _cfg(
1323
+ url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz', custom_load=True,
1324
+ hf_hub_id='timm/',
1325
+ input_size=(3, 240, 240), crop_pct=0.95),
1326
+
1327
+ 'flexivit_base.1200ep_in1k': _cfg(
1328
+ url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz', custom_load=True,
1329
+ hf_hub_id='timm/',
1330
+ input_size=(3, 240, 240), crop_pct=0.95),
1331
+ 'flexivit_base.600ep_in1k': _cfg(
1332
+ url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz', custom_load=True,
1333
+ hf_hub_id='timm/',
1334
+ input_size=(3, 240, 240), crop_pct=0.95),
1335
+ 'flexivit_base.300ep_in1k': _cfg(
1336
+ url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz', custom_load=True,
1337
+ hf_hub_id='timm/',
1338
+ input_size=(3, 240, 240), crop_pct=0.95),
1339
+ 'flexivit_base.1000ep_in21k': _cfg(
1340
+ url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz', custom_load=True,
1341
+ hf_hub_id='timm/',
1342
+ input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
1343
+ 'flexivit_base.300ep_in21k': _cfg(
1344
+ url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz', custom_load=True,
1345
+ hf_hub_id='timm/',
1346
+ input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
1347
+
1348
+ 'flexivit_large.1200ep_in1k': _cfg(
1349
+ url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz', custom_load=True,
1350
+ hf_hub_id='timm/',
1351
+ input_size=(3, 240, 240), crop_pct=0.95),
1352
+ 'flexivit_large.600ep_in1k': _cfg(
1353
+ url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz', custom_load=True,
1354
+ hf_hub_id='timm/',
1355
+ input_size=(3, 240, 240), crop_pct=0.95),
1356
+ 'flexivit_large.300ep_in1k': _cfg(
1357
+ url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz', custom_load=True,
1358
+ hf_hub_id='timm/',
1359
+ input_size=(3, 240, 240), crop_pct=0.95),
1360
+
1361
+ 'flexivit_base.patch16_in21k': _cfg(
1362
+ url='https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz', custom_load=True,
1363
+ hf_hub_id='timm/',
1364
+ input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
1365
+ 'flexivit_base.patch30_in21k': _cfg(
1366
+ url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True,
1367
+ hf_hub_id='timm/',
1368
+ input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
1369
+ })
1370
+
1371
+
1372
+ def _create_vision_transformer(variant, pretrained=False, **kwargs):
1373
+ if kwargs.get('features_only', None):
1374
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
1375
+
1376
+ if 'flexi' in variant:
1377
+ # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
1378
+ # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
1379
+ _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
1380
+ else:
1381
+ _filter_fn = checkpoint_filter_fn
1382
+
1383
+ return build_model_with_cfg(
1384
+ VisionTransformer, variant, pretrained,
1385
+ pretrained_filter_fn=_filter_fn,
1386
+ **kwargs,
1387
+ )
1388
+
1389
+
1390
+ @register_model
1391
+ def vit_tiny_patch16_224(pretrained=False, **kwargs):
1392
+ """ ViT-Tiny (Vit-Ti/16)
1393
+ """
1394
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
1395
+ model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1396
+ return model
1397
+
1398
+
1399
+ @register_model
1400
+ def vit_tiny_patch16_384(pretrained=False, **kwargs):
1401
+ """ ViT-Tiny (Vit-Ti/16) @ 384x384.
1402
+ """
1403
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
1404
+ model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1405
+ return model
1406
+
1407
+
1408
+ @register_model
1409
+ def vit_small_patch32_224(pretrained=False, **kwargs):
1410
+ """ ViT-Small (ViT-S/32)
1411
+ """
1412
+ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
1413
+ model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1414
+ return model
1415
+
1416
+
1417
+ @register_model
1418
+ def vit_small_patch32_384(pretrained=False, **kwargs):
1419
+ """ ViT-Small (ViT-S/32) at 384x384.
1420
+ """
1421
+ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
1422
+ model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1423
+ return model
1424
+
1425
+
1426
+ @register_model
1427
+ def vit_small_patch16_224(pretrained=False, **kwargs):
1428
+ """ ViT-Small (ViT-S/16)
1429
+ """
1430
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
1431
+ model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1432
+ return model
1433
+
1434
+
1435
+ @register_model
1436
+ def vit_small_patch16_384(pretrained=False, **kwargs):
1437
+ """ ViT-Small (ViT-S/16)
1438
+ """
1439
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
1440
+ model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1441
+ return model
1442
+
1443
+
1444
+ @register_model
1445
+ def vit_small_patch8_224(pretrained=False, **kwargs):
1446
+ """ ViT-Small (ViT-S/8)
1447
+ """
1448
+ model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6)
1449
+ model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1450
+ return model
1451
+
1452
+
1453
+ @register_model
1454
+ def vit_base_patch32_224(pretrained=False, **kwargs):
1455
+ """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
1456
+ ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
1457
+ """
1458
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
1459
+ model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1460
+ return model
1461
+
1462
+
1463
+ @register_model
1464
+ def vit_base_patch32_384(pretrained=False, **kwargs):
1465
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
1466
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
1467
+ """
1468
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
1469
+ model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1470
+ return model
1471
+
1472
+
1473
+ @register_model
1474
+ def vit_base_patch16_224(pretrained=False, **kwargs):
1475
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
1476
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
1477
+ """
1478
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
1479
+ model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1480
+ return model
1481
+
1482
+
1483
+ @register_model
1484
+ def vit_base_patch16_384(pretrained=False, **kwargs):
1485
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
1486
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
1487
+ """
1488
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
1489
+ model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1490
+ return model
1491
+
1492
+
1493
+ @register_model
1494
+ def vit_base_patch8_224(pretrained=False, **kwargs):
1495
+ """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
1496
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
1497
+ """
1498
+ model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12)
1499
+ model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1500
+ return model
1501
+
1502
+
1503
+ @register_model
1504
+ def vit_large_patch32_224(pretrained=False, **kwargs):
1505
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
1506
+ """
1507
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
1508
+ model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1509
+ return model
1510
+
1511
+
1512
+ @register_model
1513
+ def vit_large_patch32_384(pretrained=False, **kwargs):
1514
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
1515
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
1516
+ """
1517
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
1518
+ model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1519
+ return model
1520
+
1521
+
1522
+ @register_model
1523
+ def vit_large_patch16_224(pretrained=False, **kwargs):
1524
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
1525
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
1526
+ """
1527
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
1528
+ model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1529
+ return model
1530
+
1531
+
1532
+ @register_model
1533
+ def vit_large_patch16_384(pretrained=False, **kwargs):
1534
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
1535
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
1536
+ """
1537
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
1538
+ model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1539
+ return model
1540
+
1541
+
1542
+ @register_model
1543
+ def vit_large_patch14_224(pretrained=False, **kwargs):
1544
+ """ ViT-Large model (ViT-L/14)
1545
+ """
1546
+ model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16)
1547
+ model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1548
+ return model
1549
+
1550
+
1551
+ @register_model
1552
+ def vit_huge_patch14_224(pretrained=False, **kwargs):
1553
+ """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
1554
+ """
1555
+ model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16)
1556
+ model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1557
+ return model
1558
+
1559
+
1560
+ @register_model
1561
+ def vit_giant_patch14_224(pretrained=False, **kwargs):
1562
+ """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
1563
+ """
1564
+ model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16)
1565
+ model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1566
+ return model
1567
+
1568
+
1569
+ @register_model
1570
+ def vit_gigantic_patch14_224(pretrained=False, **kwargs):
1571
+ """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
1572
+ """
1573
+ model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
1574
+ model = _create_vision_transformer(
1575
+ 'vit_gigantic_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1576
+ return model
1577
+
1578
+
1579
+ @register_model
1580
+ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
1581
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
1582
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
1583
+ """
1584
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False)
1585
+ model = _create_vision_transformer(
1586
+ 'vit_base_patch16_224_miil', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1587
+ return model
1588
+
1589
+
1590
+ @register_model
1591
+ def vit_medium_patch16_gap_240(pretrained=False, **kwargs):
1592
+ """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240
1593
+ """
1594
+ model_kwargs = dict(
1595
+ patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
1596
+ global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
1597
+ model = _create_vision_transformer(
1598
+ 'vit_medium_patch16_gap_240', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1599
+ return model
1600
+
1601
+
1602
+ @register_model
1603
+ def vit_medium_patch16_gap_256(pretrained=False, **kwargs):
1604
+ """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256
1605
+ """
1606
+ model_kwargs = dict(
1607
+ patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
1608
+ global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
1609
+ model = _create_vision_transformer(
1610
+ 'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1611
+ return model
1612
+
1613
+
1614
+ @register_model
1615
+ def vit_medium_patch16_gap_384(pretrained=False, **kwargs):
1616
+ """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384
1617
+ """
1618
+ model_kwargs = dict(
1619
+ patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
1620
+ global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
1621
+ model = _create_vision_transformer(
1622
+ 'vit_medium_patch16_gap_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1623
+ return model
1624
+
1625
+
1626
+ @register_model
1627
+ def vit_base_patch16_gap_224(pretrained=False, **kwargs):
1628
+ """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
1629
+ """
1630
+ model_kwargs = dict(
1631
+ patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
1632
+ model = _create_vision_transformer(
1633
+ 'vit_base_patch16_gap_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1634
+ return model
1635
+
1636
+
1637
+ @register_model
1638
+ def vit_base_patch32_clip_224(pretrained=False, **kwargs):
1639
+ """ ViT-B/32 CLIP image tower @ 224x224
1640
+ """
1641
+ model_kwargs = dict(
1642
+ patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
1643
+ model = _create_vision_transformer(
1644
+ 'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1645
+ return model
1646
+
1647
+
1648
+ @register_model
1649
+ def vit_base_patch32_clip_384(pretrained=False, **kwargs):
1650
+ """ ViT-B/32 CLIP image tower @ 384x384
1651
+ """
1652
+ model_kwargs = dict(
1653
+ patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
1654
+ model = _create_vision_transformer(
1655
+ 'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1656
+ return model
1657
+
1658
+
1659
+ @register_model
1660
+ def vit_base_patch32_clip_448(pretrained=False, **kwargs):
1661
+ """ ViT-B/32 CLIP image tower @ 448x448
1662
+ """
1663
+ model_kwargs = dict(
1664
+ patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
1665
+ model = _create_vision_transformer(
1666
+ 'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1667
+ return model
1668
+
1669
+
1670
+ @register_model
1671
+ def vit_base_patch16_clip_224(pretrained=False, **kwargs):
1672
+ """ ViT-B/16 CLIP image tower
1673
+ """
1674
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
1675
+ model = _create_vision_transformer(
1676
+ 'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1677
+ return model
1678
+
1679
+
1680
+ @register_model
1681
+ def vit_base_patch16_clip_384(pretrained=False, **kwargs):
1682
+ """ ViT-B/16 CLIP image tower @ 384x384
1683
+ """
1684
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
1685
+ model = _create_vision_transformer(
1686
+ 'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1687
+ return model
1688
+
1689
+
1690
+ @register_model
1691
+ def vit_large_patch14_clip_224(pretrained=False, **kwargs):
1692
+ """ ViT-Large model (ViT-L/14) CLIP image tower
1693
+ """
1694
+ model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
1695
+ model = _create_vision_transformer(
1696
+ 'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1697
+ return model
1698
+
1699
+
1700
+ @register_model
1701
+ def vit_large_patch14_clip_336(pretrained=False, **kwargs):
1702
+ """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336
1703
+ """
1704
+ model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
1705
+ model = _create_vision_transformer(
1706
+ 'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1707
+ return model
1708
+
1709
+
1710
+ @register_model
1711
+ def vit_huge_patch14_clip_224(pretrained=False, **kwargs):
1712
+ """ ViT-Huge model (ViT-H/14) CLIP image tower.
1713
+ """
1714
+ model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
1715
+ model = _create_vision_transformer(
1716
+ 'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1717
+ return model
1718
+
1719
+
1720
+ @register_model
1721
+ def vit_huge_patch14_clip_336(pretrained=False, **kwargs):
1722
+ """ ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336
1723
+ """
1724
+ model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
1725
+ model = _create_vision_transformer(
1726
+ 'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1727
+ return model
1728
+
1729
+
1730
+ @register_model
1731
+ def vit_giant_patch14_clip_224(pretrained=False, **kwargs):
1732
+ """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
1733
+ Pretrained weights from CLIP image tower.
1734
+ """
1735
+ model_kwargs = dict(
1736
+ patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
1737
+ model = _create_vision_transformer(
1738
+ 'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1739
+ return model
1740
+
1741
+
1742
+ # Experimental models below
1743
+
1744
+ @register_model
1745
+ def vit_base_patch32_plus_256(pretrained=False, **kwargs):
1746
+ """ ViT-Base (ViT-B/32+)
1747
+ """
1748
+ model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
1749
+ model = _create_vision_transformer(
1750
+ 'vit_base_patch32_plus_256', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1751
+ return model
1752
+
1753
+
1754
+ @register_model
1755
+ def vit_base_patch16_plus_240(pretrained=False, **kwargs):
1756
+ """ ViT-Base (ViT-B/16+)
1757
+ """
1758
+ model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
1759
+ model = _create_vision_transformer(
1760
+ 'vit_base_patch16_plus_240', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1761
+ return model
1762
+
1763
+
1764
+ @register_model
1765
+ def vit_base_patch16_rpn_224(pretrained=False, **kwargs):
1766
+ """ ViT-Base (ViT-B/16) w/ residual post-norm
1767
+ """
1768
+ model_kwargs = dict(
1769
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
1770
+ class_token=False, block_fn=ResPostBlock, global_pool='avg')
1771
+ model = _create_vision_transformer(
1772
+ 'vit_base_patch16_rpn_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1773
+ return model
1774
+
1775
+
1776
+ @register_model
1777
+ def vit_small_patch16_36x1_224(pretrained=False, **kwargs):
1778
+ """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove.
1779
+ Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1780
+ Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
1781
+ """
1782
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5)
1783
+ model = _create_vision_transformer(
1784
+ 'vit_small_patch16_36x1_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1785
+ return model
1786
+
1787
+
1788
+ @register_model
1789
+ def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
1790
+ """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
1791
+ Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1792
+ Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
1793
+ """
1794
+ model_kwargs = dict(
1795
+ patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock)
1796
+ model = _create_vision_transformer(
1797
+ 'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1798
+ return model
1799
+
1800
+
1801
+ @register_model
1802
+ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
1803
+ """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
1804
+ Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1805
+ """
1806
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock)
1807
+ model = _create_vision_transformer(
1808
+ 'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1809
+ return model
1810
+
1811
+
1812
+ @register_model
1813
+ def eva_large_patch14_196(pretrained=False, **kwargs):
1814
+ """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain"""
1815
+ model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
1816
+ model = _create_vision_transformer(
1817
+ 'eva_large_patch14_196', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1818
+ return model
1819
+
1820
+
1821
+ @register_model
1822
+ def eva_large_patch14_336(pretrained=False, **kwargs):
1823
+ """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain"""
1824
+ model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
1825
+ model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1826
+ return model
1827
+
1828
+
1829
+ @register_model
1830
+ def flexivit_small(pretrained=False, **kwargs):
1831
+ """ FlexiViT-Small
1832
+ """
1833
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True)
1834
+ model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1835
+ return model
1836
+
1837
+
1838
+ @register_model
1839
+ def flexivit_base(pretrained=False, **kwargs):
1840
+ """ FlexiViT-Base
1841
+ """
1842
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
1843
+ model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1844
+ return model
1845
+
1846
+
1847
+ @register_model
1848
+ def flexivit_large(pretrained=False, **kwargs):
1849
+ """ FlexiViT-Large
1850
+ """
1851
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
1852
+ model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs))
1853
+ return model