""" Vision Transformer (ViT) in PyTorch A PyTorch implement of Vision Transformers as described in: 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` - https://arxiv.org/abs/2106.10270 `FlexiViT: One Model for All Patch Sizes` - https://arxiv.org/abs/2212.08013 The official jax code is released and available at * https://github.com/google-research/vision_transformer * https://github.com/google-research/big_vision Acknowledgments: * The paper authors for releasing code and weights, thanks! * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT * Bert reference code checks against Huggingface Transformers and Tensorflow Bert Hacked together by / Copyright 2020, Ross Wightman """ import logging import math from collections import OrderedDict from functools import partial from typing import Optional, List import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ resample_abs_pos_embed from timm.models._builder import build_model_with_cfg from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv from timm.models._pretrained import generate_default_cfgs from timm.models._registry import register_model import math from functools import partial from typing import Optional, Tuple import argparse import json import logging import os import numpy as np import torch import torch.utils.checkpoint import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from torch.jit import Final from quantization.utils import BaseEnumOptions from transformers_language.models.softmax import clipped_softmax, clipped_softmax1 __all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) # import torch.nn.Function as F # Set to True if exporting a model with Same padding via ONNX _EXPORTABLE = False # Set to True if wanting to use torch.jit.script on a model _SCRIPTABLE = False # use torch.scaled_dot_product_attention where possible _HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention') if 'TIMM_FUSED_ATTN' in os.environ: _USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN']) else: _USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use) def logit(p, eps=1e-16): p = np.clip(p, eps, 1 - eps) return -np.log(1 / p - 1) class AttentionGateType(BaseEnumOptions): none = 0 unconditional_per_head = 1 conditional_per_head = 2 conditional_per_token = 3 def use_fused_attn(experimental: bool = False) -> bool: # NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0 if not _HAS_FUSED_ATTN or _EXPORTABLE: return False if experimental: return _USE_FUSED_ATTN > 1 return _USE_FUSED_ATTN > 0 def scaled_dot_product_attention(query, key, value, softmax_fn, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: # Efficient implementation equivalent to the following: device = "cuda" if torch.cuda.is_available() else "cpu" L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias attn_weight = softmax_fn(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value class Attention(nn.Module): fused_attn: Final[bool] def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, attn_drop: float = 0., proj_drop: float = 0., norm_layer: nn.Module = nn.LayerNorm, softmax_fn=torch.nn.functional.softmax, gamma=None, ssm_eps=None, tau=None, skip_attn=False, attn_gate_type=AttentionGateType.none, attn_gate_init=None, attn_gate_mlp=False, attn_gate_mlp2=False, attn_gate_linear_all_features=False, fine_tuning=False, max_seq_length=None, ) -> None: super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_attention_heads = num_heads self.attention_head_size = dim // num_heads self.scale = self.attention_head_size ** -0.5 self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.attention_head_size) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.attention_head_size) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.attn_scores = nn.Identity() # before attention mask self.attn_probs_before_dropout = nn.Identity() self.attn_probs_after_dropout = nn.Identity() self.gamma = gamma self.ssm_eps = ssm_eps self.tau = tau self.max_seq_length = max_seq_length # define softmax function self.softmax_fn = softmax_fn self.skip_attn = skip_attn # attention gating self.last_gate_avg_prob = None self.last_gate_all_probs = None self.attn_gate_type = attn_gate_type self.attn_gate_init = attn_gate_init self.attn_gate_mlp = attn_gate_mlp self.attn_gate_mlp2 = attn_gate_mlp2 self.attn_gate_linear_all_features = attn_gate_linear_all_features self.alpha = None self.gate_fn = torch.sigmoid self.pooling_fn = partial(torch.mean, dim=1, keepdims=True) self.fine_tuning = fine_tuning # gate scaling factor self.gate_scaling_factor = 1.0 if self.fine_tuning and self.attn_gate_init is not None: self.gate_scaling_factor = 1.0 / self.attn_gate_init # define gate if self.attn_gate_type == AttentionGateType.unconditional_per_head: init_alpha = torch.zeros(size=(self.num_attention_heads,)) self.alpha = nn.Parameter(init_alpha, requires_grad=True) elif self.attn_gate_type in ( AttentionGateType.conditional_per_head, AttentionGateType.conditional_per_token, ): if self.attn_gate_linear_all_features: self.alpha = nn.Linear(self.all_head_size, self.num_attention_heads, bias=True) else: # separate predictors for each head module_list = [] for _ in range(self.num_attention_heads): if self.attn_gate_mlp: fc = nn.Sequential( nn.Linear( self.attention_head_size, self.attention_head_size // 4, bias=True ), nn.ReLU(), nn.Linear(self.attention_head_size // 4, 1, bias=True), ) elif self.attn_gate_mlp2: fc = nn.Sequential( nn.Linear( self.attention_head_size, self.attention_head_size, bias=True ), nn.ReLU(), nn.Linear(self.attention_head_size, 1, bias=True), ) else: fc = nn.Linear(self.attention_head_size, 1, bias=True) if self.attn_gate_init is not None: init_bias = logit(self.attn_gate_init) torch.nn.init.constant_(fc.bias, init_bias) if self.fine_tuning: # init to a very small values torch.nn.init.normal_(fc.weight, mean=0.0, std=0.01) module_list.append(fc) self.alpha = nn.ModuleList(module_list) def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, x: torch.Tensor) -> torch.Tensor: hidden_states = x B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_attention_heads, self.attention_head_size).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) if self.fused_attn: context_layer = scaled_dot_product_attention( q, k, v, self.softmax_fn, dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) attn = self.softmax_fn(attn, dim=-1) attn = self.attn_probs_before_dropout(attn) attn = self.attn_drop(attn) attn = self.attn_probs_after_dropout(attn) context_layer = attn @ v # *** Gating *** if self.attn_gate_type == AttentionGateType.unconditional_per_head: gate = self.gate_fn(self.alpha) # (H,) context_layer *= gate.view(-1, 1, 1) # (B, H, T, d_head) self.last_gate_avg_prob = gate.view(-1) elif self.attn_gate_type in ( AttentionGateType.conditional_per_head, AttentionGateType.conditional_per_token, ): x = hidden_states if self.attn_gate_linear_all_features: # assume per_token alpha = self.alpha(x) # (B, T, H) gate = self.gate_fn(alpha) gate = gate.permute(0, 2, 1).contiguous() # (B, H, T) gate = gate.unsqueeze(3) # (B, H, T, 1) else: x = self.transpose_for_scores(x) # (B, H, T, d_head) alpha = [] for head_idx in range(self.num_attention_heads): x_head = x[:, head_idx, ...] # (B, T, d_head) fc_head = self.alpha[head_idx] alpha_head = fc_head(x_head) # (B, T, 1) if self.attn_gate_type == AttentionGateType.conditional_per_head: alpha_head = self.pooling_fn(alpha_head) # (B, 1, 1) alpha.append(alpha_head) alpha = torch.stack(alpha, dim=1) # (B, H, *, 1) gate = self.gate_fn(alpha) context_layer *= gate * self.gate_scaling_factor self.last_gate_all_probs = gate # all gates to see the distributions avg_gate = gate.mean(dim=0) self.last_gate_avg_prob = avg_gate.view(self.num_attention_heads, -1).mean(dim=1) x = context_layer.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma class Block(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x class ResPostBlock(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm ): super().__init__() self.init_values = init_values self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.norm1 = norm_layer(dim) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.init_weights() def init_weights(self): # NOTE this init overrides that base model init with specific changes for the block type if self.init_values is not None: nn.init.constant_(self.norm1.weight, self.init_values) nn.init.constant_(self.norm2.weight, self.init_values) def forward(self, x): x = x + self.drop_path1(self.norm1(self.attn(x))) x = x + self.drop_path2(self.norm2(self.mlp(x))) return x class ParallelBlock(nn.Module): def __init__( self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm ): super().__init__() self.num_parallel = num_parallel self.attns = nn.ModuleList() self.ffns = nn.ModuleList() for _ in range(num_parallel): self.attns.append(nn.Sequential(OrderedDict([ ('norm', norm_layer(dim)), ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ]))) self.ffns.append(nn.Sequential(OrderedDict([ ('norm', norm_layer(dim)), ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ]))) def _forward_jit(self, x): x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) return x @torch.jit.ignore def _forward(self, x): x = x + sum(attn(x) for attn in self.attns) x = x + sum(ffn(x) for ffn in self.ffns) return x def forward(self, x): if torch.jit.is_scripting() or torch.jit.is_tracing(): return self._forward_jit(x) else: return self._forward(x) class VisionTransformer(nn.Module): """ Vision Transformer A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 """ def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, class_token=True, no_embed_class=False, pre_norm=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, ): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels num_classes (int): number of classes for classification head global_pool (str): type of global pooling for final sequence (default: 'token') embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True init_values: (float): layer-scale init values class_token (bool): use class token fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate weight_init (str): weight init scheme embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer act_layer: (nn.Module): MLP activation layer """ super().__init__() assert global_pool in ('', 'avg', 'token') assert class_token or global_pool != 'token' use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 self.no_embed_class = no_embed_class self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_drop = nn.Dropout(p=drop_rate) self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer ) for i in range(depth)]) self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if weight_init != 'skip': self.init_weights(weight_init) def init_weights(self, mode=''): assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. trunc_normal_(self.pos_embed, std=.02) if self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) named_apply(get_init_weights_vit(mode, head_bias), self) def _init_weights(self, m): # this fn left here for compat with downstream users init_weights_vit_timm(m) @torch.jit.ignore() def load_pretrained(self, checkpoint_path, prefix=''): _load_weights(self, checkpoint_path, prefix) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token', 'dist_token'} @torch.jit.ignore def group_matcher(self, coarse=False): return dict( stem=r'^cls_token|pos_embed|patch_embed', # stem and embed blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] ) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self): return self.head def reset_classifier(self, num_classes: int, global_pool=None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'token') self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def _pos_embed(self, x): if self.no_embed_class: # deit-3, updated JAX (big vision) # position embedding does not overlap with class token, add then concat x = x + self.pos_embed if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) else: # original timm, JAX, and deit vit impl # pos_embed has entry for class token, concat then add if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.pos_embed return self.pos_drop(x) def forward_features(self, x): x = self.patch_embed(x) x = self._pos_embed(x) x = self.norm_pre(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) x = self.norm(x) return x def forward_head(self, x, pre_logits: bool = False): if self.global_pool: x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) return x if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) return x def init_weights_vit_timm(module: nn.Module, name: str = ''): """ ViT weight initialization, original timm impl (for reproducibility) """ if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=.02) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, 'init_weights'): module.init_weights() def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): """ ViT weight initialization, matching JAX (Flax) impl """ if isinstance(module, nn.Linear): if name.startswith('head'): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) else: nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, 'init_weights'): module.init_weights() def init_weights_vit_moco(module: nn.Module, name: str = ''): """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ if isinstance(module, nn.Linear): if 'qkv' in name: # treat the weights of Q, K, V separately val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) nn.init.uniform_(module.weight, -val, val) else: nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, 'init_weights'): module.init_weights() def get_init_weights_vit(mode='jax', head_bias: float = 0.): if 'jax' in mode: return partial(init_weights_vit_jax, head_bias=head_bias) elif 'moco' in mode: return init_weights_vit_moco else: return init_weights_vit_timm def resize_pos_embed( posemb, posemb_new, num_prefix_tokens=1, gs_new=(), interpolation='bicubic', antialias=False, ): """ Rescale the grid of position embeddings when loading from state_dict. *DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed Adapted from: https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 """ ntok_new = posemb_new.shape[1] if num_prefix_tokens: posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:] ntok_new -= num_prefix_tokens else: posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) if not len(gs_new): # backwards compatibility gs_new = [int(math.sqrt(ntok_new))] * 2 assert len(gs_new) >= 2 _logger.info(f'Resized position embedding: {posemb.shape} ({[gs_old, gs_old]}) to {posemb_new.shape} ({gs_new}).') posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=interpolation, antialias=antialias, align_corners=False) posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) return posemb @torch.no_grad() def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): """ Load weights from .npz checkpoints for official Google Brain Flax implementation """ import numpy as np def _n2p(w, t=True): if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: w = w.flatten() if t: if w.ndim == 4: w = w.transpose([3, 2, 0, 1]) elif w.ndim == 3: w = w.transpose([2, 0, 1]) elif w.ndim == 2: w = w.transpose([1, 0]) return torch.from_numpy(w) w = np.load(checkpoint_path) interpolation = 'bilinear' antialias = False big_vision = False if not prefix: if 'opt/target/embedding/kernel' in w: prefix = 'opt/target/' elif 'params/embedding/kernel' in w: prefix = 'params/' big_vision = True if hasattr(model.patch_embed, 'backbone'): # hybrid backbone = model.patch_embed.backbone stem_only = not hasattr(backbone, 'stem') stem = backbone if stem_only else backbone.stem stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) if not stem_only: for i, stage in enumerate(backbone.stages): for j, block in enumerate(stage.blocks): bp = f'{prefix}block{i + 1}/unit{j + 1}/' for r in range(3): getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) if block.downsample is not None: block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) else: embed_conv_w = adapt_input_conv( model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]: embed_conv_w = resample_patch_embed( embed_conv_w, model.patch_embed.proj.weight.shape[-2:], interpolation=interpolation, antialias=antialias, verbose=True, ) model.patch_embed.proj.weight.copy_(embed_conv_w) model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) if model.cls_token is not None: model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) if big_vision: pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) else: pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) if pos_embed_w.shape != model.pos_embed.shape: old_shape = pos_embed_w.shape num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights pos_embed_w, new_size=model.patch_embed.grid_size, num_prefix_tokens=num_prefix_tokens, interpolation=interpolation, antialias=antialias, verbose=True, ) model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2) for i, block in enumerate(model.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) block.attn.qkv.weight.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) block.attn.qkv.bias.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) for r in range(2): getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) def _convert_openai_clip(state_dict, model): out_dict = {} swaps = [ ('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'), ('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'), ('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'), ] for k, v in state_dict.items(): if not k.startswith('visual.'): continue for sp in swaps: k = k.replace(sp[0], sp[1]) if k == 'proj': k = 'head.weight' v = v.transpose(0, 1) out_dict['head.bias'] = torch.zeros(v.shape[0]) elif k == 'class_embedding': k = 'cls_token' v = v.unsqueeze(0).unsqueeze(1) elif k == 'pos_embed': v = v.unsqueeze(0) if v.shape[1] != model.pos_embed.shape[1]: # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed( v, model.pos_embed, 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), model.patch_embed.grid_size ) out_dict[k] = v return out_dict def checkpoint_filter_fn( state_dict, model, adapt_layer_scale=False, interpolation='bicubic', antialias=True, ): """ convert patch embedding weight from manual patchify + linear proj to conv""" import re out_dict = {} if 'model' in state_dict: # For deit models state_dict = state_dict['model'] if 'visual.class_embedding' in state_dict: return _convert_openai_clip(state_dict, model) for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k: O, I, H, W = model.patch_embed.proj.weight.shape if len(v.shape) < 4: # For old models that I trained prior to conv based patchification O, I, H, W = model.patch_embed.proj.weight.shape v = v.reshape(O, -1, H, W) if v.shape[-1] != W or v.shape[-2] != H: v = resample_patch_embed( v, (H, W), interpolation=interpolation, antialias=antialias, verbose=True, ) elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: # To resize pos embedding when using model at different size from pretrained weights num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) v = resample_abs_pos_embed( v, new_size=model.patch_embed.grid_size, num_prefix_tokens=num_prefix_tokens, interpolation=interpolation, antialias=antialias, verbose=True, ) elif adapt_layer_scale and 'gamma_' in k: # remap layer-scale gamma into sub-module (deit3 models) k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) elif 'pre_logits' in k: # NOTE representation layer removed as not used in latest 21k/1k pretrained weights continue out_dict[k] = v return out_dict def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } default_cfgs = generate_default_cfgs({ # re-finetuned augreg 21k FT on in1k weights 'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg( hf_hub_id='timm/'), 'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(), 'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg( hf_hub_id='timm/'), # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k 'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', hf_hub_id='timm/', custom_load=True), 'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', hf_hub_id='timm/', custom_load=True), 'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', hf_hub_id='timm/', custom_load=True), 'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', hf_hub_id='timm/', custom_load=True), 'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', hf_hub_id='timm/', custom_load=True), 'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', hf_hub_id='timm/', custom_load=True), 'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', hf_hub_id='timm/', custom_load=True), 'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k 'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', hf_hub_id='timm/'), 'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', hf_hub_id='timm/', input_size=(3, 384, 384), crop_pct=1.0), 'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', hf_hub_id='timm/', input_size=(3, 384, 384), crop_pct=1.0), # How to train your ViT (augreg) weights trained on in1k only 'vit_small_patch16_224.augreg_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', hf_hub_id='timm/', custom_load=True), 'vit_small_patch16_384.augreg_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch32_224.augreg_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', hf_hub_id='timm/', custom_load=True), 'vit_base_patch32_384.augreg_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch16_224.augreg_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', hf_hub_id='timm/', custom_load=True), 'vit_base_patch16_384.augreg_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_large_patch14_224.untrained': _cfg(url=''), 'vit_huge_patch14_224.untrained': _cfg(url=''), 'vit_giant_patch14_224.untrained': _cfg(url=''), 'vit_gigantic_patch14_224.untrained': _cfg(url=''), # patch models, imagenet21k (weights from official Google JAX impl) 'vit_large_patch32_224.orig_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', hf_hub_id='timm/', num_classes=21843), 'vit_huge_patch14_224.orig_in21k': _cfg( url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', hf_hub_id='timm/', custom_load=True, num_classes=21843), # How to train your ViT (augreg) weights, pretrained on in21k 'vit_tiny_patch16_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', hf_hub_id='timm/', custom_load=True, num_classes=21843), 'vit_small_patch32_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', hf_hub_id='timm/', custom_load=True, num_classes=21843), 'vit_small_patch16_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', hf_hub_id='timm/', custom_load=True, num_classes=21843), 'vit_base_patch32_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', hf_hub_id='timm/', custom_load=True, num_classes=21843), 'vit_base_patch16_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', hf_hub_id='timm/', custom_load=True, num_classes=21843), 'vit_base_patch8_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', hf_hub_id='timm/', custom_load=True, num_classes=21843), 'vit_large_patch16_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', hf_hub_id='timm/', custom_load=True, num_classes=21843), # SAM trained models (https://arxiv.org/abs/2106.01548) 'vit_base_patch32_224.sam': _cfg( url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True, hf_hub_id='timm/'), 'vit_base_patch16_224.sam': _cfg( url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True, hf_hub_id='timm/'), # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) 'vit_small_patch16_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 'vit_small_patch8_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 'vit_base_patch16_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 'vit_base_patch8_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), # ViT ImageNet-21K-P pretraining by MILL 'vit_base_patch16_224_miil.in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth', hf_hub_id='timm/', mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221), 'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth', hf_hub_id='timm/', mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'), # Custom timm variants 'vit_base_patch16_rpn_224.in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth', hf_hub_id='timm/'), 'vit_medium_patch16_gap_240.in12k': _cfg( hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821), 'vit_medium_patch16_gap_256.in12k_ft_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), 'vit_medium_patch16_gap_384.in12k_ft_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'), 'vit_base_patch16_gap_224': _cfg(), # CLIP pretrained image tower and related fine-tuned weights 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch32_clip_384.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), 'vit_base_patch32_clip_448.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)), 'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), 'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), 'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg( # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), 'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), 'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg( hf_hub_id='timm/', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), 'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg( hf_hub_id='timm/', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg( hf_hub_id='', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_base_patch32_clip_224.openai_ft_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch16_clip_224.openai_ft_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch16_clip_384.openai_ft_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.openai_ft_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_base_patch32_clip_224.laion2b_ft_in12k': _cfg( #hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg( hf_hub_id='timm/', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821), 'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), 'vit_base_patch32_clip_224.openai_ft_in12k': _cfg( # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_base_patch16_clip_224.openai_ft_in12k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_large_patch14_clip_224.openai_ft_in12k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), 'vit_base_patch32_clip_224.laion2b': _cfg( hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), 'vit_base_patch16_clip_224.laion2b': _cfg( # hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), 'vit_large_patch14_clip_224.laion2b': _cfg( hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K', hf_hub_filename='open_clip_pytorch_model.bin', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768), 'vit_huge_patch14_clip_224.laion2b': _cfg( hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), 'vit_giant_patch14_clip_224.laion2b': _cfg( hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), 'vit_base_patch32_clip_224.openai': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), 'vit_base_patch16_clip_224.openai': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), 'vit_large_patch14_clip_224.openai': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), # experimental (may be removed) 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), 'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), 'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_small_patch16_18x2_224': _cfg(url=''), 'vit_base_patch16_18x2_224': _cfg(url=''), # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg( # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt', hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 196, 196), crop_pct=1.0), 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg( # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt', hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), 'eva_large_patch14_196.in22k_ft_in1k': _cfg( # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt', hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 196, 196), crop_pct=1.0), 'eva_large_patch14_336.in22k_ft_in1k': _cfg( # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt', hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), 'flexivit_small.1200ep_in1k': _cfg( url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz', custom_load=True, hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95), 'flexivit_small.600ep_in1k': _cfg( url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz', custom_load=True, hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95), 'flexivit_small.300ep_in1k': _cfg( url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz', custom_load=True, hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95), 'flexivit_base.1200ep_in1k': _cfg( url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz', custom_load=True, hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95), 'flexivit_base.600ep_in1k': _cfg( url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz', custom_load=True, hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95), 'flexivit_base.300ep_in1k': _cfg( url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz', custom_load=True, hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95), 'flexivit_base.1000ep_in21k': _cfg( url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz', custom_load=True, hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), 'flexivit_base.300ep_in21k': _cfg( url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz', custom_load=True, hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), 'flexivit_large.1200ep_in1k': _cfg( url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz', custom_load=True, hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95), 'flexivit_large.600ep_in1k': _cfg( url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz', custom_load=True, hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95), 'flexivit_large.300ep_in1k': _cfg( url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz', custom_load=True, hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95), 'flexivit_base.patch16_in21k': _cfg( url='https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz', custom_load=True, hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), 'flexivit_base.patch30_in21k': _cfg( url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True, hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), }) def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') if 'flexi' in variant: # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False) else: _filter_fn = checkpoint_filter_fn return build_model_with_cfg( VisionTransformer, variant, pretrained, pretrained_filter_fn=_filter_fn, **kwargs, ) @register_model def vit_tiny_patch16_224(pretrained=False, **kwargs): """ ViT-Tiny (Vit-Ti/16) """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_tiny_patch16_384(pretrained=False, **kwargs): """ ViT-Tiny (Vit-Ti/16) @ 384x384. """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_small_patch32_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/32) """ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_small_patch32_384(pretrained=False, **kwargs): """ ViT-Small (ViT-S/32) at 384x384. """ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_small_patch16_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_small_patch16_384(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_small_patch8_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/8) """ model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6) model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch32_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12) model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch32_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12) model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch16_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch16_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch8_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12) model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_large_patch32_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. """ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_large_patch32_384(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_large_patch16_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_large_patch16_384(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_large_patch14_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/14) """ model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16) model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_huge_patch14_224(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). """ model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16) model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_giant_patch14_224(pretrained=False, **kwargs): """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16) model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_gigantic_patch14_224(pretrained=False, **kwargs): """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16) model = _create_vision_transformer( 'vit_gigantic_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch16_224_miil(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False) model = _create_vision_transformer( 'vit_base_patch16_224_miil', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_medium_patch16_gap_240(pretrained=False, **kwargs): """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240 """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) model = _create_vision_transformer( 'vit_medium_patch16_gap_240', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_medium_patch16_gap_256(pretrained=False, **kwargs): """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256 """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) model = _create_vision_transformer( 'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_medium_patch16_gap_384(pretrained=False, **kwargs): """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384 """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) model = _create_vision_transformer( 'vit_medium_patch16_gap_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch16_gap_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256 """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) model = _create_vision_transformer( 'vit_base_patch16_gap_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch32_clip_224(pretrained=False, **kwargs): """ ViT-B/32 CLIP image tower @ 224x224 """ model_kwargs = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) model = _create_vision_transformer( 'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch32_clip_384(pretrained=False, **kwargs): """ ViT-B/32 CLIP image tower @ 384x384 """ model_kwargs = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) model = _create_vision_transformer( 'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch32_clip_448(pretrained=False, **kwargs): """ ViT-B/32 CLIP image tower @ 448x448 """ model_kwargs = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) model = _create_vision_transformer( 'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch16_clip_224(pretrained=False, **kwargs): """ ViT-B/16 CLIP image tower """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) model = _create_vision_transformer( 'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch16_clip_384(pretrained=False, **kwargs): """ ViT-B/16 CLIP image tower @ 384x384 """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) model = _create_vision_transformer( 'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_large_patch14_clip_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/14) CLIP image tower """ model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) model = _create_vision_transformer( 'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_large_patch14_clip_336(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 """ model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) model = _create_vision_transformer( 'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_huge_patch14_clip_224(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) CLIP image tower. """ model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) model = _create_vision_transformer( 'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_huge_patch14_clip_336(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336 """ model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) model = _create_vision_transformer( 'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_giant_patch14_clip_224(pretrained=False, **kwargs): """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 Pretrained weights from CLIP image tower. """ model_kwargs = dict( patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) model = _create_vision_transformer( 'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model # Experimental models below @register_model def vit_base_patch32_plus_256(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32+) """ model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5) model = _create_vision_transformer( 'vit_base_patch32_plus_256', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch16_plus_240(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16+) """ model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5) model = _create_vision_transformer( 'vit_base_patch16_plus_240', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch16_rpn_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ residual post-norm """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False, block_fn=ResPostBlock, global_pool='avg') model = _create_vision_transformer( 'vit_base_patch16_rpn_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_small_patch16_36x1_224(pretrained=False, **kwargs): """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5) model = _create_vision_transformer( 'vit_small_patch16_36x1_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_small_patch16_18x2_224(pretrained=False, **kwargs): """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. """ model_kwargs = dict( patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock) model = _create_vision_transformer( 'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def vit_base_patch16_18x2_224(pretrained=False, **kwargs): """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock) model = _create_vision_transformer( 'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def eva_large_patch14_196(pretrained=False, **kwargs): """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain""" model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg') model = _create_vision_transformer( 'eva_large_patch14_196', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def eva_large_patch14_336(pretrained=False, **kwargs): """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain""" model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg') model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def flexivit_small(pretrained=False, **kwargs): """ FlexiViT-Small """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True) model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def flexivit_base(pretrained=False, **kwargs): """ FlexiViT-Base """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True) model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def flexivit_large(pretrained=False, **kwargs): """ FlexiViT-Large """ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True) model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model