|
""" Vision OutLOoker (VOLO) implementation |
|
|
|
Paper: `VOLO: Vision Outlooker for Visual Recognition` - https://arxiv.org/abs/2106.13112 |
|
|
|
Code adapted from official impl at https://github.com/sail-sg/volo, original copyright in comment below |
|
|
|
Modifications and additions for timm by / Copyright 2022, Ross Wightman |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_, use_fused_attn |
|
from ._builder import build_model_with_cfg |
|
from ._features import feature_take_indices |
|
from ._manipulate import checkpoint |
|
from ._registry import register_model, generate_default_cfgs |
|
|
|
__all__ = ['VOLO'] |
|
|
|
|
|
class OutlookAttention(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
num_heads, |
|
kernel_size=3, |
|
padding=1, |
|
stride=1, |
|
qkv_bias=False, |
|
attn_drop=0., |
|
proj_drop=0., |
|
): |
|
super().__init__() |
|
head_dim = dim // num_heads |
|
self.num_heads = num_heads |
|
self.kernel_size = kernel_size |
|
self.padding = padding |
|
self.stride = stride |
|
self.scale = head_dim ** -0.5 |
|
|
|
self.v = nn.Linear(dim, dim, bias=qkv_bias) |
|
self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads) |
|
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride) |
|
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True) |
|
|
|
def forward(self, x): |
|
B, H, W, C = x.shape |
|
|
|
v = self.v(x).permute(0, 3, 1, 2) |
|
|
|
h, w = math.ceil(H / self.stride), math.ceil(W / self.stride) |
|
v = self.unfold(v).reshape( |
|
B, self.num_heads, C // self.num_heads, |
|
self.kernel_size * self.kernel_size, h * w).permute(0, 1, 4, 3, 2) |
|
|
|
attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) |
|
attn = self.attn(attn).reshape( |
|
B, h * w, self.num_heads, self.kernel_size * self.kernel_size, |
|
self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4) |
|
attn = attn * self.scale |
|
attn = attn.softmax(dim=-1) |
|
attn = self.attn_drop(attn) |
|
|
|
x = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size, h * w) |
|
x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size, padding=self.padding, stride=self.stride) |
|
|
|
x = self.proj(x.permute(0, 2, 3, 1)) |
|
x = self.proj_drop(x) |
|
|
|
return x |
|
|
|
|
|
class Outlooker(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
kernel_size, |
|
padding, |
|
stride=1, |
|
num_heads=1, |
|
mlp_ratio=3., |
|
attn_drop=0., |
|
drop_path=0., |
|
act_layer=nn.GELU, |
|
norm_layer=nn.LayerNorm, |
|
qkv_bias=False, |
|
): |
|
super().__init__() |
|
self.norm1 = norm_layer(dim) |
|
self.attn = OutlookAttention( |
|
dim, |
|
num_heads, |
|
kernel_size=kernel_size, |
|
padding=padding, |
|
stride=stride, |
|
qkv_bias=qkv_bias, |
|
attn_drop=attn_drop, |
|
) |
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
self.norm2 = norm_layer(dim) |
|
self.mlp = Mlp( |
|
in_features=dim, |
|
hidden_features=int(dim * mlp_ratio), |
|
act_layer=act_layer, |
|
) |
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
def forward(self, x): |
|
x = x + self.drop_path1(self.attn(self.norm1(x))) |
|
x = x + self.drop_path2(self.mlp(self.norm2(x))) |
|
return x |
|
|
|
|
|
class Attention(nn.Module): |
|
fused_attn: torch.jit.Final[bool] |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
num_heads=8, |
|
qkv_bias=False, |
|
attn_drop=0., |
|
proj_drop=0., |
|
): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
head_dim = dim // num_heads |
|
self.scale = head_dim ** -0.5 |
|
self.fused_attn = use_fused_attn() |
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
def forward(self, x): |
|
B, H, W, C = x.shape |
|
|
|
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
|
q, k, v = qkv.unbind(0) |
|
|
|
if self.fused_attn: |
|
x = F.scaled_dot_product_attention( |
|
q, k, v, |
|
dropout_p=self.attn_drop.p if self.training else 0., |
|
) |
|
else: |
|
q = q * self.scale |
|
attn = q @ k.transpose(-2, -1) |
|
attn = attn.softmax(dim=-1) |
|
attn = self.attn_drop(attn) |
|
x = attn @ v |
|
|
|
x = x.transpose(1, 2).reshape(B, H, W, C) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
|
|
return x |
|
|
|
|
|
class Transformer(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
num_heads, |
|
mlp_ratio=4., |
|
qkv_bias=False, |
|
attn_drop=0., |
|
drop_path=0., |
|
act_layer=nn.GELU, |
|
norm_layer=nn.LayerNorm, |
|
): |
|
super().__init__() |
|
self.norm1 = norm_layer(dim) |
|
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop) |
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
self.norm2 = norm_layer(dim) |
|
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) |
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
def forward(self, x): |
|
x = x + self.drop_path1(self.attn(self.norm1(x))) |
|
x = x + self.drop_path2(self.mlp(self.norm2(x))) |
|
return x |
|
|
|
|
|
class ClassAttention(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
num_heads=8, |
|
head_dim=None, |
|
qkv_bias=False, |
|
attn_drop=0., |
|
proj_drop=0., |
|
): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
if head_dim is not None: |
|
self.head_dim = head_dim |
|
else: |
|
head_dim = dim // num_heads |
|
self.head_dim = head_dim |
|
self.scale = head_dim ** -0.5 |
|
|
|
self.kv = nn.Linear(dim, self.head_dim * self.num_heads * 2, bias=qkv_bias) |
|
self.q = nn.Linear(dim, self.head_dim * self.num_heads, bias=qkv_bias) |
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Linear(self.head_dim * self.num_heads, dim) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
def forward(self, x): |
|
B, N, C = x.shape |
|
|
|
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
|
k, v = kv.unbind(0) |
|
q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim) * self.scale |
|
|
|
attn = q @ k.transpose(-2, -1) |
|
attn = attn.softmax(dim=-1) |
|
attn = self.attn_drop(attn) |
|
|
|
cls_embed = (attn @ v).transpose(1, 2).reshape(B, 1, self.head_dim * self.num_heads) |
|
cls_embed = self.proj(cls_embed) |
|
cls_embed = self.proj_drop(cls_embed) |
|
return cls_embed |
|
|
|
|
|
class ClassBlock(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
num_heads, |
|
head_dim=None, |
|
mlp_ratio=4., |
|
qkv_bias=False, |
|
drop=0., |
|
attn_drop=0., |
|
drop_path=0., |
|
act_layer=nn.GELU, |
|
norm_layer=nn.LayerNorm, |
|
): |
|
super().__init__() |
|
self.norm1 = norm_layer(dim) |
|
self.attn = ClassAttention( |
|
dim, |
|
num_heads=num_heads, |
|
head_dim=head_dim, |
|
qkv_bias=qkv_bias, |
|
attn_drop=attn_drop, |
|
proj_drop=drop, |
|
) |
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
self.norm2 = norm_layer(dim) |
|
self.mlp = Mlp( |
|
in_features=dim, |
|
hidden_features=int(dim * mlp_ratio), |
|
act_layer=act_layer, |
|
drop=drop, |
|
) |
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
def forward(self, x): |
|
cls_embed = x[:, :1] |
|
cls_embed = cls_embed + self.drop_path1(self.attn(self.norm1(x))) |
|
cls_embed = cls_embed + self.drop_path2(self.mlp(self.norm2(cls_embed))) |
|
return torch.cat([cls_embed, x[:, 1:]], dim=1) |
|
|
|
|
|
def get_block(block_type, **kargs): |
|
if block_type == 'ca': |
|
return ClassBlock(**kargs) |
|
|
|
|
|
def rand_bbox(size, lam, scale=1): |
|
""" |
|
get bounding box as token labeling (https://github.com/zihangJiang/TokenLabeling) |
|
return: bounding box |
|
""" |
|
W = size[1] // scale |
|
H = size[2] // scale |
|
cut_rat = np.sqrt(1. - lam) |
|
cut_w = (W * cut_rat).astype(int) |
|
cut_h = (H * cut_rat).astype(int) |
|
|
|
|
|
cx = np.random.randint(W) |
|
cy = np.random.randint(H) |
|
|
|
bbx1 = np.clip(cx - cut_w // 2, 0, W) |
|
bby1 = np.clip(cy - cut_h // 2, 0, H) |
|
bbx2 = np.clip(cx + cut_w // 2, 0, W) |
|
bby2 = np.clip(cy + cut_h // 2, 0, H) |
|
|
|
return bbx1, bby1, bbx2, bby2 |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
""" Image to Patch Embedding. |
|
Different with ViT use 1 conv layer, we use 4 conv layers to do patch embedding |
|
""" |
|
|
|
def __init__( |
|
self, |
|
img_size=224, |
|
stem_conv=False, |
|
stem_stride=1, |
|
patch_size=8, |
|
in_chans=3, |
|
hidden_dim=64, |
|
embed_dim=384, |
|
): |
|
super().__init__() |
|
assert patch_size in [4, 8, 16] |
|
if stem_conv: |
|
self.conv = nn.Sequential( |
|
nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), |
|
nn.BatchNorm2d(hidden_dim), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(hidden_dim), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(hidden_dim), |
|
nn.ReLU(inplace=True), |
|
) |
|
else: |
|
self.conv = None |
|
|
|
self.proj = nn.Conv2d( |
|
hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride) |
|
self.num_patches = (img_size // patch_size) * (img_size // patch_size) |
|
|
|
def forward(self, x): |
|
if self.conv is not None: |
|
x = self.conv(x) |
|
x = self.proj(x) |
|
return x |
|
|
|
|
|
class Downsample(nn.Module): |
|
""" Image to Patch Embedding, downsampling between stage1 and stage2 |
|
""" |
|
|
|
def __init__(self, in_embed_dim, out_embed_dim, patch_size=2): |
|
super().__init__() |
|
self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size) |
|
|
|
def forward(self, x): |
|
x = x.permute(0, 3, 1, 2) |
|
x = self.proj(x) |
|
x = x.permute(0, 2, 3, 1) |
|
return x |
|
|
|
|
|
def outlooker_blocks( |
|
block_fn, |
|
index, |
|
dim, |
|
layers, |
|
num_heads=1, |
|
kernel_size=3, |
|
padding=1, |
|
stride=2, |
|
mlp_ratio=3., |
|
qkv_bias=False, |
|
attn_drop=0, |
|
drop_path_rate=0., |
|
**kwargs, |
|
): |
|
""" |
|
generate outlooker layer in stage1 |
|
return: outlooker layers |
|
""" |
|
blocks = [] |
|
for block_idx in range(layers[index]): |
|
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) |
|
blocks.append(block_fn( |
|
dim, |
|
kernel_size=kernel_size, |
|
padding=padding, |
|
stride=stride, |
|
num_heads=num_heads, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
attn_drop=attn_drop, |
|
drop_path=block_dpr, |
|
)) |
|
blocks = nn.Sequential(*blocks) |
|
return blocks |
|
|
|
|
|
def transformer_blocks( |
|
block_fn, |
|
index, |
|
dim, |
|
layers, |
|
num_heads, |
|
mlp_ratio=3., |
|
qkv_bias=False, |
|
attn_drop=0, |
|
drop_path_rate=0., |
|
**kwargs, |
|
): |
|
""" |
|
generate transformer layers in stage2 |
|
return: transformer layers |
|
""" |
|
blocks = [] |
|
for block_idx in range(layers[index]): |
|
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) |
|
blocks.append(block_fn( |
|
dim, |
|
num_heads, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
attn_drop=attn_drop, |
|
drop_path=block_dpr, |
|
)) |
|
blocks = nn.Sequential(*blocks) |
|
return blocks |
|
|
|
|
|
class VOLO(nn.Module): |
|
""" |
|
Vision Outlooker, the main class of our model |
|
""" |
|
|
|
def __init__( |
|
self, |
|
layers, |
|
img_size=224, |
|
in_chans=3, |
|
num_classes=1000, |
|
global_pool='token', |
|
patch_size=8, |
|
stem_hidden_dim=64, |
|
embed_dims=None, |
|
num_heads=None, |
|
downsamples=(True, False, False, False), |
|
outlook_attention=(True, False, False, False), |
|
mlp_ratio=3.0, |
|
qkv_bias=False, |
|
drop_rate=0., |
|
pos_drop_rate=0., |
|
attn_drop_rate=0., |
|
drop_path_rate=0., |
|
norm_layer=nn.LayerNorm, |
|
post_layers=('ca', 'ca'), |
|
use_aux_head=True, |
|
use_mix_token=False, |
|
pooling_scale=2, |
|
): |
|
super().__init__() |
|
num_layers = len(layers) |
|
mlp_ratio = to_ntuple(num_layers)(mlp_ratio) |
|
img_size = to_2tuple(img_size) |
|
|
|
self.num_classes = num_classes |
|
self.global_pool = global_pool |
|
self.mix_token = use_mix_token |
|
self.pooling_scale = pooling_scale |
|
self.num_features = self.head_hidden_size = embed_dims[-1] |
|
if use_mix_token: |
|
self.beta = 1.0 |
|
assert global_pool == 'token', "return all tokens if mix_token is enabled" |
|
self.grad_checkpointing = False |
|
|
|
self.patch_embed = PatchEmbed( |
|
stem_conv=True, |
|
stem_stride=2, |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
hidden_dim=stem_hidden_dim, |
|
embed_dim=embed_dims[0], |
|
) |
|
r = patch_size |
|
|
|
|
|
patch_grid = (img_size[0] // patch_size // pooling_scale, img_size[1] // patch_size // pooling_scale) |
|
self.pos_embed = nn.Parameter(torch.zeros(1, patch_grid[0], patch_grid[1], embed_dims[-1])) |
|
self.pos_drop = nn.Dropout(p=pos_drop_rate) |
|
|
|
|
|
self.stage_ends = [] |
|
self.feature_info = [] |
|
network = [] |
|
block_idx = 0 |
|
for i in range(len(layers)): |
|
if outlook_attention[i]: |
|
|
|
stage = outlooker_blocks( |
|
Outlooker, |
|
i, |
|
embed_dims[i], |
|
layers, |
|
num_heads[i], |
|
mlp_ratio=mlp_ratio[i], |
|
qkv_bias=qkv_bias, |
|
attn_drop=attn_drop_rate, |
|
norm_layer=norm_layer, |
|
) |
|
else: |
|
|
|
stage = transformer_blocks( |
|
Transformer, |
|
i, |
|
embed_dims[i], |
|
layers, |
|
num_heads[i], |
|
mlp_ratio=mlp_ratio[i], |
|
qkv_bias=qkv_bias, |
|
drop_path_rate=drop_path_rate, |
|
attn_drop=attn_drop_rate, |
|
norm_layer=norm_layer, |
|
) |
|
network.append(stage) |
|
self.stage_ends.append(block_idx) |
|
self.feature_info.append(dict(num_chs=embed_dims[i], reduction=r, module=f'network.{block_idx}')) |
|
block_idx += 1 |
|
if downsamples[i]: |
|
|
|
network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2)) |
|
r *= 2 |
|
block_idx += 1 |
|
|
|
self.network = nn.ModuleList(network) |
|
|
|
|
|
self.post_network = None |
|
if post_layers is not None: |
|
self.post_network = nn.ModuleList([ |
|
get_block( |
|
post_layers[i], |
|
dim=embed_dims[-1], |
|
num_heads=num_heads[-1], |
|
mlp_ratio=mlp_ratio[-1], |
|
qkv_bias=qkv_bias, |
|
attn_drop=attn_drop_rate, |
|
drop_path=0., |
|
norm_layer=norm_layer) |
|
for i in range(len(post_layers)) |
|
]) |
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1])) |
|
trunc_normal_(self.cls_token, std=.02) |
|
|
|
|
|
if use_aux_head: |
|
self.aux_head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() |
|
else: |
|
self.aux_head = None |
|
self.norm = norm_layer(self.num_features) |
|
|
|
|
|
self.head_drop = nn.Dropout(drop_rate) |
|
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() |
|
|
|
trunc_normal_(self.pos_embed, std=.02) |
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
return {'pos_embed', 'cls_token'} |
|
|
|
@torch.jit.ignore |
|
def group_matcher(self, coarse=False): |
|
return dict( |
|
stem=r'^cls_token|pos_embed|patch_embed', |
|
blocks=[ |
|
(r'^network\.(\d+)\.(\d+)', None), |
|
(r'^network\.(\d+)', (0,)), |
|
], |
|
blocks2=[ |
|
(r'^cls_token', (0,)), |
|
(r'^post_network\.(\d+)', None), |
|
(r'^norm', (99999,)) |
|
], |
|
) |
|
|
|
@torch.jit.ignore |
|
def set_grad_checkpointing(self, enable=True): |
|
self.grad_checkpointing = enable |
|
|
|
@torch.jit.ignore |
|
def get_classifier(self) -> nn.Module: |
|
return self.head |
|
|
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): |
|
self.num_classes = num_classes |
|
if global_pool is not None: |
|
self.global_pool = global_pool |
|
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() |
|
if self.aux_head is not None: |
|
self.aux_head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() |
|
|
|
def forward_tokens(self, x): |
|
for idx, block in enumerate(self.network): |
|
if idx == 2: |
|
|
|
x = x + self.pos_embed |
|
x = self.pos_drop(x) |
|
if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
x = checkpoint(block, x) |
|
else: |
|
x = block(x) |
|
|
|
B, H, W, C = x.shape |
|
x = x.reshape(B, -1, C) |
|
return x |
|
|
|
def forward_cls(self, x): |
|
B, N, C = x.shape |
|
cls_tokens = self.cls_token.expand(B, -1, -1) |
|
x = torch.cat([cls_tokens, x], dim=1) |
|
for block in self.post_network: |
|
if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
x = checkpoint(block, x) |
|
else: |
|
x = block(x) |
|
return x |
|
|
|
def forward_train(self, x): |
|
""" A separate forward fn for training with mix_token (if a train script supports). |
|
Combining multiple modes in as single forward with different return types is torchscript hell. |
|
""" |
|
x = self.patch_embed(x) |
|
x = x.permute(0, 2, 3, 1) |
|
|
|
|
|
if self.mix_token and self.training: |
|
lam = np.random.beta(self.beta, self.beta) |
|
patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[2] // self.pooling_scale |
|
bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale) |
|
temp_x = x.clone() |
|
sbbx1, sbby1 = self.pooling_scale * bbx1, self.pooling_scale * bby1 |
|
sbbx2, sbby2 = self.pooling_scale * bbx2, self.pooling_scale * bby2 |
|
temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :] |
|
x = temp_x |
|
else: |
|
bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0 |
|
|
|
|
|
x = self.forward_tokens(x) |
|
|
|
|
|
if self.post_network is not None: |
|
x = self.forward_cls(x) |
|
x = self.norm(x) |
|
|
|
if self.global_pool == 'avg': |
|
x_cls = x.mean(dim=1) |
|
elif self.global_pool == 'token': |
|
x_cls = x[:, 0] |
|
else: |
|
x_cls = x |
|
|
|
if self.aux_head is None: |
|
return x_cls |
|
|
|
x_aux = self.aux_head(x[:, 1:]) |
|
if not self.training: |
|
return x_cls + 0.5 * x_aux.max(1)[0] |
|
|
|
if self.mix_token and self.training: |
|
x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1]) |
|
temp_x = x_aux.clone() |
|
temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :] |
|
x_aux = temp_x |
|
x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1]) |
|
|
|
|
|
return x_cls, x_aux, (bbx1, bby1, bbx2, bby2) |
|
|
|
def forward_intermediates( |
|
self, |
|
x: torch.Tensor, |
|
indices: Optional[Union[int, List[int]]] = None, |
|
norm: bool = False, |
|
stop_early: bool = False, |
|
output_fmt: str = 'NCHW', |
|
intermediates_only: bool = False, |
|
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: |
|
""" Forward features that returns intermediates. |
|
|
|
Args: |
|
x: Input image tensor |
|
indices: Take last n blocks if int, all if None, select matching indices if sequence |
|
norm: Apply norm layer to all intermediates |
|
stop_early: Stop iterating over blocks when last desired intermediate hit |
|
output_fmt: Shape of intermediate feature outputs |
|
intermediates_only: Only return intermediate features |
|
Returns: |
|
|
|
""" |
|
assert output_fmt in ('NCHW',), 'Output format must be NCHW.' |
|
intermediates = [] |
|
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) |
|
take_indices = [self.stage_ends[i] for i in take_indices] |
|
max_index = self.stage_ends[max_index] |
|
|
|
|
|
B, _, height, width = x.shape |
|
x = self.patch_embed(x).permute(0, 2, 3, 1) |
|
|
|
|
|
if torch.jit.is_scripting() or not stop_early: |
|
network = self.network |
|
else: |
|
network = self.network[:max_index + 1] |
|
for idx, block in enumerate(network): |
|
if idx == 2: |
|
|
|
x = x + self.pos_embed |
|
x = self.pos_drop(x) |
|
x = block(x) |
|
if idx in take_indices: |
|
if norm and idx >= 2: |
|
x_inter = self.norm(x) |
|
else: |
|
x_inter = x |
|
intermediates.append(x_inter.permute(0, 3, 1, 2)) |
|
|
|
if intermediates_only: |
|
return intermediates |
|
|
|
|
|
|
|
B, H, W, C = x.shape |
|
x = x.reshape(B, -1, C) |
|
if self.post_network is not None: |
|
x = self.forward_cls(x) |
|
x = self.norm(x) |
|
|
|
return x, intermediates |
|
|
|
def prune_intermediate_layers( |
|
self, |
|
indices: Union[int, List[int]] = 1, |
|
prune_norm: bool = False, |
|
prune_head: bool = True, |
|
): |
|
""" Prune layers not required for specified intermediates. |
|
""" |
|
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) |
|
max_index = self.stage_ends[max_index] |
|
self.network = self.network[:max_index + 1] |
|
if prune_norm: |
|
self.norm = nn.Identity() |
|
if prune_head: |
|
self.post_network = nn.ModuleList() |
|
self.reset_classifier(0, '') |
|
return take_indices |
|
|
|
def forward_features(self, x): |
|
x = self.patch_embed(x).permute(0, 2, 3, 1) |
|
|
|
|
|
x = self.forward_tokens(x) |
|
|
|
|
|
if self.post_network is not None: |
|
x = self.forward_cls(x) |
|
x = self.norm(x) |
|
return x |
|
|
|
def forward_head(self, x, pre_logits: bool = False): |
|
if self.global_pool == 'avg': |
|
out = x.mean(dim=1) |
|
elif self.global_pool == 'token': |
|
out = x[:, 0] |
|
else: |
|
out = x |
|
x = self.head_drop(x) |
|
if pre_logits: |
|
return out |
|
out = self.head(out) |
|
if self.aux_head is not None: |
|
|
|
aux = self.aux_head(x[:, 1:]) |
|
out = out + 0.5 * aux.max(1)[0] |
|
return out |
|
|
|
def forward(self, x): |
|
""" simplified forward (without mix token training) """ |
|
x = self.forward_features(x) |
|
x = self.forward_head(x) |
|
return x |
|
|
|
|
|
def _create_volo(variant, pretrained=False, **kwargs): |
|
out_indices = kwargs.pop('out_indices', 3) |
|
return build_model_with_cfg( |
|
VOLO, |
|
variant, |
|
pretrained, |
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), |
|
**kwargs, |
|
) |
|
|
|
|
|
def _cfg(url='', **kwargs): |
|
return { |
|
'url': url, |
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, |
|
'crop_pct': .96, 'interpolation': 'bicubic', 'fixed_input_size': True, |
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, |
|
'first_conv': 'patch_embed.conv.0', 'classifier': ('head', 'aux_head'), |
|
**kwargs |
|
} |
|
|
|
|
|
default_cfgs = generate_default_cfgs({ |
|
'volo_d1_224.sail_in1k': _cfg( |
|
hf_hub_id='timm/', |
|
url='https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar', |
|
crop_pct=0.96), |
|
'volo_d1_384.sail_in1k': _cfg( |
|
hf_hub_id='timm/', |
|
url='https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar', |
|
crop_pct=1.0, input_size=(3, 384, 384)), |
|
'volo_d2_224.sail_in1k': _cfg( |
|
hf_hub_id='timm/', |
|
url='https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar', |
|
crop_pct=0.96), |
|
'volo_d2_384.sail_in1k': _cfg( |
|
hf_hub_id='timm/', |
|
url='https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar', |
|
crop_pct=1.0, input_size=(3, 384, 384)), |
|
'volo_d3_224.sail_in1k': _cfg( |
|
hf_hub_id='timm/', |
|
url='https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar', |
|
crop_pct=0.96), |
|
'volo_d3_448.sail_in1k': _cfg( |
|
hf_hub_id='timm/', |
|
url='https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar', |
|
crop_pct=1.0, input_size=(3, 448, 448)), |
|
'volo_d4_224.sail_in1k': _cfg( |
|
hf_hub_id='timm/', |
|
url='https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar', |
|
crop_pct=0.96), |
|
'volo_d4_448.sail_in1k': _cfg( |
|
hf_hub_id='timm/', |
|
url='https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar', |
|
crop_pct=1.15, input_size=(3, 448, 448)), |
|
'volo_d5_224.sail_in1k': _cfg( |
|
hf_hub_id='timm/', |
|
url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar', |
|
crop_pct=0.96), |
|
'volo_d5_448.sail_in1k': _cfg( |
|
hf_hub_id='timm/', |
|
url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar', |
|
crop_pct=1.15, input_size=(3, 448, 448)), |
|
'volo_d5_512.sail_in1k': _cfg( |
|
hf_hub_id='timm/', |
|
url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar', |
|
crop_pct=1.15, input_size=(3, 512, 512)), |
|
}) |
|
|
|
|
|
@register_model |
|
def volo_d1_224(pretrained=False, **kwargs) -> VOLO: |
|
""" VOLO-D1 model, Params: 27M """ |
|
model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs) |
|
model = _create_volo('volo_d1_224', pretrained=pretrained, **model_args) |
|
return model |
|
|
|
|
|
@register_model |
|
def volo_d1_384(pretrained=False, **kwargs) -> VOLO: |
|
""" VOLO-D1 model, Params: 27M """ |
|
model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs) |
|
model = _create_volo('volo_d1_384', pretrained=pretrained, **model_args) |
|
return model |
|
|
|
|
|
@register_model |
|
def volo_d2_224(pretrained=False, **kwargs) -> VOLO: |
|
""" VOLO-D2 model, Params: 59M """ |
|
model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) |
|
model = _create_volo('volo_d2_224', pretrained=pretrained, **model_args) |
|
return model |
|
|
|
|
|
@register_model |
|
def volo_d2_384(pretrained=False, **kwargs) -> VOLO: |
|
""" VOLO-D2 model, Params: 59M """ |
|
model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) |
|
model = _create_volo('volo_d2_384', pretrained=pretrained, **model_args) |
|
return model |
|
|
|
|
|
@register_model |
|
def volo_d3_224(pretrained=False, **kwargs) -> VOLO: |
|
""" VOLO-D3 model, Params: 86M """ |
|
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) |
|
model = _create_volo('volo_d3_224', pretrained=pretrained, **model_args) |
|
return model |
|
|
|
|
|
@register_model |
|
def volo_d3_448(pretrained=False, **kwargs) -> VOLO: |
|
""" VOLO-D3 model, Params: 86M """ |
|
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs) |
|
model = _create_volo('volo_d3_448', pretrained=pretrained, **model_args) |
|
return model |
|
|
|
|
|
@register_model |
|
def volo_d4_224(pretrained=False, **kwargs) -> VOLO: |
|
""" VOLO-D4 model, Params: 193M """ |
|
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs) |
|
model = _create_volo('volo_d4_224', pretrained=pretrained, **model_args) |
|
return model |
|
|
|
|
|
@register_model |
|
def volo_d4_448(pretrained=False, **kwargs) -> VOLO: |
|
""" VOLO-D4 model, Params: 193M """ |
|
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs) |
|
model = _create_volo('volo_d4_448', pretrained=pretrained, **model_args) |
|
return model |
|
|
|
|
|
@register_model |
|
def volo_d5_224(pretrained=False, **kwargs) -> VOLO: |
|
""" VOLO-D5 model, Params: 296M |
|
stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5 |
|
""" |
|
model_args = dict( |
|
layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), |
|
mlp_ratio=4, stem_hidden_dim=128, **kwargs) |
|
model = _create_volo('volo_d5_224', pretrained=pretrained, **model_args) |
|
return model |
|
|
|
|
|
@register_model |
|
def volo_d5_448(pretrained=False, **kwargs) -> VOLO: |
|
""" VOLO-D5 model, Params: 296M |
|
stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5 |
|
""" |
|
model_args = dict( |
|
layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), |
|
mlp_ratio=4, stem_hidden_dim=128, **kwargs) |
|
model = _create_volo('volo_d5_448', pretrained=pretrained, **model_args) |
|
return model |
|
|
|
|
|
@register_model |
|
def volo_d5_512(pretrained=False, **kwargs) -> VOLO: |
|
""" VOLO-D5 model, Params: 296M |
|
stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5 |
|
""" |
|
model_args = dict( |
|
layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), |
|
mlp_ratio=4, stem_hidden_dim=128, **kwargs) |
|
model = _create_volo('volo_d5_512', pretrained=pretrained, **model_args) |
|
return model |
|
|