KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/MichaelFan01/STDC-Seg."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule, ModuleList, Sequential
from mmseg.registry import MODELS
from ..utils import resize
from .bisenetv1 import AttentionRefinementModule
class STDCModule(BaseModule):
"""STDCModule.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels before scaling.
stride (int): The number of stride for the first conv layer.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): The activation config for conv layers.
num_convs (int): Numbers of conv layers.
fusion_type (str): Type of fusion operation. Default: 'add'.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
stride,
norm_cfg=None,
act_cfg=None,
num_convs=4,
fusion_type='add',
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert num_convs > 1
assert fusion_type in ['add', 'cat']
self.stride = stride
self.with_downsample = True if self.stride == 2 else False
self.fusion_type = fusion_type
self.layers = ModuleList()
conv_0 = ConvModule(
in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg)
if self.with_downsample:
self.downsample = ConvModule(
out_channels // 2,
out_channels // 2,
kernel_size=3,
stride=2,
padding=1,
groups=out_channels // 2,
norm_cfg=norm_cfg,
act_cfg=None)
if self.fusion_type == 'add':
self.layers.append(nn.Sequential(conv_0, self.downsample))
self.skip = Sequential(
ConvModule(
in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=1,
groups=in_channels,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
in_channels,
out_channels,
1,
norm_cfg=norm_cfg,
act_cfg=None))
else:
self.layers.append(conv_0)
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
else:
self.layers.append(conv_0)
for i in range(1, num_convs):
out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i
self.layers.append(
ConvModule(
out_channels // 2**i,
out_channels // out_factor,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, inputs):
if self.fusion_type == 'add':
out = self.forward_add(inputs)
else:
out = self.forward_cat(inputs)
return out
def forward_add(self, inputs):
layer_outputs = []
x = inputs.clone()
for layer in self.layers:
x = layer(x)
layer_outputs.append(x)
if self.with_downsample:
inputs = self.skip(inputs)
return torch.cat(layer_outputs, dim=1) + inputs
def forward_cat(self, inputs):
x0 = self.layers[0](inputs)
layer_outputs = [x0]
for i, layer in enumerate(self.layers[1:]):
if i == 0:
if self.with_downsample:
x = layer(self.downsample(x0))
else:
x = layer(x0)
else:
x = layer(x)
layer_outputs.append(x)
if self.with_downsample:
layer_outputs[0] = self.skip(x0)
return torch.cat(layer_outputs, dim=1)
class FeatureFusionModule(BaseModule):
"""Feature Fusion Module. This module is different from FeatureFusionModule
in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter
channel number is calculated by given `scale_factor`, while
FeatureFusionModule in BiSeNetV1 only uses one ConvModule in
`self.conv_atten`.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
scale_factor (int): The number of channel scale factor.
Default: 4.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): The activation config for conv layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
scale_factor=4,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
channels = out_channels // scale_factor
self.conv0 = ConvModule(
in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
self.attention = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
ConvModule(
out_channels,
channels,
1,
norm_cfg=None,
bias=False,
act_cfg=act_cfg),
ConvModule(
channels,
out_channels,
1,
norm_cfg=None,
bias=False,
act_cfg=None), nn.Sigmoid())
def forward(self, spatial_inputs, context_inputs):
inputs = torch.cat([spatial_inputs, context_inputs], dim=1)
x = self.conv0(inputs)
attn = self.attention(x)
x_attn = x * attn
return x_attn + x
@MODELS.register_module()
class STDCNet(BaseModule):
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
Args:
stdc_type (int): The type of backbone structure,
`STDCNet1` and`STDCNet2` denotes two main backbones in paper,
whose FLOPs is 813M and 1446M, respectively.
in_channels (int): The num of input_channels.
channels (tuple[int]): The output channels for each stage.
bottleneck_type (str): The type of STDC Module type, the value must
be 'add' or 'cat'.
norm_cfg (dict): Config dict for normalization layer.
act_cfg (dict): The activation config for conv layers.
num_convs (int): Numbers of conv layer at each STDC Module.
Default: 4.
with_final_conv (bool): Whether add a conv layer at the Module output.
Default: True.
pretrained (str, optional): Model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> import torch
>>> stdc_type = 'STDCNet1'
>>> in_channels = 3
>>> channels = (32, 64, 256, 512, 1024)
>>> bottleneck_type = 'cat'
>>> inputs = torch.rand(1, 3, 1024, 2048)
>>> self = STDCNet(stdc_type, in_channels,
... channels, bottleneck_type).eval()
>>> outputs = self.forward(inputs)
>>> for i in range(len(outputs)):
... print(f'outputs[{i}].shape = {outputs[i].shape}')
outputs[0].shape = torch.Size([1, 256, 128, 256])
outputs[1].shape = torch.Size([1, 512, 64, 128])
outputs[2].shape = torch.Size([1, 1024, 32, 64])
"""
arch_settings = {
'STDCNet1': [(2, 1), (2, 1), (2, 1)],
'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)]
}
def __init__(self,
stdc_type,
in_channels,
channels,
bottleneck_type,
norm_cfg,
act_cfg,
num_convs=4,
with_final_conv=False,
pretrained=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert stdc_type in self.arch_settings, \
f'invalid structure {stdc_type} for STDCNet.'
assert bottleneck_type in ['add', 'cat'],\
f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}'
assert len(channels) == 5,\
f'invalid channels length {len(channels)} for STDCNet.'
self.in_channels = in_channels
self.channels = channels
self.stage_strides = self.arch_settings[stdc_type]
self.prtrained = pretrained
self.num_convs = num_convs
self.with_final_conv = with_final_conv
self.stages = ModuleList([
ConvModule(
self.in_channels,
self.channels[0],
kernel_size=3,
stride=2,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
self.channels[0],
self.channels[1],
kernel_size=3,
stride=2,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
])
# `self.num_shallow_features` is the number of shallow modules in
# `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper.
# They are both not used for following modules like Attention
# Refinement Module and Feature Fusion Module.
# Thus they would be cut from `outs`. Please refer to Figure 4
# of original paper for more details.
self.num_shallow_features = len(self.stages)
for strides in self.stage_strides:
idx = len(self.stages) - 1
self.stages.append(
self._make_stage(self.channels[idx], self.channels[idx + 1],
strides, norm_cfg, act_cfg, bottleneck_type))
# After appending, `self.stages` is a ModuleList including several
# shallow modules and STDCModules.
# (len(self.stages) ==
# self.num_shallow_features + len(self.stage_strides))
if self.with_final_conv:
self.final_conv = ConvModule(
self.channels[-1],
max(1024, self.channels[-1]),
1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def _make_stage(self, in_channels, out_channels, strides, norm_cfg,
act_cfg, bottleneck_type):
layers = []
for i, stride in enumerate(strides):
layers.append(
STDCModule(
in_channels if i == 0 else out_channels,
out_channels,
stride,
norm_cfg,
act_cfg,
num_convs=self.num_convs,
fusion_type=bottleneck_type))
return Sequential(*layers)
def forward(self, x):
outs = []
for stage in self.stages:
x = stage(x)
outs.append(x)
if self.with_final_conv:
outs[-1] = self.final_conv(outs[-1])
outs = outs[self.num_shallow_features:]
return tuple(outs)
@MODELS.register_module()
class STDCContextPathNet(BaseModule):
"""STDCNet with Context Path. The `outs` below is a list of three feature
maps from deep to shallow, whose height and width is from small to big,
respectively. The biggest feature map of `outs` is outputted for
`STDCHead`, where Detail Loss would be calculated by Detail Ground-truth.
The other two feature maps are used for Attention Refinement Module,
respectively. Besides, the biggest feature map of `outs` and the last
output of Attention Refinement Module are concatenated for Feature Fusion
Module. Then, this fusion feature map `feat_fuse` would be outputted for
`decode_head`. More details please refer to Figure 4 of original paper.
Args:
backbone_cfg (dict): Config dict for stdc backbone.
last_in_channels (tuple(int)), The number of channels of last
two feature maps from stdc backbone. Default: (1024, 512).
out_channels (int): The channels of output feature maps.
Default: 128.
ffm_cfg (dict): Config dict for Feature Fusion Module. Default:
`dict(in_channels=512, out_channels=256, scale_factor=4)`.
upsample_mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``.
align_corners (str): align_corners argument of F.interpolate. It
must be `None` if upsample_mode is ``'nearest'``. Default: None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Return:
outputs (tuple): The tuple of list of output feature map for
auxiliary heads and decoder head.
"""
def __init__(self,
backbone_cfg,
last_in_channels=(1024, 512),
out_channels=128,
ffm_cfg=dict(
in_channels=512, out_channels=256, scale_factor=4),
upsample_mode='nearest',
align_corners=None,
norm_cfg=dict(type='BN'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.backbone = MODELS.build(backbone_cfg)
self.arms = ModuleList()
self.convs = ModuleList()
for channels in last_in_channels:
self.arms.append(AttentionRefinementModule(channels, out_channels))
self.convs.append(
ConvModule(
out_channels,
out_channels,
3,
padding=1,
norm_cfg=norm_cfg))
self.conv_avg = ConvModule(
last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg)
self.ffm = FeatureFusionModule(**ffm_cfg)
self.upsample_mode = upsample_mode
self.align_corners = align_corners
def forward(self, x):
outs = list(self.backbone(x))
avg = F.adaptive_avg_pool2d(outs[-1], 1)
avg_feat = self.conv_avg(avg)
feature_up = resize(
avg_feat,
size=outs[-1].shape[2:],
mode=self.upsample_mode,
align_corners=self.align_corners)
arms_out = []
for i in range(len(self.arms)):
x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up
feature_up = resize(
x_arm,
size=outs[len(outs) - 1 - i - 1].shape[2:],
mode=self.upsample_mode,
align_corners=self.align_corners)
feature_up = self.convs[i](feature_up)
arms_out.append(feature_up)
feat_fuse = self.ffm(outs[0], arms_out[1])
# The `outputs` has four feature maps.
# `outs[0]` is outputted for `STDCHead` auxiliary head.
# Two feature maps of `arms_out` are outputted for auxiliary head.
# `feat_fuse` is outputted for decoder head.
outputs = [outs[0]] + list(arms_out) + [feat_fuse]
return tuple(outputs)