Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn.bricks import DropPath | |
| from mmengine.model import BaseModule, constant_init | |
| from mmengine.model.weight_init import trunc_normal_ | |
| from mmpose.registry import MODELS | |
| from .base_backbone import BaseBackbone | |
| class Attention(BaseModule): | |
| def __init__(self, | |
| dim, | |
| num_heads=8, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| attn_drop=0., | |
| proj_drop=0., | |
| mode='spatial'): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = qk_scale or head_dim**-0.5 | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.mode = mode | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| self.attn_count_s = None | |
| self.attn_count_t = None | |
| def forward(self, x, seq_len=1): | |
| B, N, C = x.shape | |
| if self.mode == 'temporal': | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // | |
| self.num_heads).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[ | |
| 2] # make torchscript happy (cannot use tensor as tuple) | |
| x = self.forward_temporal(q, k, v, seq_len=seq_len) | |
| elif self.mode == 'spatial': | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // | |
| self.num_heads).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[ | |
| 2] # make torchscript happy (cannot use tensor as tuple) | |
| x = self.forward_spatial(q, k, v) | |
| else: | |
| raise NotImplementedError(self.mode) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| def forward_spatial(self, q, k, v): | |
| B, _, N, C = q.shape | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = attn @ v | |
| x = x.transpose(1, 2).reshape(B, N, C * self.num_heads) | |
| return x | |
| def forward_temporal(self, q, k, v, seq_len=8): | |
| B, _, N, C = q.shape | |
| qt = q.reshape(-1, seq_len, self.num_heads, N, | |
| C).permute(0, 2, 3, 1, 4) # (B, H, N, T, C) | |
| kt = k.reshape(-1, seq_len, self.num_heads, N, | |
| C).permute(0, 2, 3, 1, 4) # (B, H, N, T, C) | |
| vt = v.reshape(-1, seq_len, self.num_heads, N, | |
| C).permute(0, 2, 3, 1, 4) # (B, H, N, T, C) | |
| attn = (qt @ kt.transpose(-2, -1)) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = attn @ vt # (B, H, N, T, C) | |
| x = x.permute(0, 3, 2, 1, 4).reshape(B, N, C * self.num_heads) | |
| return x | |
| class AttentionBlock(BaseModule): | |
| def __init__(self, | |
| dim, | |
| num_heads, | |
| mlp_ratio=4., | |
| mlp_out_ratio=1., | |
| qkv_bias=True, | |
| qk_scale=None, | |
| drop=0., | |
| attn_drop=0., | |
| drop_path=0., | |
| st_mode='st'): | |
| super().__init__() | |
| self.st_mode = st_mode | |
| self.norm1_s = nn.LayerNorm(dim, eps=1e-06) | |
| self.norm1_t = nn.LayerNorm(dim, eps=1e-06) | |
| self.attn_s = Attention( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=drop, | |
| mode='spatial') | |
| self.attn_t = Attention( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=drop, | |
| mode='temporal') | |
| self.drop_path = DropPath( | |
| drop_path) if drop_path > 0. else nn.Identity() | |
| self.norm2_s = nn.LayerNorm(dim, eps=1e-06) | |
| self.norm2_t = nn.LayerNorm(dim, eps=1e-06) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| mlp_out_dim = int(dim * mlp_out_ratio) | |
| self.mlp_s = nn.Sequential( | |
| nn.Linear(dim, mlp_hidden_dim), nn.GELU(), | |
| nn.Linear(mlp_hidden_dim, mlp_out_dim), nn.Dropout(drop)) | |
| self.mlp_t = nn.Sequential( | |
| nn.Linear(dim, mlp_hidden_dim), nn.GELU(), | |
| nn.Linear(mlp_hidden_dim, mlp_out_dim), nn.Dropout(drop)) | |
| def forward(self, x, seq_len=1): | |
| if self.st_mode == 'st': | |
| x = x + self.drop_path(self.attn_s(self.norm1_s(x), seq_len)) | |
| x = x + self.drop_path(self.mlp_s(self.norm2_s(x))) | |
| x = x + self.drop_path(self.attn_t(self.norm1_t(x), seq_len)) | |
| x = x + self.drop_path(self.mlp_t(self.norm2_t(x))) | |
| elif self.st_mode == 'ts': | |
| x = x + self.drop_path(self.attn_t(self.norm1_t(x), seq_len)) | |
| x = x + self.drop_path(self.mlp_t(self.norm2_t(x))) | |
| x = x + self.drop_path(self.attn_s(self.norm1_s(x), seq_len)) | |
| x = x + self.drop_path(self.mlp_s(self.norm2_s(x))) | |
| else: | |
| raise NotImplementedError(self.st_mode) | |
| return x | |
| class DSTFormer(BaseBackbone): | |
| """Dual-stream Spatio-temporal Transformer Module. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| feat_size: Number of feature channels. Default: 256. | |
| depth: The network depth. Default: 5. | |
| num_heads: Number of heads in multi-Head self-attention blocks. | |
| Default: 8. | |
| mlp_ratio (int, optional): The expansion ratio of FFN. Default: 4. | |
| num_keypoints: num_keypoints (int): Number of keypoints. Default: 17. | |
| seq_len: The sequence length. Default: 243. | |
| qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. | |
| Default: True. | |
| qk_scale (float | None, optional): Override default qk scale of | |
| head_dim ** -0.5 if set. Default: None. | |
| drop_rate (float, optional): Dropout ratio of input. Default: 0. | |
| attn_drop_rate (float, optional): Dropout ratio of attention weight. | |
| Default: 0. | |
| drop_path_rate (float, optional): Stochastic depth rate. Default: 0. | |
| att_fuse: Whether to fuse the results of attention blocks. | |
| Default: True. | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Default: None | |
| Example: | |
| >>> from mmpose.models import DSTFormer | |
| >>> import torch | |
| >>> self = DSTFormer(in_channels=3) | |
| >>> self.eval() | |
| >>> inputs = torch.rand(1, 2, 17, 3) | |
| >>> level_outputs = self.forward(inputs) | |
| >>> print(tuple(level_outputs.shape)) | |
| (1, 2, 17, 512) | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| feat_size=256, | |
| depth=5, | |
| num_heads=8, | |
| mlp_ratio=4, | |
| num_keypoints=17, | |
| seq_len=243, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| drop_rate=0., | |
| attn_drop_rate=0., | |
| drop_path_rate=0., | |
| att_fuse=True, | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| self.in_channels = in_channels | |
| self.feat_size = feat_size | |
| self.joints_embed = nn.Linear(in_channels, feat_size) | |
| self.pos_drop = nn.Dropout(p=drop_rate) | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) | |
| ] # stochastic depth decay rule | |
| self.blocks_st = nn.ModuleList([ | |
| AttentionBlock( | |
| dim=feat_size, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| drop=drop_rate, | |
| attn_drop=attn_drop_rate, | |
| drop_path=dpr[i], | |
| st_mode='st') for i in range(depth) | |
| ]) | |
| self.blocks_ts = nn.ModuleList([ | |
| AttentionBlock( | |
| dim=feat_size, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| drop=drop_rate, | |
| attn_drop=attn_drop_rate, | |
| drop_path=dpr[i], | |
| st_mode='ts') for i in range(depth) | |
| ]) | |
| self.norm = nn.LayerNorm(feat_size, eps=1e-06) | |
| self.temp_embed = nn.Parameter(torch.zeros(1, seq_len, 1, feat_size)) | |
| self.spat_embed = nn.Parameter( | |
| torch.zeros(1, num_keypoints, feat_size)) | |
| trunc_normal_(self.temp_embed, std=.02) | |
| trunc_normal_(self.spat_embed, std=.02) | |
| self.att_fuse = att_fuse | |
| if self.att_fuse: | |
| self.attn_regress = nn.ModuleList( | |
| [nn.Linear(feat_size * 2, 2) for i in range(depth)]) | |
| for i in range(depth): | |
| self.attn_regress[i].weight.data.fill_(0) | |
| self.attn_regress[i].bias.data.fill_(0.5) | |
| def forward(self, x): | |
| if len(x.shape) == 3: | |
| x = x[None, :] | |
| assert len(x.shape) == 4 | |
| B, F, K, C = x.shape | |
| x = x.reshape(-1, K, C) | |
| BF = x.shape[0] | |
| x = self.joints_embed(x) # (BF, K, feat_size) | |
| x = x + self.spat_embed | |
| _, K, C = x.shape | |
| x = x.reshape(-1, F, K, C) + self.temp_embed[:, :F, :, :] | |
| x = x.reshape(BF, K, C) # (BF, K, feat_size) | |
| x = self.pos_drop(x) | |
| for idx, (blk_st, | |
| blk_ts) in enumerate(zip(self.blocks_st, self.blocks_ts)): | |
| x_st = blk_st(x, F) | |
| x_ts = blk_ts(x, F) | |
| if self.att_fuse: | |
| att = self.attn_regress[idx] | |
| alpha = torch.cat([x_st, x_ts], dim=-1) | |
| BF, K = alpha.shape[:2] | |
| alpha = att(alpha) | |
| alpha = alpha.softmax(dim=-1) | |
| x = x_st * alpha[:, :, 0:1] + x_ts * alpha[:, :, 1:2] | |
| else: | |
| x = (x_st + x_ts) * 0.5 | |
| x = self.norm(x) # (BF, K, feat_size) | |
| x = x.reshape(B, F, K, -1) | |
| return x | |
| def init_weights(self): | |
| """Initialize the weights in backbone.""" | |
| super(DSTFormer, self).init_weights() | |
| if (isinstance(self.init_cfg, dict) | |
| and self.init_cfg['type'] == 'Pretrained'): | |
| return | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| constant_init(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| constant_init(m.bias, 0) | |
| constant_init(m.weight, 1.0) | |