Spaces:
Build error
Build error
# ------------------------------------------------------------------------------ | |
# Adapted from https://github.com/HRNet/Lite-HRNet | |
# Original licence: Apache License 2.0. | |
# ------------------------------------------------------------------------------ | |
import mmcv | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.checkpoint as cp | |
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, | |
build_conv_layer, build_norm_layer, constant_init, | |
normal_init) | |
from torch.nn.modules.batchnorm import _BatchNorm | |
from mmpose.utils import get_root_logger | |
from ..builder import BACKBONES | |
from .utils import channel_shuffle, load_checkpoint | |
class SpatialWeighting(nn.Module): | |
"""Spatial weighting module. | |
Args: | |
channels (int): The channels of the module. | |
ratio (int): channel reduction ratio. | |
conv_cfg (dict): Config dict for convolution layer. | |
Default: None, which means using conv2d. | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: None. | |
act_cfg (dict): Config dict for activation layer. | |
Default: (dict(type='ReLU'), dict(type='Sigmoid')). | |
The last ConvModule uses Sigmoid by default. | |
""" | |
def __init__(self, | |
channels, | |
ratio=16, | |
conv_cfg=None, | |
norm_cfg=None, | |
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))): | |
super().__init__() | |
if isinstance(act_cfg, dict): | |
act_cfg = (act_cfg, act_cfg) | |
assert len(act_cfg) == 2 | |
assert mmcv.is_tuple_of(act_cfg, dict) | |
self.global_avgpool = nn.AdaptiveAvgPool2d(1) | |
self.conv1 = ConvModule( | |
in_channels=channels, | |
out_channels=int(channels / ratio), | |
kernel_size=1, | |
stride=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg[0]) | |
self.conv2 = ConvModule( | |
in_channels=int(channels / ratio), | |
out_channels=channels, | |
kernel_size=1, | |
stride=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg[1]) | |
def forward(self, x): | |
out = self.global_avgpool(x) | |
out = self.conv1(out) | |
out = self.conv2(out) | |
return x * out | |
class CrossResolutionWeighting(nn.Module): | |
"""Cross-resolution channel weighting module. | |
Args: | |
channels (int): The channels of the module. | |
ratio (int): channel reduction ratio. | |
conv_cfg (dict): Config dict for convolution layer. | |
Default: None, which means using conv2d. | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: None. | |
act_cfg (dict): Config dict for activation layer. | |
Default: (dict(type='ReLU'), dict(type='Sigmoid')). | |
The last ConvModule uses Sigmoid by default. | |
""" | |
def __init__(self, | |
channels, | |
ratio=16, | |
conv_cfg=None, | |
norm_cfg=None, | |
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))): | |
super().__init__() | |
if isinstance(act_cfg, dict): | |
act_cfg = (act_cfg, act_cfg) | |
assert len(act_cfg) == 2 | |
assert mmcv.is_tuple_of(act_cfg, dict) | |
self.channels = channels | |
total_channel = sum(channels) | |
self.conv1 = ConvModule( | |
in_channels=total_channel, | |
out_channels=int(total_channel / ratio), | |
kernel_size=1, | |
stride=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg[0]) | |
self.conv2 = ConvModule( | |
in_channels=int(total_channel / ratio), | |
out_channels=total_channel, | |
kernel_size=1, | |
stride=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg[1]) | |
def forward(self, x): | |
mini_size = x[-1].size()[-2:] | |
out = [F.adaptive_avg_pool2d(s, mini_size) for s in x[:-1]] + [x[-1]] | |
out = torch.cat(out, dim=1) | |
out = self.conv1(out) | |
out = self.conv2(out) | |
out = torch.split(out, self.channels, dim=1) | |
out = [ | |
s * F.interpolate(a, size=s.size()[-2:], mode='nearest') | |
for s, a in zip(x, out) | |
] | |
return out | |
class ConditionalChannelWeighting(nn.Module): | |
"""Conditional channel weighting block. | |
Args: | |
in_channels (int): The input channels of the block. | |
stride (int): Stride of the 3x3 convolution layer. | |
reduce_ratio (int): channel reduction ratio. | |
conv_cfg (dict): Config dict for convolution layer. | |
Default: None, which means using conv2d. | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: dict(type='BN'). | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
memory while slowing down the training speed. Default: False. | |
""" | |
def __init__(self, | |
in_channels, | |
stride, | |
reduce_ratio, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
with_cp=False): | |
super().__init__() | |
self.with_cp = with_cp | |
self.stride = stride | |
assert stride in [1, 2] | |
branch_channels = [channel // 2 for channel in in_channels] | |
self.cross_resolution_weighting = CrossResolutionWeighting( | |
branch_channels, | |
ratio=reduce_ratio, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg) | |
self.depthwise_convs = nn.ModuleList([ | |
ConvModule( | |
channel, | |
channel, | |
kernel_size=3, | |
stride=self.stride, | |
padding=1, | |
groups=channel, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None) for channel in branch_channels | |
]) | |
self.spatial_weighting = nn.ModuleList([ | |
SpatialWeighting(channels=channel, ratio=4) | |
for channel in branch_channels | |
]) | |
def forward(self, x): | |
def _inner_forward(x): | |
x = [s.chunk(2, dim=1) for s in x] | |
x1 = [s[0] for s in x] | |
x2 = [s[1] for s in x] | |
x2 = self.cross_resolution_weighting(x2) | |
x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)] | |
x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)] | |
out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)] | |
out = [channel_shuffle(s, 2) for s in out] | |
return out | |
if self.with_cp and x.requires_grad: | |
out = cp.checkpoint(_inner_forward, x) | |
else: | |
out = _inner_forward(x) | |
return out | |
class Stem(nn.Module): | |
"""Stem network block. | |
Args: | |
in_channels (int): The input channels of the block. | |
stem_channels (int): Output channels of the stem layer. | |
out_channels (int): The output channels of the block. | |
expand_ratio (int): adjusts number of channels of the hidden layer | |
in InvertedResidual by this amount. | |
conv_cfg (dict): Config dict for convolution layer. | |
Default: None, which means using conv2d. | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: dict(type='BN'). | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
memory while slowing down the training speed. Default: False. | |
""" | |
def __init__(self, | |
in_channels, | |
stem_channels, | |
out_channels, | |
expand_ratio, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
with_cp=False): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.with_cp = with_cp | |
self.conv1 = ConvModule( | |
in_channels=in_channels, | |
out_channels=stem_channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=dict(type='ReLU')) | |
mid_channels = int(round(stem_channels * expand_ratio)) | |
branch_channels = stem_channels // 2 | |
if stem_channels == self.out_channels: | |
inc_channels = self.out_channels - branch_channels | |
else: | |
inc_channels = self.out_channels - stem_channels | |
self.branch1 = nn.Sequential( | |
ConvModule( | |
branch_channels, | |
branch_channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
groups=branch_channels, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None), | |
ConvModule( | |
branch_channels, | |
inc_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=dict(type='ReLU')), | |
) | |
self.expand_conv = ConvModule( | |
branch_channels, | |
mid_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=dict(type='ReLU')) | |
self.depthwise_conv = ConvModule( | |
mid_channels, | |
mid_channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
groups=mid_channels, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None) | |
self.linear_conv = ConvModule( | |
mid_channels, | |
branch_channels | |
if stem_channels == self.out_channels else stem_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=dict(type='ReLU')) | |
def forward(self, x): | |
def _inner_forward(x): | |
x = self.conv1(x) | |
x1, x2 = x.chunk(2, dim=1) | |
x2 = self.expand_conv(x2) | |
x2 = self.depthwise_conv(x2) | |
x2 = self.linear_conv(x2) | |
out = torch.cat((self.branch1(x1), x2), dim=1) | |
out = channel_shuffle(out, 2) | |
return out | |
if self.with_cp and x.requires_grad: | |
out = cp.checkpoint(_inner_forward, x) | |
else: | |
out = _inner_forward(x) | |
return out | |
class IterativeHead(nn.Module): | |
"""Extra iterative head for feature learning. | |
Args: | |
in_channels (int): The input channels of the block. | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: dict(type='BN'). | |
""" | |
def __init__(self, in_channels, norm_cfg=dict(type='BN')): | |
super().__init__() | |
projects = [] | |
num_branchs = len(in_channels) | |
self.in_channels = in_channels[::-1] | |
for i in range(num_branchs): | |
if i != num_branchs - 1: | |
projects.append( | |
DepthwiseSeparableConvModule( | |
in_channels=self.in_channels[i], | |
out_channels=self.in_channels[i + 1], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
norm_cfg=norm_cfg, | |
act_cfg=dict(type='ReLU'), | |
dw_act_cfg=None, | |
pw_act_cfg=dict(type='ReLU'))) | |
else: | |
projects.append( | |
DepthwiseSeparableConvModule( | |
in_channels=self.in_channels[i], | |
out_channels=self.in_channels[i], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
norm_cfg=norm_cfg, | |
act_cfg=dict(type='ReLU'), | |
dw_act_cfg=None, | |
pw_act_cfg=dict(type='ReLU'))) | |
self.projects = nn.ModuleList(projects) | |
def forward(self, x): | |
x = x[::-1] | |
y = [] | |
last_x = None | |
for i, s in enumerate(x): | |
if last_x is not None: | |
last_x = F.interpolate( | |
last_x, | |
size=s.size()[-2:], | |
mode='bilinear', | |
align_corners=True) | |
s = s + last_x | |
s = self.projects[i](s) | |
y.append(s) | |
last_x = s | |
return y[::-1] | |
class ShuffleUnit(nn.Module): | |
"""InvertedResidual block for ShuffleNetV2 backbone. | |
Args: | |
in_channels (int): The input channels of the block. | |
out_channels (int): The output channels of the block. | |
stride (int): Stride of the 3x3 convolution layer. Default: 1 | |
conv_cfg (dict): Config dict for convolution layer. | |
Default: None, which means using conv2d. | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: dict(type='BN'). | |
act_cfg (dict): Config dict for activation layer. | |
Default: dict(type='ReLU'). | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
memory while slowing down the training speed. Default: False. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
stride=1, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
with_cp=False): | |
super().__init__() | |
self.stride = stride | |
self.with_cp = with_cp | |
branch_features = out_channels // 2 | |
if self.stride == 1: | |
assert in_channels == branch_features * 2, ( | |
f'in_channels ({in_channels}) should equal to ' | |
f'branch_features * 2 ({branch_features * 2}) ' | |
'when stride is 1') | |
if in_channels != branch_features * 2: | |
assert self.stride != 1, ( | |
f'stride ({self.stride}) should not equal 1 when ' | |
f'in_channels != branch_features * 2') | |
if self.stride > 1: | |
self.branch1 = nn.Sequential( | |
ConvModule( | |
in_channels, | |
in_channels, | |
kernel_size=3, | |
stride=self.stride, | |
padding=1, | |
groups=in_channels, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None), | |
ConvModule( | |
in_channels, | |
branch_features, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg), | |
) | |
self.branch2 = nn.Sequential( | |
ConvModule( | |
in_channels if (self.stride > 1) else branch_features, | |
branch_features, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg), | |
ConvModule( | |
branch_features, | |
branch_features, | |
kernel_size=3, | |
stride=self.stride, | |
padding=1, | |
groups=branch_features, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None), | |
ConvModule( | |
branch_features, | |
branch_features, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
def forward(self, x): | |
def _inner_forward(x): | |
if self.stride > 1: | |
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) | |
else: | |
x1, x2 = x.chunk(2, dim=1) | |
out = torch.cat((x1, self.branch2(x2)), dim=1) | |
out = channel_shuffle(out, 2) | |
return out | |
if self.with_cp and x.requires_grad: | |
out = cp.checkpoint(_inner_forward, x) | |
else: | |
out = _inner_forward(x) | |
return out | |
class LiteHRModule(nn.Module): | |
"""High-Resolution Module for LiteHRNet. | |
It contains conditional channel weighting blocks and | |
shuffle blocks. | |
Args: | |
num_branches (int): Number of branches in the module. | |
num_blocks (int): Number of blocks in the module. | |
in_channels (list(int)): Number of input image channels. | |
reduce_ratio (int): Channel reduction ratio. | |
module_type (str): 'LITE' or 'NAIVE' | |
multiscale_output (bool): Whether to output multi-scale features. | |
with_fuse (bool): Whether to use fuse layers. | |
conv_cfg (dict): dictionary to construct and config conv layer. | |
norm_cfg (dict): dictionary to construct and config norm layer. | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
memory while slowing down the training speed. | |
""" | |
def __init__( | |
self, | |
num_branches, | |
num_blocks, | |
in_channels, | |
reduce_ratio, | |
module_type, | |
multiscale_output=False, | |
with_fuse=True, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
with_cp=False, | |
): | |
super().__init__() | |
self._check_branches(num_branches, in_channels) | |
self.in_channels = in_channels | |
self.num_branches = num_branches | |
self.module_type = module_type | |
self.multiscale_output = multiscale_output | |
self.with_fuse = with_fuse | |
self.norm_cfg = norm_cfg | |
self.conv_cfg = conv_cfg | |
self.with_cp = with_cp | |
if self.module_type.upper() == 'LITE': | |
self.layers = self._make_weighting_blocks(num_blocks, reduce_ratio) | |
elif self.module_type.upper() == 'NAIVE': | |
self.layers = self._make_naive_branches(num_branches, num_blocks) | |
else: | |
raise ValueError("module_type should be either 'LITE' or 'NAIVE'.") | |
if self.with_fuse: | |
self.fuse_layers = self._make_fuse_layers() | |
self.relu = nn.ReLU() | |
def _check_branches(self, num_branches, in_channels): | |
"""Check input to avoid ValueError.""" | |
if num_branches != len(in_channels): | |
error_msg = f'NUM_BRANCHES({num_branches}) ' \ | |
f'!= NUM_INCHANNELS({len(in_channels)})' | |
raise ValueError(error_msg) | |
def _make_weighting_blocks(self, num_blocks, reduce_ratio, stride=1): | |
"""Make channel weighting blocks.""" | |
layers = [] | |
for i in range(num_blocks): | |
layers.append( | |
ConditionalChannelWeighting( | |
self.in_channels, | |
stride=stride, | |
reduce_ratio=reduce_ratio, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
with_cp=self.with_cp)) | |
return nn.Sequential(*layers) | |
def _make_one_branch(self, branch_index, num_blocks, stride=1): | |
"""Make one branch.""" | |
layers = [] | |
layers.append( | |
ShuffleUnit( | |
self.in_channels[branch_index], | |
self.in_channels[branch_index], | |
stride=stride, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=dict(type='ReLU'), | |
with_cp=self.with_cp)) | |
for i in range(1, num_blocks): | |
layers.append( | |
ShuffleUnit( | |
self.in_channels[branch_index], | |
self.in_channels[branch_index], | |
stride=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=dict(type='ReLU'), | |
with_cp=self.with_cp)) | |
return nn.Sequential(*layers) | |
def _make_naive_branches(self, num_branches, num_blocks): | |
"""Make branches.""" | |
branches = [] | |
for i in range(num_branches): | |
branches.append(self._make_one_branch(i, num_blocks)) | |
return nn.ModuleList(branches) | |
def _make_fuse_layers(self): | |
"""Make fuse layer.""" | |
if self.num_branches == 1: | |
return None | |
num_branches = self.num_branches | |
in_channels = self.in_channels | |
fuse_layers = [] | |
num_out_branches = num_branches if self.multiscale_output else 1 | |
for i in range(num_out_branches): | |
fuse_layer = [] | |
for j in range(num_branches): | |
if j > i: | |
fuse_layer.append( | |
nn.Sequential( | |
build_conv_layer( | |
self.conv_cfg, | |
in_channels[j], | |
in_channels[i], | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False), | |
build_norm_layer(self.norm_cfg, in_channels[i])[1], | |
nn.Upsample( | |
scale_factor=2**(j - i), mode='nearest'))) | |
elif j == i: | |
fuse_layer.append(None) | |
else: | |
conv_downsamples = [] | |
for k in range(i - j): | |
if k == i - j - 1: | |
conv_downsamples.append( | |
nn.Sequential( | |
build_conv_layer( | |
self.conv_cfg, | |
in_channels[j], | |
in_channels[j], | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
groups=in_channels[j], | |
bias=False), | |
build_norm_layer(self.norm_cfg, | |
in_channels[j])[1], | |
build_conv_layer( | |
self.conv_cfg, | |
in_channels[j], | |
in_channels[i], | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False), | |
build_norm_layer(self.norm_cfg, | |
in_channels[i])[1])) | |
else: | |
conv_downsamples.append( | |
nn.Sequential( | |
build_conv_layer( | |
self.conv_cfg, | |
in_channels[j], | |
in_channels[j], | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
groups=in_channels[j], | |
bias=False), | |
build_norm_layer(self.norm_cfg, | |
in_channels[j])[1], | |
build_conv_layer( | |
self.conv_cfg, | |
in_channels[j], | |
in_channels[j], | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False), | |
build_norm_layer(self.norm_cfg, | |
in_channels[j])[1], | |
nn.ReLU(inplace=True))) | |
fuse_layer.append(nn.Sequential(*conv_downsamples)) | |
fuse_layers.append(nn.ModuleList(fuse_layer)) | |
return nn.ModuleList(fuse_layers) | |
def forward(self, x): | |
"""Forward function.""" | |
if self.num_branches == 1: | |
return [self.layers[0](x[0])] | |
if self.module_type.upper() == 'LITE': | |
out = self.layers(x) | |
elif self.module_type.upper() == 'NAIVE': | |
for i in range(self.num_branches): | |
x[i] = self.layers[i](x[i]) | |
out = x | |
if self.with_fuse: | |
out_fuse = [] | |
for i in range(len(self.fuse_layers)): | |
# `y = 0` will lead to decreased accuracy (0.5~1 mAP) | |
y = out[0] if i == 0 else self.fuse_layers[i][0](out[0]) | |
for j in range(self.num_branches): | |
if i == j: | |
y += out[j] | |
else: | |
y += self.fuse_layers[i][j](out[j]) | |
out_fuse.append(self.relu(y)) | |
out = out_fuse | |
if not self.multiscale_output: | |
out = [out[0]] | |
return out | |
class LiteHRNet(nn.Module): | |
"""Lite-HRNet backbone. | |
`Lite-HRNet: A Lightweight High-Resolution Network | |
<https://arxiv.org/abs/2104.06403>`_. | |
Code adapted from 'https://github.com/HRNet/Lite-HRNet'. | |
Args: | |
extra (dict): detailed configuration for each stage of HRNet. | |
in_channels (int): Number of input image channels. Default: 3. | |
conv_cfg (dict): dictionary to construct and config conv layer. | |
norm_cfg (dict): dictionary to construct and config norm layer. | |
norm_eval (bool): Whether to set norm layers to eval mode, namely, | |
freeze running stats (mean and var). Note: Effect on Batch Norm | |
and its variants only. Default: False | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
memory while slowing down the training speed. | |
Example: | |
>>> from mmpose.models import LiteHRNet | |
>>> import torch | |
>>> extra=dict( | |
>>> stem=dict(stem_channels=32, out_channels=32, expand_ratio=1), | |
>>> num_stages=3, | |
>>> stages_spec=dict( | |
>>> num_modules=(2, 4, 2), | |
>>> num_branches=(2, 3, 4), | |
>>> num_blocks=(2, 2, 2), | |
>>> module_type=('LITE', 'LITE', 'LITE'), | |
>>> with_fuse=(True, True, True), | |
>>> reduce_ratios=(8, 8, 8), | |
>>> num_channels=( | |
>>> (40, 80), | |
>>> (40, 80, 160), | |
>>> (40, 80, 160, 320), | |
>>> )), | |
>>> with_head=False) | |
>>> self = LiteHRNet(extra, in_channels=1) | |
>>> self.eval() | |
>>> inputs = torch.rand(1, 1, 32, 32) | |
>>> level_outputs = self.forward(inputs) | |
>>> for level_out in level_outputs: | |
... print(tuple(level_out.shape)) | |
(1, 40, 8, 8) | |
""" | |
def __init__(self, | |
extra, | |
in_channels=3, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
norm_eval=False, | |
with_cp=False): | |
super().__init__() | |
self.extra = extra | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.norm_eval = norm_eval | |
self.with_cp = with_cp | |
self.stem = Stem( | |
in_channels, | |
stem_channels=self.extra['stem']['stem_channels'], | |
out_channels=self.extra['stem']['out_channels'], | |
expand_ratio=self.extra['stem']['expand_ratio'], | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg) | |
self.num_stages = self.extra['num_stages'] | |
self.stages_spec = self.extra['stages_spec'] | |
num_channels_last = [ | |
self.stem.out_channels, | |
] | |
for i in range(self.num_stages): | |
num_channels = self.stages_spec['num_channels'][i] | |
num_channels = [num_channels[i] for i in range(len(num_channels))] | |
setattr( | |
self, f'transition{i}', | |
self._make_transition_layer(num_channels_last, num_channels)) | |
stage, num_channels_last = self._make_stage( | |
self.stages_spec, i, num_channels, multiscale_output=True) | |
setattr(self, f'stage{i}', stage) | |
self.with_head = self.extra['with_head'] | |
if self.with_head: | |
self.head_layer = IterativeHead( | |
in_channels=num_channels_last, | |
norm_cfg=self.norm_cfg, | |
) | |
def _make_transition_layer(self, num_channels_pre_layer, | |
num_channels_cur_layer): | |
"""Make transition layer.""" | |
num_branches_cur = len(num_channels_cur_layer) | |
num_branches_pre = len(num_channels_pre_layer) | |
transition_layers = [] | |
for i in range(num_branches_cur): | |
if i < num_branches_pre: | |
if num_channels_cur_layer[i] != num_channels_pre_layer[i]: | |
transition_layers.append( | |
nn.Sequential( | |
build_conv_layer( | |
self.conv_cfg, | |
num_channels_pre_layer[i], | |
num_channels_pre_layer[i], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
groups=num_channels_pre_layer[i], | |
bias=False), | |
build_norm_layer(self.norm_cfg, | |
num_channels_pre_layer[i])[1], | |
build_conv_layer( | |
self.conv_cfg, | |
num_channels_pre_layer[i], | |
num_channels_cur_layer[i], | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False), | |
build_norm_layer(self.norm_cfg, | |
num_channels_cur_layer[i])[1], | |
nn.ReLU())) | |
else: | |
transition_layers.append(None) | |
else: | |
conv_downsamples = [] | |
for j in range(i + 1 - num_branches_pre): | |
in_channels = num_channels_pre_layer[-1] | |
out_channels = num_channels_cur_layer[i] \ | |
if j == i - num_branches_pre else in_channels | |
conv_downsamples.append( | |
nn.Sequential( | |
build_conv_layer( | |
self.conv_cfg, | |
in_channels, | |
in_channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
groups=in_channels, | |
bias=False), | |
build_norm_layer(self.norm_cfg, in_channels)[1], | |
build_conv_layer( | |
self.conv_cfg, | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False), | |
build_norm_layer(self.norm_cfg, out_channels)[1], | |
nn.ReLU())) | |
transition_layers.append(nn.Sequential(*conv_downsamples)) | |
return nn.ModuleList(transition_layers) | |
def _make_stage(self, | |
stages_spec, | |
stage_index, | |
in_channels, | |
multiscale_output=True): | |
num_modules = stages_spec['num_modules'][stage_index] | |
num_branches = stages_spec['num_branches'][stage_index] | |
num_blocks = stages_spec['num_blocks'][stage_index] | |
reduce_ratio = stages_spec['reduce_ratios'][stage_index] | |
with_fuse = stages_spec['with_fuse'][stage_index] | |
module_type = stages_spec['module_type'][stage_index] | |
modules = [] | |
for i in range(num_modules): | |
# multi_scale_output is only used last module | |
if not multiscale_output and i == num_modules - 1: | |
reset_multiscale_output = False | |
else: | |
reset_multiscale_output = True | |
modules.append( | |
LiteHRModule( | |
num_branches, | |
num_blocks, | |
in_channels, | |
reduce_ratio, | |
module_type, | |
multiscale_output=reset_multiscale_output, | |
with_fuse=with_fuse, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
with_cp=self.with_cp)) | |
in_channels = modules[-1].in_channels | |
return nn.Sequential(*modules), in_channels | |
def init_weights(self, pretrained=None): | |
"""Initialize the weights in backbone. | |
Args: | |
pretrained (str, optional): Path to pre-trained weights. | |
Defaults to None. | |
""" | |
if isinstance(pretrained, str): | |
logger = get_root_logger() | |
load_checkpoint(self, pretrained, strict=False, logger=logger) | |
elif pretrained is None: | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
normal_init(m, std=0.001) | |
elif isinstance(m, (_BatchNorm, nn.GroupNorm)): | |
constant_init(m, 1) | |
else: | |
raise TypeError('pretrained must be a str or None') | |
def forward(self, x): | |
"""Forward function.""" | |
x = self.stem(x) | |
y_list = [x] | |
for i in range(self.num_stages): | |
x_list = [] | |
transition = getattr(self, f'transition{i}') | |
for j in range(self.stages_spec['num_branches'][i]): | |
if transition[j]: | |
if j >= len(y_list): | |
x_list.append(transition[j](y_list[-1])) | |
else: | |
x_list.append(transition[j](y_list[j])) | |
else: | |
x_list.append(y_list[j]) | |
y_list = getattr(self, f'stage{i}')(x_list) | |
x = y_list | |
if self.with_head: | |
x = self.head_layer(x) | |
return [x[0]] | |
def train(self, mode=True): | |
"""Convert the model into training mode.""" | |
super().train(mode) | |
if mode and self.norm_eval: | |
for m in self.modules(): | |
if isinstance(m, _BatchNorm): | |
m.eval() | |