|
|
|
import warnings |
|
|
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule |
|
from mmengine.model import BaseModule |
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
|
|
from mmseg.registry import MODELS |
|
from ..utils import InvertedResidual, make_divisible |
|
|
|
|
|
@MODELS.register_module() |
|
class MobileNetV2(BaseModule): |
|
"""MobileNetV2 backbone. |
|
|
|
This backbone is the implementation of |
|
`MobileNetV2: Inverted Residuals and Linear Bottlenecks |
|
<https://arxiv.org/abs/1801.04381>`_. |
|
|
|
Args: |
|
widen_factor (float): Width multiplier, multiply number of |
|
channels in each layer by this amount. Default: 1.0. |
|
strides (Sequence[int], optional): Strides of the first block of each |
|
layer. If not specified, default config in ``arch_setting`` will |
|
be used. |
|
dilations (Sequence[int]): Dilation of each layer. |
|
out_indices (None or Sequence[int]): Output from which stages. |
|
Default: (7, ). |
|
frozen_stages (int): Stages to be frozen (all param fixed). |
|
Default: -1, which means not freezing any parameters. |
|
conv_cfg (dict): Config dict for convolution layer. |
|
Default: None, which means using conv2d. |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: dict(type='BN'). |
|
act_cfg (dict): Config dict for activation layer. |
|
Default: dict(type='ReLU6'). |
|
norm_eval (bool): Whether to set norm layers to eval mode, namely, |
|
freeze running stats (mean and var). Note: Effect on Batch Norm |
|
and its variants only. Default: False. |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
|
memory while slowing down the training speed. Default: False. |
|
pretrained (str, optional): model pretrained path. Default: None |
|
init_cfg (dict or list[dict], optional): Initialization config dict. |
|
Default: None |
|
""" |
|
|
|
|
|
|
|
arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4], |
|
[6, 96, 3], [6, 160, 3], [6, 320, 1]] |
|
|
|
def __init__(self, |
|
widen_factor=1., |
|
strides=(1, 2, 2, 2, 1, 2, 1), |
|
dilations=(1, 1, 1, 1, 1, 1, 1), |
|
out_indices=(1, 2, 4, 6), |
|
frozen_stages=-1, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN'), |
|
act_cfg=dict(type='ReLU6'), |
|
norm_eval=False, |
|
with_cp=False, |
|
pretrained=None, |
|
init_cfg=None): |
|
super().__init__(init_cfg) |
|
|
|
self.pretrained = pretrained |
|
assert not (init_cfg and pretrained), \ |
|
'init_cfg and pretrained cannot be setting at the same time' |
|
if isinstance(pretrained, str): |
|
warnings.warn('DeprecationWarning: pretrained is a deprecated, ' |
|
'please use "init_cfg" instead') |
|
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) |
|
elif pretrained is None: |
|
if init_cfg is None: |
|
self.init_cfg = [ |
|
dict(type='Kaiming', layer='Conv2d'), |
|
dict( |
|
type='Constant', |
|
val=1, |
|
layer=['_BatchNorm', 'GroupNorm']) |
|
] |
|
else: |
|
raise TypeError('pretrained must be a str or None') |
|
|
|
self.widen_factor = widen_factor |
|
self.strides = strides |
|
self.dilations = dilations |
|
assert len(strides) == len(dilations) == len(self.arch_settings) |
|
self.out_indices = out_indices |
|
for index in out_indices: |
|
if index not in range(0, 7): |
|
raise ValueError('the item in out_indices must in ' |
|
f'range(0, 7). But received {index}') |
|
|
|
if frozen_stages not in range(-1, 7): |
|
raise ValueError('frozen_stages must be in range(-1, 7). ' |
|
f'But received {frozen_stages}') |
|
self.out_indices = out_indices |
|
self.frozen_stages = frozen_stages |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
self.act_cfg = act_cfg |
|
self.norm_eval = norm_eval |
|
self.with_cp = with_cp |
|
|
|
self.in_channels = make_divisible(32 * widen_factor, 8) |
|
|
|
self.conv1 = ConvModule( |
|
in_channels=3, |
|
out_channels=self.in_channels, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
|
|
self.layers = [] |
|
|
|
for i, layer_cfg in enumerate(self.arch_settings): |
|
expand_ratio, channel, num_blocks = layer_cfg |
|
stride = self.strides[i] |
|
dilation = self.dilations[i] |
|
out_channels = make_divisible(channel * widen_factor, 8) |
|
inverted_res_layer = self.make_layer( |
|
out_channels=out_channels, |
|
num_blocks=num_blocks, |
|
stride=stride, |
|
dilation=dilation, |
|
expand_ratio=expand_ratio) |
|
layer_name = f'layer{i + 1}' |
|
self.add_module(layer_name, inverted_res_layer) |
|
self.layers.append(layer_name) |
|
|
|
def make_layer(self, out_channels, num_blocks, stride, dilation, |
|
expand_ratio): |
|
"""Stack InvertedResidual blocks to build a layer for MobileNetV2. |
|
|
|
Args: |
|
out_channels (int): out_channels of block. |
|
num_blocks (int): Number of blocks. |
|
stride (int): Stride of the first block. |
|
dilation (int): Dilation of the first block. |
|
expand_ratio (int): Expand the number of channels of the |
|
hidden layer in InvertedResidual by this ratio. |
|
""" |
|
layers = [] |
|
for i in range(num_blocks): |
|
layers.append( |
|
InvertedResidual( |
|
self.in_channels, |
|
out_channels, |
|
stride if i == 0 else 1, |
|
expand_ratio=expand_ratio, |
|
dilation=dilation if i == 0 else 1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg, |
|
with_cp=self.with_cp)) |
|
self.in_channels = out_channels |
|
|
|
return nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
|
|
outs = [] |
|
for i, layer_name in enumerate(self.layers): |
|
layer = getattr(self, layer_name) |
|
x = layer(x) |
|
if i in self.out_indices: |
|
outs.append(x) |
|
|
|
if len(outs) == 1: |
|
return outs[0] |
|
else: |
|
return tuple(outs) |
|
|
|
def _freeze_stages(self): |
|
if self.frozen_stages >= 0: |
|
for param in self.conv1.parameters(): |
|
param.requires_grad = False |
|
for i in range(1, self.frozen_stages + 1): |
|
layer = getattr(self, f'layer{i}') |
|
layer.eval() |
|
for param in layer.parameters(): |
|
param.requires_grad = False |
|
|
|
def train(self, mode=True): |
|
super().train(mode) |
|
self._freeze_stages() |
|
if mode and self.norm_eval: |
|
for m in self.modules(): |
|
if isinstance(m, _BatchNorm): |
|
m.eval() |
|
|