# Copyright (c) OpenMMLab. All rights reserved. import copy as cp import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import (ConvModule, MaxPool2d, constant_init, kaiming_init, normal_init) from ..builder import BACKBONES from .base_backbone import BaseBackbone class RSB(nn.Module): """Residual Steps block for RSN. Paper ref: Cai et al. "Learning Delicate Local Representations for Multi-Person Pose Estimation" (ECCV 2020). Args: in_channels (int): Input channels of this block. out_channels (int): Output channels of this block. num_steps (int): Numbers of steps in RSB stride (int): stride of the block. Default: 1 downsample (nn.Module): downsample operation on identity branch. Default: None. norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN') expand_times (int): Times by which the in_channels are expanded. Default:26. res_top_channels (int): Number of channels of feature output by ResNet_top. Default:64. """ expansion = 1 def __init__(self, in_channels, out_channels, num_steps=4, stride=1, downsample=None, with_cp=False, norm_cfg=dict(type='BN'), expand_times=26, res_top_channels=64): # Protect mutable default arguments norm_cfg = cp.deepcopy(norm_cfg) super().__init__() assert num_steps > 1 self.in_channels = in_channels self.branch_channels = self.in_channels * expand_times self.branch_channels //= res_top_channels self.out_channels = out_channels self.stride = stride self.downsample = downsample self.with_cp = with_cp self.norm_cfg = norm_cfg self.num_steps = num_steps self.conv_bn_relu1 = ConvModule( self.in_channels, self.num_steps * self.branch_channels, kernel_size=1, stride=self.stride, padding=0, norm_cfg=self.norm_cfg, inplace=False) for i in range(self.num_steps): for j in range(i + 1): module_name = f'conv_bn_relu2_{i + 1}_{j + 1}' self.add_module( module_name, ConvModule( self.branch_channels, self.branch_channels, kernel_size=3, stride=1, padding=1, norm_cfg=self.norm_cfg, inplace=False)) self.conv_bn3 = ConvModule( self.num_steps * self.branch_channels, self.out_channels * self.expansion, kernel_size=1, stride=1, padding=0, act_cfg=None, norm_cfg=self.norm_cfg, inplace=False) self.relu = nn.ReLU(inplace=False) def forward(self, x): """Forward function.""" identity = x x = self.conv_bn_relu1(x) spx = torch.split(x, self.branch_channels, 1) outputs = list() outs = list() for i in range(self.num_steps): outputs_i = list() outputs.append(outputs_i) for j in range(i + 1): if j == 0: inputs = spx[i] else: inputs = outputs[i][j - 1] if i > j: inputs = inputs + outputs[i - 1][j] module_name = f'conv_bn_relu2_{i + 1}_{j + 1}' module_i_j = getattr(self, module_name) outputs[i].append(module_i_j(inputs)) outs.append(outputs[i][i]) out = torch.cat(tuple(outs), 1) out = self.conv_bn3(out) if self.downsample is not None: identity = self.downsample(identity) out = out + identity out = self.relu(out) return out class Downsample_module(nn.Module): """Downsample module for RSN. Args: block (nn.Module): Downsample block. num_blocks (list): Number of blocks in each downsample unit. num_units (int): Numbers of downsample units. Default: 4 has_skip (bool): Have skip connections from prior upsample module or not. Default:False num_steps (int): Number of steps in a block. Default:4 norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN') in_channels (int): Number of channels of the input feature to downsample module. Default: 64 expand_times (int): Times by which the in_channels are expanded. Default:26. """ def __init__(self, block, num_blocks, num_steps=4, num_units=4, has_skip=False, norm_cfg=dict(type='BN'), in_channels=64, expand_times=26): # Protect mutable default arguments norm_cfg = cp.deepcopy(norm_cfg) super().__init__() self.has_skip = has_skip self.in_channels = in_channels assert len(num_blocks) == num_units self.num_blocks = num_blocks self.num_units = num_units self.num_steps = num_steps self.norm_cfg = norm_cfg self.layer1 = self._make_layer( block, in_channels, num_blocks[0], expand_times=expand_times, res_top_channels=in_channels) for i in range(1, num_units): module_name = f'layer{i + 1}' self.add_module( module_name, self._make_layer( block, in_channels * pow(2, i), num_blocks[i], stride=2, expand_times=expand_times, res_top_channels=in_channels)) def _make_layer(self, block, out_channels, blocks, stride=1, expand_times=26, res_top_channels=64): downsample = None if stride != 1 or self.in_channels != out_channels * block.expansion: downsample = ConvModule( self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, padding=0, norm_cfg=self.norm_cfg, act_cfg=None, inplace=True) units = list() units.append( block( self.in_channels, out_channels, num_steps=self.num_steps, stride=stride, downsample=downsample, norm_cfg=self.norm_cfg, expand_times=expand_times, res_top_channels=res_top_channels)) self.in_channels = out_channels * block.expansion for _ in range(1, blocks): units.append( block( self.in_channels, out_channels, num_steps=self.num_steps, expand_times=expand_times, res_top_channels=res_top_channels)) return nn.Sequential(*units) def forward(self, x, skip1, skip2): out = list() for i in range(self.num_units): module_name = f'layer{i + 1}' module_i = getattr(self, module_name) x = module_i(x) if self.has_skip: x = x + skip1[i] + skip2[i] out.append(x) out.reverse() return tuple(out) class Upsample_unit(nn.Module): """Upsample unit for upsample module. Args: ind (int): Indicates whether to interpolate (>0) and whether to generate feature map for the next hourglass-like module. num_units (int): Number of units that form a upsample module. Along with ind and gen_cross_conv, nm_units is used to decide whether to generate feature map for the next hourglass-like module. in_channels (int): Channel number of the skip-in feature maps from the corresponding downsample unit. unit_channels (int): Channel number in this unit. Default:256. gen_skip: (bool): Whether or not to generate skips for the posterior downsample module. Default:False gen_cross_conv (bool): Whether to generate feature map for the next hourglass-like module. Default:False norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN') out_channels (in): Number of channels of feature output by upsample module. Must equal to in_channels of downsample module. Default:64 """ def __init__(self, ind, num_units, in_channels, unit_channels=256, gen_skip=False, gen_cross_conv=False, norm_cfg=dict(type='BN'), out_channels=64): # Protect mutable default arguments norm_cfg = cp.deepcopy(norm_cfg) super().__init__() self.num_units = num_units self.norm_cfg = norm_cfg self.in_skip = ConvModule( in_channels, unit_channels, kernel_size=1, stride=1, padding=0, norm_cfg=self.norm_cfg, act_cfg=None, inplace=True) self.relu = nn.ReLU(inplace=True) self.ind = ind if self.ind > 0: self.up_conv = ConvModule( unit_channels, unit_channels, kernel_size=1, stride=1, padding=0, norm_cfg=self.norm_cfg, act_cfg=None, inplace=True) self.gen_skip = gen_skip if self.gen_skip: self.out_skip1 = ConvModule( in_channels, in_channels, kernel_size=1, stride=1, padding=0, norm_cfg=self.norm_cfg, inplace=True) self.out_skip2 = ConvModule( unit_channels, in_channels, kernel_size=1, stride=1, padding=0, norm_cfg=self.norm_cfg, inplace=True) self.gen_cross_conv = gen_cross_conv if self.ind == num_units - 1 and self.gen_cross_conv: self.cross_conv = ConvModule( unit_channels, out_channels, kernel_size=1, stride=1, padding=0, norm_cfg=self.norm_cfg, inplace=True) def forward(self, x, up_x): out = self.in_skip(x) if self.ind > 0: up_x = F.interpolate( up_x, size=(x.size(2), x.size(3)), mode='bilinear', align_corners=True) up_x = self.up_conv(up_x) out = out + up_x out = self.relu(out) skip1 = None skip2 = None if self.gen_skip: skip1 = self.out_skip1(x) skip2 = self.out_skip2(out) cross_conv = None if self.ind == self.num_units - 1 and self.gen_cross_conv: cross_conv = self.cross_conv(out) return out, skip1, skip2, cross_conv class Upsample_module(nn.Module): """Upsample module for RSN. Args: unit_channels (int): Channel number in the upsample units. Default:256. num_units (int): Numbers of upsample units. Default: 4 gen_skip (bool): Whether to generate skip for posterior downsample module or not. Default:False gen_cross_conv (bool): Whether to generate feature map for the next hourglass-like module. Default:False norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN') out_channels (int): Number of channels of feature output by upsample module. Must equal to in_channels of downsample module. Default:64 """ def __init__(self, unit_channels=256, num_units=4, gen_skip=False, gen_cross_conv=False, norm_cfg=dict(type='BN'), out_channels=64): # Protect mutable default arguments norm_cfg = cp.deepcopy(norm_cfg) super().__init__() self.in_channels = list() for i in range(num_units): self.in_channels.append(RSB.expansion * out_channels * pow(2, i)) self.in_channels.reverse() self.num_units = num_units self.gen_skip = gen_skip self.gen_cross_conv = gen_cross_conv self.norm_cfg = norm_cfg for i in range(num_units): module_name = f'up{i + 1}' self.add_module( module_name, Upsample_unit( i, self.num_units, self.in_channels[i], unit_channels, self.gen_skip, self.gen_cross_conv, norm_cfg=self.norm_cfg, out_channels=64)) def forward(self, x): out = list() skip1 = list() skip2 = list() cross_conv = None for i in range(self.num_units): module_i = getattr(self, f'up{i + 1}') if i == 0: outi, skip1_i, skip2_i, _ = module_i(x[i], None) elif i == self.num_units - 1: outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1]) else: outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1]) out.append(outi) skip1.append(skip1_i) skip2.append(skip2_i) skip1.reverse() skip2.reverse() return out, skip1, skip2, cross_conv class Single_stage_RSN(nn.Module): """Single_stage Residual Steps Network. Args: unit_channels (int): Channel number in the upsample units. Default:256. num_units (int): Numbers of downsample/upsample units. Default: 4 gen_skip (bool): Whether to generate skip for posterior downsample module or not. Default:False gen_cross_conv (bool): Whether to generate feature map for the next hourglass-like module. Default:False has_skip (bool): Have skip connections from prior upsample module or not. Default:False num_steps (int): Number of steps in RSB. Default: 4 num_blocks (list): Number of blocks in each downsample unit. Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks) norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN') in_channels (int): Number of channels of the feature from ResNet_Top. Default: 64. expand_times (int): Times by which the in_channels are expanded in RSB. Default:26. """ def __init__(self, has_skip=False, gen_skip=False, gen_cross_conv=False, unit_channels=256, num_units=4, num_steps=4, num_blocks=[2, 2, 2, 2], norm_cfg=dict(type='BN'), in_channels=64, expand_times=26): # Protect mutable default arguments norm_cfg = cp.deepcopy(norm_cfg) num_blocks = cp.deepcopy(num_blocks) super().__init__() assert len(num_blocks) == num_units self.has_skip = has_skip self.gen_skip = gen_skip self.gen_cross_conv = gen_cross_conv self.num_units = num_units self.num_steps = num_steps self.unit_channels = unit_channels self.num_blocks = num_blocks self.norm_cfg = norm_cfg self.downsample = Downsample_module(RSB, num_blocks, num_steps, num_units, has_skip, norm_cfg, in_channels, expand_times) self.upsample = Upsample_module(unit_channels, num_units, gen_skip, gen_cross_conv, norm_cfg, in_channels) def forward(self, x, skip1, skip2): mid = self.downsample(x, skip1, skip2) out, skip1, skip2, cross_conv = self.upsample(mid) return out, skip1, skip2, cross_conv class ResNet_top(nn.Module): """ResNet top for RSN. Args: norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN') channels (int): Number of channels of the feature output by ResNet_top. """ def __init__(self, norm_cfg=dict(type='BN'), channels=64): # Protect mutable default arguments norm_cfg = cp.deepcopy(norm_cfg) super().__init__() self.top = nn.Sequential( ConvModule( 3, channels, kernel_size=7, stride=2, padding=3, norm_cfg=norm_cfg, inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1)) def forward(self, img): return self.top(img) @BACKBONES.register_module() class RSN(BaseBackbone): """Residual Steps Network backbone. Paper ref: Cai et al. "Learning Delicate Local Representations for Multi-Person Pose Estimation" (ECCV 2020). Args: unit_channels (int): Number of Channels in an upsample unit. Default: 256 num_stages (int): Number of stages in a multi-stage RSN. Default: 4 num_units (int): NUmber of downsample/upsample units in a single-stage RSN. Default: 4 Note: Make sure num_units == len(self.num_blocks) num_blocks (list): Number of RSBs (Residual Steps Block) in each downsample unit. Default: [2, 2, 2, 2] num_steps (int): Number of steps in a RSB. Default:4 norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN') res_top_channels (int): Number of channels of feature from ResNet_top. Default: 64. expand_times (int): Times by which the in_channels are expanded in RSB. Default:26. Example: >>> from mmpose.models import RSN >>> import torch >>> self = RSN(num_stages=2,num_units=2,num_blocks=[2,2]) >>> self.eval() >>> inputs = torch.rand(1, 3, 511, 511) >>> level_outputs = self.forward(inputs) >>> for level_output in level_outputs: ... for feature in level_output: ... print(tuple(feature.shape)) ... (1, 256, 64, 64) (1, 256, 128, 128) (1, 256, 64, 64) (1, 256, 128, 128) """ def __init__(self, unit_channels=256, num_stages=4, num_units=4, num_blocks=[2, 2, 2, 2], num_steps=4, norm_cfg=dict(type='BN'), res_top_channels=64, expand_times=26): # Protect mutable default arguments norm_cfg = cp.deepcopy(norm_cfg) num_blocks = cp.deepcopy(num_blocks) super().__init__() self.unit_channels = unit_channels self.num_stages = num_stages self.num_units = num_units self.num_blocks = num_blocks self.num_steps = num_steps self.norm_cfg = norm_cfg assert self.num_stages > 0 assert self.num_steps > 1 assert self.num_units > 1 assert self.num_units == len(self.num_blocks) self.top = ResNet_top(norm_cfg=norm_cfg) self.multi_stage_rsn = nn.ModuleList([]) for i in range(self.num_stages): if i == 0: has_skip = False else: has_skip = True if i != self.num_stages - 1: gen_skip = True gen_cross_conv = True else: gen_skip = False gen_cross_conv = False self.multi_stage_rsn.append( Single_stage_RSN(has_skip, gen_skip, gen_cross_conv, unit_channels, num_units, num_steps, num_blocks, norm_cfg, res_top_channels, expand_times)) def forward(self, x): """Model forward function.""" out_feats = [] skip1 = None skip2 = None x = self.top(x) for i in range(self.num_stages): out, skip1, skip2, x = self.multi_stage_rsn[i](x, skip1, skip2) out_feats.append(out) return out_feats def init_weights(self, pretrained=None): """Initialize model weights.""" for m in self.multi_stage_rsn.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, nn.BatchNorm2d): constant_init(m, 1) elif isinstance(m, nn.Linear): normal_init(m, std=0.01) for m in self.top.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m)