File size: 16,098 Bytes
80187e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 |
""" Hybrid Vision Transformer (ViT) in PyTorch
A PyTorch implement of the Hybrid Vision Transformers as described in:
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
- https://arxiv.org/abs/2010.11929
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
- https://arxiv.org/abs/2106.TODO
NOTE These hybrid model definitions depend on code in vision_transformer.py.
They were moved here to keep file sizes sane.
Hacked together by / Copyright 2021 Ross Wightman
"""
from copy import deepcopy
from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import StdConv2dSame, StdConv2d, to_2tuple
from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem
from .registry import register_model
from timm.models.vision_transformer import _create_vision_transformer
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
**kwargs
}
default_cfgs = {
# hybrid in-1k models (weights from official JAX impl where they exist)
'vit_tiny_r_s16_p8_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
first_conv='patch_embed.backbone.conv'),
'vit_tiny_r_s16_p8_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_r26_s32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
),
'vit_small_r26_s32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_r26_s32_224': _cfg(),
'vit_base_r50_s16_224': _cfg(),
'vit_base_r50_s16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_r50_s32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'
),
'vit_large_r50_s32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0
),
# hybrid in-21k models (weights from official Google JAX impl where they exist)
'vit_tiny_r_s16_p8_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'),
'vit_small_r26_s32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9),
'vit_base_r50_s16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, crop_pct=0.9),
'vit_large_r50_s32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9),
# hybrid models (using timm resnet backbones)
'vit_small_resnet26d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_small_resnet50d_s16_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet26d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet50d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
}
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# NOTE Most reliable way of determining output dims is to run forward pass
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x).flatten(2).transpose(1, 2)
return x
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
embed_layer = partial(HybridEmbed, backbone=backbone)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
return _create_vision_transformer(
variant, pretrained=pretrained, embed_layer=embed_layer, default_cfg=default_cfgs[variant], **kwargs)
def _resnetv2(layers=(3, 4, 9), **kwargs):
""" ResNet-V2 backbone helper"""
padding_same = kwargs.get('padding_same', True)
stem_type = 'same' if padding_same else ''
conv_layer = partial(StdConv2dSame, eps=1e-8) if padding_same else partial(StdConv2d, eps=1e-8)
if len(layers):
backbone = ResNetV2(
layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
preact=False, stem_type=stem_type, conv_layer=conv_layer)
else:
backbone = create_resnetv2_stem(
kwargs.get('in_chans', 3), stem_type=stem_type, preact=False, conv_layer=conv_layer)
return backbone
@register_model
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r26_s32_224(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r26_s32_384(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r26_s32_224(pretrained=False, **kwargs):
""" R26+ViT-B/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r50_s16_224(pretrained=False, **kwargs):
""" R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
"""
backbone = _resnetv2((3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r50_s16_384(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
backbone = _resnetv2((3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet50_384(pretrained=False, **kwargs):
# DEPRECATED this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
@register_model
def vit_large_r50_s32_224(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_r50_s32_384(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid. ImageNet-21k.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
# DEPRECATED this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
@register_model
def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid. ImageNet-21k.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_resnet26d_224(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
"""
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_resnet50d_s16_224(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
"""
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_resnet50d_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet26d_224(pretrained=False, **kwargs):
""" Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
"""
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet50d_224(pretrained=False, **kwargs):
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
"""
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model |