onescotch
add huggingface implementation
2de1f98
raw
history blame
10.6 kB
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
from mmcv.cnn import (build_norm_layer, constant_init, normal_init,
trunc_normal_init)
from mmcv.runner import _load_checkpoint, load_state_dict
from ...utils import get_root_logger
from ..builder import BACKBONES
from ..utils import (PatchEmbed, TCFormerDynamicBlock, TCFormerRegularBlock,
TokenConv, cluster_dpc_knn, merge_tokens,
tcformer_convert, token2map)
class CTM(nn.Module):
"""Clustering-based Token Merging module in TCFormer.
Args:
sample_ratio (float): The sample ratio of tokens.
embed_dim (int): Input token feature dimension.
dim_out (int): Output token feature dimension.
k (int): number of the nearest neighbor used i DPC-knn algorithm.
"""
def __init__(self, sample_ratio, embed_dim, dim_out, k=5):
super().__init__()
self.sample_ratio = sample_ratio
self.dim_out = dim_out
self.conv = TokenConv(
in_channels=embed_dim,
out_channels=dim_out,
kernel_size=3,
stride=2,
padding=1)
self.norm = nn.LayerNorm(self.dim_out)
self.score = nn.Linear(self.dim_out, 1)
self.k = k
def forward(self, token_dict):
token_dict = token_dict.copy()
x = self.conv(token_dict)
x = self.norm(x)
token_score = self.score(x)
token_weight = token_score.exp()
token_dict['x'] = x
B, N, C = x.shape
token_dict['token_score'] = token_score
cluster_num = max(math.ceil(N * self.sample_ratio), 1)
idx_cluster, cluster_num = cluster_dpc_knn(token_dict, cluster_num,
self.k)
down_dict = merge_tokens(token_dict, idx_cluster, cluster_num,
token_weight)
H, W = token_dict['map_size']
H = math.floor((H - 1) / 2 + 1)
W = math.floor((W - 1) / 2 + 1)
down_dict['map_size'] = [H, W]
return down_dict, token_dict
@BACKBONES.register_module()
class TCFormer(nn.Module):
"""Token Clustering Transformer (TCFormer)
Implementation of `Not All Tokens Are Equal: Human-centric Visual
Analysis via Token Clustering Transformer
<https://arxiv.org/abs/2204.08680>`
Args:
in_channels (int): Number of input channels. Default: 3.
embed_dims (list[int]): Embedding dimension. Default:
[64, 128, 256, 512].
num_heads (Sequence[int]): The attention heads of each transformer
encode layer. Default: [1, 2, 5, 8].
mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the
embedding dim of each transformer block.
qkv_bias (bool): Enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0.
drop_path_rate (float): stochastic depth rate. Default 0.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN', eps=1e-6).
num_layers (Sequence[int]): The layer number of each transformer encode
layer. Default: [3, 4, 6, 3].
sr_ratios (Sequence[int]): The spatial reduction rate of each
transformer block. Default: [8, 4, 2, 1].
num_stages (int): The num of stages. Default: 4.
pretrained (str, optional): model pretrained path. Default: None.
k (int): number of the nearest neighbor used for local density.
sample_ratios (list[float]): The sample ratios of CTM modules.
Default: [0.25, 0.25, 0.25]
return_map (bool): If True, transfer dynamic tokens to feature map at
last. Default: False
convert_weights (bool): The flag indicates whether the
pre-trained model is from the original repo. We may need
to convert some keys to make it compatible.
Default: True.
"""
def __init__(self,
in_channels=3,
embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8],
mlp_ratios=[4, 4, 4, 4],
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN', eps=1e-6),
num_layers=[3, 4, 6, 3],
sr_ratios=[8, 4, 2, 1],
num_stages=4,
pretrained=None,
k=5,
sample_ratios=[0.25, 0.25, 0.25],
return_map=False,
convert_weights=True):
super().__init__()
self.num_layers = num_layers
self.num_stages = num_stages
self.grid_stride = sr_ratios[0]
self.embed_dims = embed_dims
self.sr_ratios = sr_ratios
self.mlp_ratios = mlp_ratios
self.sample_ratios = sample_ratios
self.return_map = return_map
self.convert_weights = convert_weights
# stochastic depth decay rule
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
]
cur = 0
# In stage 1, use the standard transformer blocks
for i in range(1):
patch_embed = PatchEmbed(
in_channels=in_channels if i == 0 else embed_dims[i - 1],
embed_dims=embed_dims[i],
kernel_size=7,
stride=4,
padding=3,
bias=True,
norm_cfg=dict(type='LN', eps=1e-6))
block = nn.ModuleList([
TCFormerRegularBlock(
dim=embed_dims[i],
num_heads=num_heads[i],
mlp_ratio=mlp_ratios[i],
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[cur + j],
norm_cfg=norm_cfg,
sr_ratio=sr_ratios[i]) for j in range(num_layers[i])
])
norm = build_norm_layer(norm_cfg, embed_dims[i])[1]
cur += num_layers[i]
setattr(self, f'patch_embed{i + 1}', patch_embed)
setattr(self, f'block{i + 1}', block)
setattr(self, f'norm{i + 1}', norm)
# In stage 2~4, use TCFormerDynamicBlock for dynamic tokens
for i in range(1, num_stages):
ctm = CTM(sample_ratios[i - 1], embed_dims[i - 1], embed_dims[i],
k)
block = nn.ModuleList([
TCFormerDynamicBlock(
dim=embed_dims[i],
num_heads=num_heads[i],
mlp_ratio=mlp_ratios[i],
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[cur + j],
norm_cfg=norm_cfg,
sr_ratio=sr_ratios[i]) for j in range(num_layers[i])
])
norm = build_norm_layer(norm_cfg, embed_dims[i])[1]
cur += num_layers[i]
setattr(self, f'ctm{i}', ctm)
setattr(self, f'block{i + 1}', block)
setattr(self, f'norm{i + 1}', norm)
self.init_weights(pretrained)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = get_root_logger()
checkpoint = _load_checkpoint(
pretrained, logger=logger, map_location='cpu')
logger.warning(f'Load pre-trained model for '
f'{self.__class__.__name__} from original repo')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
if self.convert_weights:
# We need to convert pre-trained weights to match this
# implementation.
state_dict = tcformer_convert(state_dict)
load_state_dict(self, state_dict, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(m, 0, math.sqrt(2.0 / fan_out))
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
outs = []
i = 0
patch_embed = getattr(self, f'patch_embed{i + 1}')
block = getattr(self, f'block{i + 1}')
norm = getattr(self, f'norm{i + 1}')
x, (H, W) = patch_embed(x)
for blk in block:
x = blk(x, H, W)
x = norm(x)
# init token dict
B, N, _ = x.shape
device = x.device
idx_token = torch.arange(N)[None, :].repeat(B, 1).to(device)
agg_weight = x.new_ones(B, N, 1)
token_dict = {
'x': x,
'token_num': N,
'map_size': [H, W],
'init_grid_size': [H, W],
'idx_token': idx_token,
'agg_weight': agg_weight
}
outs.append(token_dict.copy())
# stage 2~4
for i in range(1, self.num_stages):
ctm = getattr(self, f'ctm{i}')
block = getattr(self, f'block{i + 1}')
norm = getattr(self, f'norm{i + 1}')
token_dict = ctm(token_dict) # down sample
for j, blk in enumerate(block):
token_dict = blk(token_dict)
token_dict['x'] = norm(token_dict['x'])
outs.append(token_dict)
if self.return_map:
outs = [token2map(token_dict) for token_dict in outs]
return outs