Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
"""Modified from https://github.com/MichaelFan01/STDC-Seg.""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule | |
from mmengine.model import BaseModule, ModuleList, Sequential | |
from mmseg.registry import MODELS | |
from ..utils import resize | |
from .bisenetv1 import AttentionRefinementModule | |
class STDCModule(BaseModule): | |
"""STDCModule. | |
Args: | |
in_channels (int): The number of input channels. | |
out_channels (int): The number of output channels before scaling. | |
stride (int): The number of stride for the first conv layer. | |
norm_cfg (dict): Config dict for normalization layer. Default: None. | |
act_cfg (dict): The activation config for conv layers. | |
num_convs (int): Numbers of conv layers. | |
fusion_type (str): Type of fusion operation. Default: 'add'. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
stride, | |
norm_cfg=None, | |
act_cfg=None, | |
num_convs=4, | |
fusion_type='add', | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
assert num_convs > 1 | |
assert fusion_type in ['add', 'cat'] | |
self.stride = stride | |
self.with_downsample = True if self.stride == 2 else False | |
self.fusion_type = fusion_type | |
self.layers = ModuleList() | |
conv_0 = ConvModule( | |
in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg) | |
if self.with_downsample: | |
self.downsample = ConvModule( | |
out_channels // 2, | |
out_channels // 2, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
groups=out_channels // 2, | |
norm_cfg=norm_cfg, | |
act_cfg=None) | |
if self.fusion_type == 'add': | |
self.layers.append(nn.Sequential(conv_0, self.downsample)) | |
self.skip = Sequential( | |
ConvModule( | |
in_channels, | |
in_channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
groups=in_channels, | |
norm_cfg=norm_cfg, | |
act_cfg=None), | |
ConvModule( | |
in_channels, | |
out_channels, | |
1, | |
norm_cfg=norm_cfg, | |
act_cfg=None)) | |
else: | |
self.layers.append(conv_0) | |
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) | |
else: | |
self.layers.append(conv_0) | |
for i in range(1, num_convs): | |
out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i | |
self.layers.append( | |
ConvModule( | |
out_channels // 2**i, | |
out_channels // out_factor, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
def forward(self, inputs): | |
if self.fusion_type == 'add': | |
out = self.forward_add(inputs) | |
else: | |
out = self.forward_cat(inputs) | |
return out | |
def forward_add(self, inputs): | |
layer_outputs = [] | |
x = inputs.clone() | |
for layer in self.layers: | |
x = layer(x) | |
layer_outputs.append(x) | |
if self.with_downsample: | |
inputs = self.skip(inputs) | |
return torch.cat(layer_outputs, dim=1) + inputs | |
def forward_cat(self, inputs): | |
x0 = self.layers[0](inputs) | |
layer_outputs = [x0] | |
for i, layer in enumerate(self.layers[1:]): | |
if i == 0: | |
if self.with_downsample: | |
x = layer(self.downsample(x0)) | |
else: | |
x = layer(x0) | |
else: | |
x = layer(x) | |
layer_outputs.append(x) | |
if self.with_downsample: | |
layer_outputs[0] = self.skip(x0) | |
return torch.cat(layer_outputs, dim=1) | |
class FeatureFusionModule(BaseModule): | |
"""Feature Fusion Module. This module is different from FeatureFusionModule | |
in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter | |
channel number is calculated by given `scale_factor`, while | |
FeatureFusionModule in BiSeNetV1 only uses one ConvModule in | |
`self.conv_atten`. | |
Args: | |
in_channels (int): The number of input channels. | |
out_channels (int): The number of output channels. | |
scale_factor (int): The number of channel scale factor. | |
Default: 4. | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: dict(type='BN'). | |
act_cfg (dict): The activation config for conv layers. | |
Default: dict(type='ReLU'). | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
scale_factor=4, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
channels = out_channels // scale_factor | |
self.conv0 = ConvModule( | |
in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) | |
self.attention = nn.Sequential( | |
nn.AdaptiveAvgPool2d((1, 1)), | |
ConvModule( | |
out_channels, | |
channels, | |
1, | |
norm_cfg=None, | |
bias=False, | |
act_cfg=act_cfg), | |
ConvModule( | |
channels, | |
out_channels, | |
1, | |
norm_cfg=None, | |
bias=False, | |
act_cfg=None), nn.Sigmoid()) | |
def forward(self, spatial_inputs, context_inputs): | |
inputs = torch.cat([spatial_inputs, context_inputs], dim=1) | |
x = self.conv0(inputs) | |
attn = self.attention(x) | |
x_attn = x * attn | |
return x_attn + x | |
class STDCNet(BaseModule): | |
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time | |
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_. | |
Args: | |
stdc_type (int): The type of backbone structure, | |
`STDCNet1` and`STDCNet2` denotes two main backbones in paper, | |
whose FLOPs is 813M and 1446M, respectively. | |
in_channels (int): The num of input_channels. | |
channels (tuple[int]): The output channels for each stage. | |
bottleneck_type (str): The type of STDC Module type, the value must | |
be 'add' or 'cat'. | |
norm_cfg (dict): Config dict for normalization layer. | |
act_cfg (dict): The activation config for conv layers. | |
num_convs (int): Numbers of conv layer at each STDC Module. | |
Default: 4. | |
with_final_conv (bool): Whether add a conv layer at the Module output. | |
Default: True. | |
pretrained (str, optional): Model pretrained path. Default: None. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
Example: | |
>>> import torch | |
>>> stdc_type = 'STDCNet1' | |
>>> in_channels = 3 | |
>>> channels = (32, 64, 256, 512, 1024) | |
>>> bottleneck_type = 'cat' | |
>>> inputs = torch.rand(1, 3, 1024, 2048) | |
>>> self = STDCNet(stdc_type, in_channels, | |
... channels, bottleneck_type).eval() | |
>>> outputs = self.forward(inputs) | |
>>> for i in range(len(outputs)): | |
... print(f'outputs[{i}].shape = {outputs[i].shape}') | |
outputs[0].shape = torch.Size([1, 256, 128, 256]) | |
outputs[1].shape = torch.Size([1, 512, 64, 128]) | |
outputs[2].shape = torch.Size([1, 1024, 32, 64]) | |
""" | |
arch_settings = { | |
'STDCNet1': [(2, 1), (2, 1), (2, 1)], | |
'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)] | |
} | |
def __init__(self, | |
stdc_type, | |
in_channels, | |
channels, | |
bottleneck_type, | |
norm_cfg, | |
act_cfg, | |
num_convs=4, | |
with_final_conv=False, | |
pretrained=None, | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
assert stdc_type in self.arch_settings, \ | |
f'invalid structure {stdc_type} for STDCNet.' | |
assert bottleneck_type in ['add', 'cat'],\ | |
f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}' | |
assert len(channels) == 5,\ | |
f'invalid channels length {len(channels)} for STDCNet.' | |
self.in_channels = in_channels | |
self.channels = channels | |
self.stage_strides = self.arch_settings[stdc_type] | |
self.prtrained = pretrained | |
self.num_convs = num_convs | |
self.with_final_conv = with_final_conv | |
self.stages = ModuleList([ | |
ConvModule( | |
self.in_channels, | |
self.channels[0], | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg), | |
ConvModule( | |
self.channels[0], | |
self.channels[1], | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
]) | |
# `self.num_shallow_features` is the number of shallow modules in | |
# `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper. | |
# They are both not used for following modules like Attention | |
# Refinement Module and Feature Fusion Module. | |
# Thus they would be cut from `outs`. Please refer to Figure 4 | |
# of original paper for more details. | |
self.num_shallow_features = len(self.stages) | |
for strides in self.stage_strides: | |
idx = len(self.stages) - 1 | |
self.stages.append( | |
self._make_stage(self.channels[idx], self.channels[idx + 1], | |
strides, norm_cfg, act_cfg, bottleneck_type)) | |
# After appending, `self.stages` is a ModuleList including several | |
# shallow modules and STDCModules. | |
# (len(self.stages) == | |
# self.num_shallow_features + len(self.stage_strides)) | |
if self.with_final_conv: | |
self.final_conv = ConvModule( | |
self.channels[-1], | |
max(1024, self.channels[-1]), | |
1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
def _make_stage(self, in_channels, out_channels, strides, norm_cfg, | |
act_cfg, bottleneck_type): | |
layers = [] | |
for i, stride in enumerate(strides): | |
layers.append( | |
STDCModule( | |
in_channels if i == 0 else out_channels, | |
out_channels, | |
stride, | |
norm_cfg, | |
act_cfg, | |
num_convs=self.num_convs, | |
fusion_type=bottleneck_type)) | |
return Sequential(*layers) | |
def forward(self, x): | |
outs = [] | |
for stage in self.stages: | |
x = stage(x) | |
outs.append(x) | |
if self.with_final_conv: | |
outs[-1] = self.final_conv(outs[-1]) | |
outs = outs[self.num_shallow_features:] | |
return tuple(outs) | |
class STDCContextPathNet(BaseModule): | |
"""STDCNet with Context Path. The `outs` below is a list of three feature | |
maps from deep to shallow, whose height and width is from small to big, | |
respectively. The biggest feature map of `outs` is outputted for | |
`STDCHead`, where Detail Loss would be calculated by Detail Ground-truth. | |
The other two feature maps are used for Attention Refinement Module, | |
respectively. Besides, the biggest feature map of `outs` and the last | |
output of Attention Refinement Module are concatenated for Feature Fusion | |
Module. Then, this fusion feature map `feat_fuse` would be outputted for | |
`decode_head`. More details please refer to Figure 4 of original paper. | |
Args: | |
backbone_cfg (dict): Config dict for stdc backbone. | |
last_in_channels (tuple(int)), The number of channels of last | |
two feature maps from stdc backbone. Default: (1024, 512). | |
out_channels (int): The channels of output feature maps. | |
Default: 128. | |
ffm_cfg (dict): Config dict for Feature Fusion Module. Default: | |
`dict(in_channels=512, out_channels=256, scale_factor=4)`. | |
upsample_mode (str): Algorithm used for upsampling: | |
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | | |
``'trilinear'``. Default: ``'nearest'``. | |
align_corners (str): align_corners argument of F.interpolate. It | |
must be `None` if upsample_mode is ``'nearest'``. Default: None. | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: dict(type='BN'). | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
Return: | |
outputs (tuple): The tuple of list of output feature map for | |
auxiliary heads and decoder head. | |
""" | |
def __init__(self, | |
backbone_cfg, | |
last_in_channels=(1024, 512), | |
out_channels=128, | |
ffm_cfg=dict( | |
in_channels=512, out_channels=256, scale_factor=4), | |
upsample_mode='nearest', | |
align_corners=None, | |
norm_cfg=dict(type='BN'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.backbone = MODELS.build(backbone_cfg) | |
self.arms = ModuleList() | |
self.convs = ModuleList() | |
for channels in last_in_channels: | |
self.arms.append(AttentionRefinementModule(channels, out_channels)) | |
self.convs.append( | |
ConvModule( | |
out_channels, | |
out_channels, | |
3, | |
padding=1, | |
norm_cfg=norm_cfg)) | |
self.conv_avg = ConvModule( | |
last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg) | |
self.ffm = FeatureFusionModule(**ffm_cfg) | |
self.upsample_mode = upsample_mode | |
self.align_corners = align_corners | |
def forward(self, x): | |
outs = list(self.backbone(x)) | |
avg = F.adaptive_avg_pool2d(outs[-1], 1) | |
avg_feat = self.conv_avg(avg) | |
feature_up = resize( | |
avg_feat, | |
size=outs[-1].shape[2:], | |
mode=self.upsample_mode, | |
align_corners=self.align_corners) | |
arms_out = [] | |
for i in range(len(self.arms)): | |
x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up | |
feature_up = resize( | |
x_arm, | |
size=outs[len(outs) - 1 - i - 1].shape[2:], | |
mode=self.upsample_mode, | |
align_corners=self.align_corners) | |
feature_up = self.convs[i](feature_up) | |
arms_out.append(feature_up) | |
feat_fuse = self.ffm(outs[0], arms_out[1]) | |
# The `outputs` has four feature maps. | |
# `outs[0]` is outputted for `STDCHead` auxiliary head. | |
# Two feature maps of `arms_out` are outputted for auxiliary head. | |
# `feat_fuse` is outputted for decoder head. | |
outputs = [outs[0]] + list(arms_out) + [feat_fuse] | |
return tuple(outputs) | |