Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import mmcv | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, | |
normal_init) | |
from mmcv.utils import digit_version | |
from torch.nn.modules.batchnorm import _BatchNorm | |
from mmpose.models.utils.ops import resize | |
from ..backbones.resnet import BasicBlock, Bottleneck | |
from ..builder import NECKS | |
try: | |
from mmcv.ops import DeformConv2d | |
has_mmcv_full = True | |
except (ImportError, ModuleNotFoundError): | |
has_mmcv_full = False | |
class PoseWarperNeck(nn.Module): | |
"""PoseWarper neck. | |
`"Learning temporal pose estimation from sparsely-labeled videos" | |
<https://arxiv.org/abs/1906.04016>`_. | |
Args: | |
in_channels (int): Number of input channels from backbone | |
out_channels (int): Number of output channels | |
inner_channels (int): Number of intermediate channels of the res block | |
deform_groups (int): Number of groups in the deformable conv | |
dilations (list|tuple): different dilations of the offset conv layers | |
trans_conv_kernel (int): the kernel of the trans conv layer, which is | |
used to get heatmap from the output of backbone. Default: 1 | |
res_blocks_cfg (dict|None): config of residual blocks. If None, | |
use the default values. If not None, it should contain the | |
following keys: | |
- block (str): the type of residual block, Default: 'BASIC'. | |
- num_blocks (int): the number of blocks, Default: 20. | |
offsets_kernel (int): the kernel of offset conv layer. | |
deform_conv_kernel (int): the kernel of defomrable conv layer. | |
in_index (int|Sequence[int]): Input feature index. Default: 0 | |
input_transform (str|None): Transformation type of input features. | |
Options: 'resize_concat', 'multiple_select', None. | |
Default: None. | |
- 'resize_concat': Multiple feature maps will be resize to \ | |
the same size as first one and than concat together. \ | |
Usually used in FCN head of HRNet. | |
- 'multiple_select': Multiple feature maps will be bundle into \ | |
a list and passed into decode head. | |
- None: Only one select feature map is allowed. | |
freeze_trans_layer (bool): Whether to freeze the transition layer | |
(stop grad and set eval mode). Default: True. | |
norm_eval (bool): Whether to set norm layers to eval mode, namely, | |
freeze running stats (mean and var). Note: Effect on Batch Norm | |
and its variants only. Default: False. | |
im2col_step (int): the argument `im2col_step` in deformable conv, | |
Default: 80. | |
""" | |
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} | |
minimum_mmcv_version = '1.3.17' | |
def __init__(self, | |
in_channels, | |
out_channels, | |
inner_channels, | |
deform_groups=17, | |
dilations=(3, 6, 12, 18, 24), | |
trans_conv_kernel=1, | |
res_blocks_cfg=None, | |
offsets_kernel=3, | |
deform_conv_kernel=3, | |
in_index=0, | |
input_transform=None, | |
freeze_trans_layer=True, | |
norm_eval=False, | |
im2col_step=80): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.inner_channels = inner_channels | |
self.deform_groups = deform_groups | |
self.dilations = dilations | |
self.trans_conv_kernel = trans_conv_kernel | |
self.res_blocks_cfg = res_blocks_cfg | |
self.offsets_kernel = offsets_kernel | |
self.deform_conv_kernel = deform_conv_kernel | |
self.in_index = in_index | |
self.input_transform = input_transform | |
self.freeze_trans_layer = freeze_trans_layer | |
self.norm_eval = norm_eval | |
self.im2col_step = im2col_step | |
identity_trans_layer = False | |
assert trans_conv_kernel in [0, 1, 3] | |
kernel_size = trans_conv_kernel | |
if kernel_size == 3: | |
padding = 1 | |
elif kernel_size == 1: | |
padding = 0 | |
else: | |
# 0 for Identity mapping. | |
identity_trans_layer = True | |
if identity_trans_layer: | |
self.trans_layer = nn.Identity() | |
else: | |
self.trans_layer = build_conv_layer( | |
cfg=dict(type='Conv2d'), | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=padding) | |
# build chain of residual blocks | |
if res_blocks_cfg is not None and not isinstance(res_blocks_cfg, dict): | |
raise TypeError('res_blocks_cfg should be dict or None.') | |
if res_blocks_cfg is None: | |
block_type = 'BASIC' | |
num_blocks = 20 | |
else: | |
block_type = res_blocks_cfg.get('block', 'BASIC') | |
num_blocks = res_blocks_cfg.get('num_blocks', 20) | |
block = self.blocks_dict[block_type] | |
res_layers = [] | |
downsample = nn.Sequential( | |
build_conv_layer( | |
cfg=dict(type='Conv2d'), | |
in_channels=out_channels, | |
out_channels=inner_channels, | |
kernel_size=1, | |
stride=1, | |
bias=False), | |
build_norm_layer(dict(type='BN'), inner_channels)[1]) | |
res_layers.append( | |
block( | |
in_channels=out_channels, | |
out_channels=inner_channels, | |
downsample=downsample)) | |
for _ in range(1, num_blocks): | |
res_layers.append(block(inner_channels, inner_channels)) | |
self.offset_feats = nn.Sequential(*res_layers) | |
# build offset layers | |
self.num_offset_layers = len(dilations) | |
assert self.num_offset_layers > 0, 'Number of offset layers ' \ | |
'should be larger than 0.' | |
target_offset_channels = 2 * offsets_kernel**2 * deform_groups | |
offset_layers = [ | |
build_conv_layer( | |
cfg=dict(type='Conv2d'), | |
in_channels=inner_channels, | |
out_channels=target_offset_channels, | |
kernel_size=offsets_kernel, | |
stride=1, | |
dilation=dilations[i], | |
padding=dilations[i], | |
bias=False, | |
) for i in range(self.num_offset_layers) | |
] | |
self.offset_layers = nn.ModuleList(offset_layers) | |
# build deformable conv layers | |
assert digit_version(mmcv.__version__) >= \ | |
digit_version(self.minimum_mmcv_version), \ | |
f'Current MMCV version: {mmcv.__version__}, ' \ | |
f'but MMCV >= {self.minimum_mmcv_version} is required, see ' \ | |
f'https://github.com/open-mmlab/mmcv/issues/1440, ' \ | |
f'Please install the latest MMCV.' | |
if has_mmcv_full: | |
deform_conv_layers = [ | |
DeformConv2d( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=deform_conv_kernel, | |
stride=1, | |
padding=int(deform_conv_kernel / 2) * dilations[i], | |
dilation=dilations[i], | |
deform_groups=deform_groups, | |
im2col_step=self.im2col_step, | |
) for i in range(self.num_offset_layers) | |
] | |
else: | |
raise ImportError('Please install the full version of mmcv ' | |
'to use `DeformConv2d`.') | |
self.deform_conv_layers = nn.ModuleList(deform_conv_layers) | |
self.freeze_layers() | |
def freeze_layers(self): | |
if self.freeze_trans_layer: | |
self.trans_layer.eval() | |
for param in self.trans_layer.parameters(): | |
param.requires_grad = False | |
def init_weights(self): | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
normal_init(m, std=0.001) | |
elif isinstance(m, (_BatchNorm, nn.GroupNorm)): | |
constant_init(m, 1) | |
elif isinstance(m, DeformConv2d): | |
filler = torch.zeros([ | |
m.weight.size(0), | |
m.weight.size(1), | |
m.weight.size(2), | |
m.weight.size(3) | |
], | |
dtype=torch.float32, | |
device=m.weight.device) | |
for k in range(m.weight.size(0)): | |
filler[k, k, | |
int(m.weight.size(2) / 2), | |
int(m.weight.size(3) / 2)] = 1.0 | |
m.weight = torch.nn.Parameter(filler) | |
m.weight.requires_grad = True | |
# posewarper offset layer weight initialization | |
for m in self.offset_layers.modules(): | |
constant_init(m, 0) | |
def _transform_inputs(self, inputs): | |
"""Transform inputs for decoder. | |
Args: | |
inputs (list[Tensor] | Tensor): multi-level img features. | |
Returns: | |
Tensor: The transformed inputs | |
""" | |
if not isinstance(inputs, list): | |
return inputs | |
if self.input_transform == 'resize_concat': | |
inputs = [inputs[i] for i in self.in_index] | |
upsampled_inputs = [ | |
resize( | |
input=x, | |
size=inputs[0].shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) for x in inputs | |
] | |
inputs = torch.cat(upsampled_inputs, dim=1) | |
elif self.input_transform == 'multiple_select': | |
inputs = [inputs[i] for i in self.in_index] | |
else: | |
inputs = inputs[self.in_index] | |
return inputs | |
def forward(self, inputs, frame_weight): | |
assert isinstance(inputs, (list, tuple)), 'PoseWarperNeck inputs ' \ | |
'should be list or tuple, even though the length is 1, ' \ | |
'for unified processing.' | |
output_heatmap = 0 | |
if len(inputs) > 1: | |
inputs = [self._transform_inputs(input) for input in inputs] | |
inputs = [self.trans_layer(input) for input in inputs] | |
# calculate difference features | |
diff_features = [ | |
self.offset_feats(inputs[0] - input) for input in inputs | |
] | |
for i in range(len(inputs)): | |
if frame_weight[i] == 0: | |
continue | |
warped_heatmap = 0 | |
for j in range(self.num_offset_layers): | |
offset = (self.offset_layers[j](diff_features[i])) | |
warped_heatmap_tmp = self.deform_conv_layers[j](inputs[i], | |
offset) | |
warped_heatmap += warped_heatmap_tmp / \ | |
self.num_offset_layers | |
output_heatmap += warped_heatmap * frame_weight[i] | |
else: | |
inputs = inputs[0] | |
inputs = self._transform_inputs(inputs) | |
inputs = self.trans_layer(inputs) | |
num_frames = len(frame_weight) | |
batch_size = inputs.size(0) // num_frames | |
ref_x = inputs[:batch_size] | |
ref_x_tiled = ref_x.repeat(num_frames, 1, 1, 1) | |
offset_features = self.offset_feats(ref_x_tiled - inputs) | |
warped_heatmap = 0 | |
for j in range(self.num_offset_layers): | |
offset = self.offset_layers[j](offset_features) | |
warped_heatmap_tmp = self.deform_conv_layers[j](inputs, offset) | |
warped_heatmap += warped_heatmap_tmp / self.num_offset_layers | |
for i in range(num_frames): | |
if frame_weight[i] == 0: | |
continue | |
output_heatmap += warped_heatmap[i * batch_size:(i + 1) * | |
batch_size] * frame_weight[i] | |
return output_heatmap | |
def train(self, mode=True): | |
"""Convert the model into training mode.""" | |
super().train(mode) | |
self.freeze_layers() | |
if mode and self.norm_eval: | |
for m in self.modules(): | |
if isinstance(m, _BatchNorm): | |
m.eval() | |