|
""" ViTamin |
|
|
|
Paper: Designing Scalable Vison Models in the Vision-Language Era |
|
|
|
@misc{chen2023designing, |
|
title={Designing Scalable Vison Models in the Vision-Language Era}, |
|
author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen}, |
|
year={2023}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.CV} |
|
} |
|
|
|
Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin |
|
|
|
Modifications and timm support by Jieneng Chen 2023 |
|
|
|
Adapted from timm codebase, thanks! |
|
""" |
|
|
|
from functools import partial |
|
from typing import List, Tuple |
|
from dataclasses import dataclass, replace |
|
from typing import Callable, Optional, Union, Tuple, List, Sequence |
|
import math, time |
|
from torch.jit import Final |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import timm |
|
from torch.utils.checkpoint import checkpoint |
|
from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_ |
|
|
|
|
|
from timm.layers import to_2tuple, DropPath, Format |
|
from timm.layers.norm_act import _create_act |
|
from timm.models._registry import register_model |
|
from timm.models._manipulate import named_apply, checkpoint_seq |
|
from timm.models._builder import build_model_with_cfg |
|
from timm.models.vision_transformer import get_act_layer, Type, LayerType, Mlp, Block, PatchEmbed, VisionTransformer, checkpoint_filter_fn, get_init_weights_vit, init_weights_vit_timm, _load_weights |
|
import logging |
|
from collections import OrderedDict |
|
|
|
|
|
|
|
@dataclass |
|
class VitConvCfg: |
|
expand_ratio: float = 4.0 |
|
expand_output: bool = True |
|
kernel_size: int = 3 |
|
group_size: int = 1 |
|
pre_norm_act: bool = False |
|
stride_mode: str = 'dw' |
|
pool_type: str = 'avg2' |
|
downsample_pool_type: str = 'avg2' |
|
act_layer: str = 'gelu' |
|
norm_layer: str = '' |
|
norm_layer_cl: str = '' |
|
norm_eps: Optional[float] = None |
|
down_shortcut: Optional[bool] = True |
|
mlp: str = 'mlp' |
|
|
|
def __post_init__(self): |
|
use_mbconv = True |
|
if not self.norm_layer: |
|
self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d' |
|
if not self.norm_layer_cl and not use_mbconv: |
|
self.norm_layer_cl = 'layernorm' |
|
if self.norm_eps is None: |
|
self.norm_eps = 1e-5 if use_mbconv else 1e-6 |
|
self.downsample_pool_type = self.downsample_pool_type or self.pool_type |
|
|
|
@dataclass |
|
class VitCfg: |
|
embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768) |
|
depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2) |
|
stem_width: int = 64 |
|
conv_cfg: VitConvCfg = VitConvCfg() |
|
weight_init: str = 'vit_eff' |
|
head_type: str = "" |
|
stem_type: str = "stem" |
|
|
|
def _init_conv(module, name, scheme=''): |
|
if isinstance(module, nn.Conv2d): |
|
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels |
|
fan_out //= module.groups |
|
nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
|
|
class Stem(nn.Module): |
|
def __init__( |
|
self, |
|
in_chs: int, |
|
out_chs: int, |
|
act_layer: str = 'gelu', |
|
norm_layer: str = 'layernorm2d', |
|
norm_eps: float = 1e-6, |
|
bias: bool = True, |
|
): |
|
super().__init__() |
|
self.grad_checkpointing=False |
|
norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) |
|
self.out_chs = out_chs |
|
self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias) |
|
self.norm1 = norm_act_layer(out_chs) |
|
self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias) |
|
named_apply(_init_conv, self) |
|
|
|
def forward(self, x): |
|
if self.grad_checkpointing: |
|
x = checkpoint(self.conv1, x) |
|
x = self.norm1(x) |
|
x = checkpoint(self.conv2, x) |
|
else: |
|
x = self.conv1(x) |
|
x = self.norm1(x) |
|
x = self.conv2(x) |
|
|
|
return x |
|
|
|
class Downsample2d(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
dim_out: int, |
|
bias: bool = True, |
|
): |
|
super().__init__() |
|
self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) |
|
if dim != dim_out: |
|
self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) |
|
else: |
|
self.expand = nn.Identity() |
|
|
|
def forward(self, x): |
|
x = self.pool(x) |
|
x = self.expand(x) |
|
return x |
|
|
|
|
|
class StridedConv(nn.Module): |
|
""" downsample 2d as well |
|
""" |
|
def __init__( |
|
self, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
in_chans=3, |
|
embed_dim=768, |
|
): |
|
super().__init__() |
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) |
|
norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6) |
|
self.norm = norm_layer(in_chans) |
|
|
|
def forward(self, x): |
|
x = self.norm(x) |
|
x = self.proj(x) |
|
return x |
|
|
|
|
|
class MbConvLNBlock(nn.Module): |
|
def __init__( |
|
self, |
|
in_chs: int, |
|
out_chs: int, |
|
stride: int = 1, |
|
drop_path: float = 0., |
|
kernel_size: int = 3, |
|
norm_layer: str = 'layernorm2d', |
|
norm_eps: float = 1e-6, |
|
act_layer: str = 'gelu', |
|
expand_ratio: float = 4.0, |
|
): |
|
super(MbConvLNBlock, self).__init__() |
|
self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs |
|
mid_chs = make_divisible(out_chs * expand_ratio) |
|
prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) |
|
|
|
if stride == 2: |
|
self.shortcut = Downsample2d(in_chs, out_chs, bias=True) |
|
elif in_chs != out_chs: |
|
self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True) |
|
else: |
|
self.shortcut = nn.Identity() |
|
|
|
self.pre_norm = prenorm_act_layer(in_chs, apply_act=False) |
|
self.down = nn.Identity() |
|
self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True) |
|
self.act1 = _create_act(act_layer, inplace=True) |
|
self.act2 = _create_act(act_layer, inplace=True) |
|
|
|
self.conv2_kxk = create_conv2d(mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True) |
|
self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True) |
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
|
|
def init_weights(self, scheme=''): |
|
named_apply(partial(_init_conv, scheme=scheme), self) |
|
|
|
def forward(self, x): |
|
shortcut = self.shortcut(x) |
|
|
|
x = self.pre_norm(x) |
|
x = self.down(x) |
|
|
|
|
|
x = self.conv1_1x1(x) |
|
x = self.act1(x) |
|
|
|
|
|
x = self.conv2_kxk(x) |
|
x = self.act2(x) |
|
|
|
|
|
x = self.conv3_1x1(x) |
|
x = self.drop_path(x) + shortcut |
|
|
|
return x |
|
|
|
|
|
class MbConvStages(nn.Module): |
|
""" stage 1 and stage 2 of ViTamin: MBConv-LN blocks |
|
""" |
|
def __init__( |
|
self, |
|
cfg: VitCfg, |
|
img_size: Union[int, Tuple[int, int]] = 224, |
|
in_chans: int = 3, |
|
): |
|
super().__init__() |
|
self.grad_checkpointing = False |
|
self.stem = Stem( |
|
in_chs=in_chans, |
|
out_chs=cfg.stem_width, |
|
) |
|
stages = [] |
|
self.num_stages = len(cfg.embed_dim) |
|
for s, dim in enumerate(cfg.embed_dim[:2]): |
|
blocks = [] |
|
stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width |
|
for d in range(cfg.depths[s]): |
|
blocks += [MbConvLNBlock( |
|
in_chs = stage_in_chs if d==0 else dim, |
|
out_chs = dim, |
|
stride = 2 if d == 0 else 1, |
|
)] |
|
blocks = nn.Sequential(*blocks) |
|
stages += [blocks] |
|
|
|
self.stages = nn.ModuleList(stages) |
|
self.pool = StridedConv( |
|
stride=2, |
|
in_chans=cfg.embed_dim[1], |
|
embed_dim=cfg.embed_dim[2] |
|
) |
|
|
|
def forward(self, x): |
|
x = self.stem(x) |
|
if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
for stage in self.stages: |
|
x = checkpoint_seq(stage, x) |
|
x = checkpoint(self.pool, x) |
|
else: |
|
for stage in self.stages: |
|
x = stage(x) |
|
x = self.pool(x) |
|
|
|
return x |
|
|
|
class GeGluMlp(nn.Module): |
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_features, |
|
act_layer = None, |
|
drop = 0.0, |
|
): |
|
super().__init__() |
|
norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6) |
|
self.norm = norm_layer(in_features) |
|
self.act = nn.GELU() |
|
self.w0 = nn.Linear(in_features, hidden_features) |
|
self.w1 = nn.Linear(in_features, hidden_features) |
|
self.w2 = nn.Linear(hidden_features, in_features) |
|
|
|
def forward(self, x): |
|
x = self.norm(x) |
|
x = self.act(self.w0(x)) * self.w1(x) |
|
x = self.w2(x) |
|
return x |
|
|
|
class HybridEmbed(nn.Module): |
|
""" |
|
Extract feature map from stage 1-2, flatten, project to embedding dim. |
|
""" |
|
def __init__( |
|
self, |
|
backbone, |
|
img_size=224, |
|
patch_size=1, |
|
feature_size=None, |
|
in_chans=3, |
|
embed_dim=1024, |
|
bias=True, |
|
dynamic_img_pad=False, |
|
): |
|
super().__init__() |
|
assert isinstance(backbone, nn.Module) |
|
img_size = to_2tuple(img_size) |
|
patch_size = to_2tuple(patch_size) |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.backbone = backbone |
|
if feature_size is None: |
|
feature_size = img_size[0] // 16 |
|
feature_size = to_2tuple(feature_size) |
|
if hasattr(self.backbone, 'feature_info'): |
|
feature_dim = self.backbone.feature_info.channels()[-1] |
|
elif hasattr(self.backbone, 'num_features'): |
|
feature_dim = self.backbone.num_features |
|
else: |
|
feature_dim = embed_dim |
|
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 |
|
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) |
|
self.num_patches = self.grid_size[0] * self.grid_size[1] |
|
self.proj = nn.Identity() |
|
|
|
def forward(self, x): |
|
x = self.backbone(x) |
|
if isinstance(x, (list, tuple)): |
|
x = x[-1] |
|
x = self.proj(x) |
|
x = x.flatten(2).transpose(1, 2) |
|
return x |
|
|
|
def _trunc_normal_(tensor, mean, std, a, b): |
|
|
|
def norm_cdf(x): |
|
|
|
return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
|
"The distribution of values may be incorrect.", |
|
stacklevel=2) |
|
|
|
l = norm_cdf((a - mean) / std) |
|
u = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
|
|
|
|
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.)) |
|
tensor.add_(mean) |
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
return tensor |
|
|
|
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
|
with torch.no_grad(): |
|
return _trunc_normal_(tensor, mean, std, a, b) |
|
|
|
class ViTamin(nn.Module): |
|
""" hack timm VisionTransformer |
|
""" |
|
dynamic_img_size: Final[bool] |
|
|
|
def __init__( |
|
self, |
|
img_size: Union[int, Tuple[int, int]] = 224, |
|
patch_size: Union[int, Tuple[int, int]] = 16, |
|
in_chans: int = 3, |
|
num_classes: int = 1000, |
|
global_pool = 'token', |
|
embed_dim: int = 768, |
|
depth: int = 12, |
|
num_heads: int = 12, |
|
mlp_ratio: float = 4., |
|
qkv_bias: bool = True, |
|
qk_norm: bool = False, |
|
init_values: Optional[float] = None, |
|
class_token: bool = True, |
|
no_embed_class: bool = False, |
|
reg_tokens: int = 0, |
|
pre_norm: bool = False, |
|
fc_norm: Optional[bool] = None, |
|
dynamic_img_size: bool = False, |
|
dynamic_img_pad: bool = False, |
|
drop_rate: float = 0., |
|
pos_drop_rate: float = 0., |
|
patch_drop_rate: float = 0., |
|
proj_drop_rate: float = 0., |
|
attn_drop_rate: float = 0., |
|
drop_path_rate: float = 0., |
|
weight_init = '', |
|
fix_init: bool = False, |
|
embed_layer: Callable = PatchEmbed, |
|
norm_layer: Optional[LayerType] = None, |
|
act_layer: Optional[LayerType] = None, |
|
block_fn: Type[nn.Module] = Block, |
|
mlp_layer: Type[nn.Module] = Mlp, |
|
is_pos_embed: bool = True |
|
) -> None: |
|
""" |
|
Args: |
|
img_size: Input image size. |
|
patch_size: Patch size. |
|
in_chans: Number of image input channels. |
|
num_classes: Mumber of classes for classification head. |
|
global_pool: Type of global pooling for final sequence (default: 'token'). |
|
embed_dim: Transformer embedding dimension. |
|
depth: Depth of transformer. |
|
num_heads: Number of attention heads. |
|
mlp_ratio: Ratio of mlp hidden dim to embedding dim. |
|
qkv_bias: Enable bias for qkv projections if True. |
|
init_values: Layer-scale init values (layer-scale enabled if not None). |
|
class_token: Use class token. |
|
no_embed_class: Don't include position embeddings for class (or reg) tokens. |
|
reg_tokens: Number of register tokens. |
|
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. |
|
drop_rate: Head dropout rate. |
|
pos_drop_rate: Position embedding dropout rate. |
|
attn_drop_rate: Attention dropout rate. |
|
drop_path_rate: Stochastic depth rate. |
|
weight_init: Weight initialization scheme. |
|
fix_init: Apply weight initialization fix (scaling w/ layer index). |
|
embed_layer: Patch embedding layer. |
|
norm_layer: Normalization layer. |
|
act_layer: MLP activation layer. |
|
block_fn: Transformer block layer. |
|
""" |
|
super().__init__() |
|
assert global_pool in ('', 'avg', 'token', 'map') |
|
assert class_token or global_pool != 'token' |
|
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm |
|
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) |
|
act_layer = get_act_layer(act_layer) or nn.GELU |
|
|
|
self.num_classes = num_classes |
|
self.global_pool = global_pool |
|
self.num_features = self.embed_dim = embed_dim |
|
self.num_prefix_tokens = 1 if class_token else 0 |
|
self.num_prefix_tokens += reg_tokens |
|
self.num_reg_tokens = reg_tokens |
|
self.has_class_token = class_token |
|
self.no_embed_class = no_embed_class |
|
self.dynamic_img_size = dynamic_img_size |
|
self.grad_checkpointing = False |
|
self.is_pos_embed = is_pos_embed |
|
embed_args = {} |
|
if dynamic_img_size: |
|
|
|
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.patch_embed = embed_layer( |
|
img_size=img_size, |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
embed_dim=embed_dim, |
|
bias=not pre_norm, |
|
) |
|
|
|
num_patches = self.patch_embed.num_patches |
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None |
|
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None |
|
|
|
if self.is_pos_embed: |
|
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens |
|
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) |
|
else: |
|
self.pos_embed = None |
|
|
|
self.pos_drop = nn.Dropout(p=pos_drop_rate) |
|
if patch_drop_rate > 0: |
|
self.patch_drop = PatchDropout( |
|
patch_drop_rate, |
|
num_prefix_tokens=self.num_prefix_tokens, |
|
) |
|
else: |
|
self.patch_drop = nn.Identity() |
|
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() |
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
|
self.blocks = nn.Sequential(*[ |
|
block_fn( |
|
dim=embed_dim, |
|
num_heads=num_heads, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
qk_norm=qk_norm, |
|
init_values=init_values, |
|
proj_drop=proj_drop_rate, |
|
attn_drop=attn_drop_rate, |
|
drop_path=dpr[i], |
|
norm_layer=norm_layer, |
|
act_layer=act_layer, |
|
mlp_layer=mlp_layer, |
|
) |
|
for i in range(depth)]) |
|
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() |
|
|
|
|
|
if global_pool == 'map': |
|
self.attn_pool = AttentionPoolLatent( |
|
self.embed_dim, |
|
num_heads=num_heads, |
|
mlp_ratio=mlp_ratio, |
|
norm_layer=norm_layer, |
|
) |
|
else: |
|
self.attn_pool = None |
|
|
|
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() |
|
self.head_drop = nn.Dropout(drop_rate) |
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
|
|
|
if weight_init != 'skip': |
|
self.init_weights(weight_init) |
|
if fix_init: |
|
self.fix_init_weight() |
|
|
|
def init_weights(self, mode=''): |
|
assert mode in ('jax', 'jax_nlhb', 'moco', '') |
|
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. |
|
if self.is_pos_embed: |
|
trunc_normal_(self.pos_embed, std=.02) |
|
if self.cls_token is not None: |
|
nn.init.normal_(self.cls_token, std=1e-6) |
|
named_apply(get_init_weights_vit(mode, head_bias), self) |
|
|
|
def _init_weights(self, m): |
|
|
|
init_weights_vit_timm(m) |
|
|
|
@torch.jit.ignore() |
|
def load_pretrained(self, checkpoint_path, prefix=''): |
|
_load_weights(self, checkpoint_path, prefix) |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
if self.is_pos_embed: |
|
return {'pos_embed', 'cls_token', 'dist_token'} |
|
else: |
|
return {'cls_token', 'dist_token'} |
|
|
|
@torch.jit.ignore |
|
def group_matcher(self, coarse=False): |
|
return dict( |
|
stem=r'^cls_token|pos_embed|patch_embed', |
|
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] |
|
) |
|
|
|
@torch.jit.ignore |
|
def set_grad_checkpointing(self, enable=True): |
|
self.grad_checkpointing = enable |
|
self.patch_embed.backbone.stem.grad_checkpointing = enable |
|
self.patch_embed.backbone.grad_checkpointing = enable |
|
|
|
@torch.jit.ignore |
|
def get_classifier(self): |
|
return self.head |
|
|
|
def reset_classifier(self, num_classes: int, global_pool=None): |
|
self.num_classes = num_classes |
|
if global_pool is not None: |
|
assert global_pool in ('', 'avg', 'token') |
|
self.global_pool = global_pool |
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
|
|
|
def _pos_embed(self, x): |
|
if self.no_embed_class: |
|
|
|
|
|
x = x + self.pos_embed |
|
if self.cls_token is not None: |
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) |
|
else: |
|
|
|
|
|
if self.cls_token is not None: |
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) |
|
x = x + self.pos_embed |
|
return self.pos_drop(x) |
|
|
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.patch_embed(x) |
|
if self.is_pos_embed: |
|
x = self._pos_embed(x) |
|
x = self.patch_drop(x) |
|
x = self.norm_pre(x) |
|
if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
x = checkpoint_seq(self.blocks, x) |
|
else: |
|
x = self.blocks(x) |
|
x = self.norm(x) |
|
return x |
|
|
|
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: |
|
if self.attn_pool is not None: |
|
x = self.attn_pool(x) |
|
elif self.global_pool == 'avg': |
|
x = x[:, self.num_prefix_tokens:].mean(dim=1) |
|
elif self.global_pool: |
|
x = x[:, 0] |
|
x = self.fc_norm(x) |
|
x = self.head_drop(x) |
|
return x if pre_logits else self.head(x) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.forward_features(x) |
|
x = self.forward_head(x) |
|
return x |
|
|
|
def _create_vision_transformer(variant, pretrained=False, **kwargs): |
|
if kwargs.get('features_only', None): |
|
raise RuntimeError('features_only not implemented for Vision Transformer models.') |
|
|
|
return build_model_with_cfg( |
|
ViTamin, |
|
variant, |
|
pretrained, |
|
pretrained_filter_fn=checkpoint_filter_fn, |
|
**kwargs, |
|
) |
|
|
|
|
|
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): |
|
embed_layer = partial(HybridEmbed, backbone=backbone) |
|
kwargs.setdefault('patch_size', 1) |
|
return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs) |
|
|
|
|
|
@register_model |
|
def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer: |
|
stage_1_2 = MbConvStages(cfg=VitCfg( |
|
embed_dim=(64, 128, 384), |
|
depths=(2, 4, 1), |
|
stem_width=64, |
|
conv_cfg = VitConvCfg( |
|
norm_layer='layernorm2d', |
|
norm_eps=1e-6, |
|
), |
|
head_type='1d', |
|
), |
|
) |
|
stage3_args = dict(embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') |
|
model = _create_vision_transformer_hybrid('vitamin_small', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) |
|
return model |
|
|
|
|
|
@register_model |
|
def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer: |
|
stage_1_2 = MbConvStages(cfg=VitCfg( |
|
embed_dim=(128, 256, 768), |
|
depths=(2, 4, 1), |
|
stem_width=128, |
|
conv_cfg = VitConvCfg( |
|
norm_layer='layernorm2d', |
|
norm_eps=1e-6, |
|
), |
|
head_type='1d', |
|
), |
|
) |
|
stage3_args = dict(embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') |
|
model = _create_vision_transformer_hybrid('vitamin_base', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) |
|
return model |
|
|
|
|
|
@register_model |
|
def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer: |
|
stage_1_2 = MbConvStages(cfg=VitCfg( |
|
embed_dim=(160, 320, 1024), |
|
depths=(2, 4, 1), |
|
stem_width=160, |
|
conv_cfg = VitConvCfg( |
|
norm_layer='layernorm2d', |
|
norm_eps=1e-6, |
|
), |
|
head_type='1d', |
|
), |
|
) |
|
stage3_args = dict(embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') |
|
model = _create_vision_transformer_hybrid( |
|
'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) |
|
return model |
|
|
|
|
|
@register_model |
|
def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: |
|
backbone = MbConvStages(cfg=VitCfg( |
|
embed_dim=(160, 320, 1024), |
|
depths=(2, 4, 1), |
|
stem_width=160, |
|
conv_cfg = VitConvCfg( |
|
norm_layer='layernorm2d', |
|
norm_eps=1e-6, |
|
), |
|
head_type='1d', |
|
), |
|
) |
|
model_args = dict(img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') |
|
model = _create_vision_transformer_hybrid( |
|
'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) |
|
return model |
|
|
|
|
|
@register_model |
|
def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: |
|
backbone = MbConvStages(cfg=VitCfg( |
|
embed_dim=(160, 320, 1024), |
|
depths=(2, 4, 1), |
|
stem_width=160, |
|
conv_cfg = VitConvCfg( |
|
norm_layer='layernorm2d', |
|
norm_eps=1e-6, |
|
), |
|
head_type='1d', |
|
), |
|
) |
|
model_args = dict(img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') |
|
model = _create_vision_transformer_hybrid( |
|
'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) |
|
return model |
|
|
|
|
|
@register_model |
|
def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: |
|
backbone = MbConvStages(cfg=VitCfg( |
|
embed_dim=(160, 320, 1024), |
|
depths=(2, 4, 1), |
|
stem_width=160, |
|
conv_cfg = VitConvCfg( |
|
norm_layer='layernorm2d', |
|
norm_eps=1e-6, |
|
), |
|
head_type='1d', |
|
), |
|
) |
|
model_args = dict(img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') |
|
model = _create_vision_transformer_hybrid( |
|
'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) |
|
return model |
|
|
|
|
|
@register_model |
|
def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: |
|
backbone = MbConvStages(cfg=VitCfg( |
|
embed_dim=(192, 384, 1152), |
|
depths=(2, 4, 1), |
|
stem_width=192, |
|
conv_cfg = VitConvCfg( |
|
norm_layer='layernorm2d', |
|
norm_eps=1e-6, |
|
), |
|
head_type='1d', |
|
), |
|
) |
|
model_args = dict(img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, is_pos_embed=False, global_pool='avg') |
|
model = _create_vision_transformer_hybrid( |
|
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) |
|
return model |
|
|
|
|
|
@register_model |
|
def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: |
|
backbone = MbConvStages(cfg=VitCfg( |
|
embed_dim=(192, 384, 1152), |
|
depths=(2, 4, 1), |
|
stem_width=192, |
|
conv_cfg = VitConvCfg( |
|
norm_layer='layernorm2d', |
|
norm_eps=1e-6, |
|
), |
|
head_type='1d', |
|
), |
|
) |
|
model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, is_pos_embed=False, global_pool='avg') |
|
model = _create_vision_transformer_hybrid( |
|
'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) |
|
return model |
|
|
|
|
|
def count_params(model: nn.Module): |
|
return sum([m.numel() for m in model.parameters()]) |
|
|
|
|
|
def count_stage_params(model: nn.Module, prefix='none'): |
|
collections = [] |
|
for name, m in model.named_parameters(): |
|
print(name) |
|
if name.startswith(prefix): |
|
collections.append(m.numel()) |
|
return sum(collections) |
|
|
|
|
|
if __name__ == "__main__": |
|
model = timm.create_model('vitamin_large', num_classes=10).cuda() |
|
|
|
check_keys(model) |
|
|