Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, | |
build_activation_layer, build_norm_layer) | |
from mmengine.model import BaseModule | |
from mmseg.registry import MODELS | |
from ..utils import resize | |
class DetailBranch(BaseModule): | |
"""Detail Branch with wide channels and shallow layers to capture low-level | |
details and generate high-resolution feature representation. | |
Args: | |
detail_channels (Tuple[int]): Size of channel numbers of each stage | |
in Detail Branch, in paper it has 3 stages. | |
Default: (64, 64, 128). | |
in_channels (int): Number of channels of input image. Default: 3. | |
conv_cfg (dict | None): Config of conv layers. | |
Default: None. | |
norm_cfg (dict | None): Config of norm layers. | |
Default: dict(type='BN'). | |
act_cfg (dict): Config of activation layers. | |
Default: dict(type='ReLU'). | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
Returns: | |
x (torch.Tensor): Feature map of Detail Branch. | |
""" | |
def __init__(self, | |
detail_channels=(64, 64, 128), | |
in_channels=3, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
detail_branch = [] | |
for i in range(len(detail_channels)): | |
if i == 0: | |
detail_branch.append( | |
nn.Sequential( | |
ConvModule( | |
in_channels=in_channels, | |
out_channels=detail_channels[i], | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg), | |
ConvModule( | |
in_channels=detail_channels[i], | |
out_channels=detail_channels[i], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg))) | |
else: | |
detail_branch.append( | |
nn.Sequential( | |
ConvModule( | |
in_channels=detail_channels[i - 1], | |
out_channels=detail_channels[i], | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg), | |
ConvModule( | |
in_channels=detail_channels[i], | |
out_channels=detail_channels[i], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg), | |
ConvModule( | |
in_channels=detail_channels[i], | |
out_channels=detail_channels[i], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg))) | |
self.detail_branch = nn.ModuleList(detail_branch) | |
def forward(self, x): | |
for stage in self.detail_branch: | |
x = stage(x) | |
return x | |
class StemBlock(BaseModule): | |
"""Stem Block at the beginning of Semantic Branch. | |
Args: | |
in_channels (int): Number of input channels. | |
Default: 3. | |
out_channels (int): Number of output channels. | |
Default: 16. | |
conv_cfg (dict | None): Config of conv layers. | |
Default: None. | |
norm_cfg (dict | None): Config of norm layers. | |
Default: dict(type='BN'). | |
act_cfg (dict): Config of activation layers. | |
Default: dict(type='ReLU'). | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
Returns: | |
x (torch.Tensor): First feature map in Semantic Branch. | |
""" | |
def __init__(self, | |
in_channels=3, | |
out_channels=16, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.conv_first = ConvModule( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.convs = nn.Sequential( | |
ConvModule( | |
in_channels=out_channels, | |
out_channels=out_channels // 2, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg), | |
ConvModule( | |
in_channels=out_channels // 2, | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
self.pool = nn.MaxPool2d( | |
kernel_size=3, stride=2, padding=1, ceil_mode=False) | |
self.fuse_last = ConvModule( | |
in_channels=out_channels * 2, | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
def forward(self, x): | |
x = self.conv_first(x) | |
x_left = self.convs(x) | |
x_right = self.pool(x) | |
x = self.fuse_last(torch.cat([x_left, x_right], dim=1)) | |
return x | |
class GELayer(BaseModule): | |
"""Gather-and-Expansion Layer. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
exp_ratio (int): Expansion ratio for middle channels. | |
Default: 6. | |
stride (int): Stride of GELayer. Default: 1 | |
conv_cfg (dict | None): Config of conv layers. | |
Default: None. | |
norm_cfg (dict | None): Config of norm layers. | |
Default: dict(type='BN'). | |
act_cfg (dict): Config of activation layers. | |
Default: dict(type='ReLU'). | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
Returns: | |
x (torch.Tensor): Intermediate feature map in | |
Semantic Branch. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
exp_ratio=6, | |
stride=1, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
mid_channel = in_channels * exp_ratio | |
self.conv1 = ConvModule( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
if stride == 1: | |
self.dwconv = nn.Sequential( | |
# ReLU in ConvModule not shown in paper | |
ConvModule( | |
in_channels=in_channels, | |
out_channels=mid_channel, | |
kernel_size=3, | |
stride=stride, | |
padding=1, | |
groups=in_channels, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
self.shortcut = None | |
else: | |
self.dwconv = nn.Sequential( | |
ConvModule( | |
in_channels=in_channels, | |
out_channels=mid_channel, | |
kernel_size=3, | |
stride=stride, | |
padding=1, | |
groups=in_channels, | |
bias=False, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None), | |
# ReLU in ConvModule not shown in paper | |
ConvModule( | |
in_channels=mid_channel, | |
out_channels=mid_channel, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
groups=mid_channel, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg), | |
) | |
self.shortcut = nn.Sequential( | |
DepthwiseSeparableConvModule( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=3, | |
stride=stride, | |
padding=1, | |
dw_norm_cfg=norm_cfg, | |
dw_act_cfg=None, | |
pw_norm_cfg=norm_cfg, | |
pw_act_cfg=None, | |
)) | |
self.conv2 = nn.Sequential( | |
ConvModule( | |
in_channels=mid_channel, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None, | |
)) | |
self.act = build_activation_layer(act_cfg) | |
def forward(self, x): | |
identity = x | |
x = self.conv1(x) | |
x = self.dwconv(x) | |
x = self.conv2(x) | |
if self.shortcut is not None: | |
shortcut = self.shortcut(identity) | |
x = x + shortcut | |
else: | |
x = x + identity | |
x = self.act(x) | |
return x | |
class CEBlock(BaseModule): | |
"""Context Embedding Block for large receptive filed in Semantic Branch. | |
Args: | |
in_channels (int): Number of input channels. | |
Default: 3. | |
out_channels (int): Number of output channels. | |
Default: 16. | |
conv_cfg (dict | None): Config of conv layers. | |
Default: None. | |
norm_cfg (dict | None): Config of norm layers. | |
Default: dict(type='BN'). | |
act_cfg (dict): Config of activation layers. | |
Default: dict(type='ReLU'). | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
Returns: | |
x (torch.Tensor): Last feature map in Semantic Branch. | |
""" | |
def __init__(self, | |
in_channels=3, | |
out_channels=16, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.gap = nn.Sequential( | |
nn.AdaptiveAvgPool2d((1, 1)), | |
build_norm_layer(norm_cfg, self.in_channels)[1]) | |
self.conv_gap = ConvModule( | |
in_channels=self.in_channels, | |
out_channels=self.out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
# Note: in paper here is naive conv2d, no bn-relu | |
self.conv_last = ConvModule( | |
in_channels=self.out_channels, | |
out_channels=self.out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
def forward(self, x): | |
identity = x | |
x = self.gap(x) | |
x = self.conv_gap(x) | |
x = identity + x | |
x = self.conv_last(x) | |
return x | |
class SemanticBranch(BaseModule): | |
"""Semantic Branch which is lightweight with narrow channels and deep | |
layers to obtain high-level semantic context. | |
Args: | |
semantic_channels(Tuple[int]): Size of channel numbers of | |
various stages in Semantic Branch. | |
Default: (16, 32, 64, 128). | |
in_channels (int): Number of channels of input image. Default: 3. | |
exp_ratio (int): Expansion ratio for middle channels. | |
Default: 6. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
Returns: | |
semantic_outs (List[torch.Tensor]): List of several feature maps | |
for auxiliary heads (Booster) and Bilateral | |
Guided Aggregation Layer. | |
""" | |
def __init__(self, | |
semantic_channels=(16, 32, 64, 128), | |
in_channels=3, | |
exp_ratio=6, | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.in_channels = in_channels | |
self.semantic_channels = semantic_channels | |
self.semantic_stages = [] | |
for i in range(len(semantic_channels)): | |
stage_name = f'stage{i + 1}' | |
self.semantic_stages.append(stage_name) | |
if i == 0: | |
self.add_module( | |
stage_name, | |
StemBlock(self.in_channels, semantic_channels[i])) | |
elif i == (len(semantic_channels) - 1): | |
self.add_module( | |
stage_name, | |
nn.Sequential( | |
GELayer(semantic_channels[i - 1], semantic_channels[i], | |
exp_ratio, 2), | |
GELayer(semantic_channels[i], semantic_channels[i], | |
exp_ratio, 1), | |
GELayer(semantic_channels[i], semantic_channels[i], | |
exp_ratio, 1), | |
GELayer(semantic_channels[i], semantic_channels[i], | |
exp_ratio, 1))) | |
else: | |
self.add_module( | |
stage_name, | |
nn.Sequential( | |
GELayer(semantic_channels[i - 1], semantic_channels[i], | |
exp_ratio, 2), | |
GELayer(semantic_channels[i], semantic_channels[i], | |
exp_ratio, 1))) | |
self.add_module(f'stage{len(semantic_channels)}_CEBlock', | |
CEBlock(semantic_channels[-1], semantic_channels[-1])) | |
self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock') | |
def forward(self, x): | |
semantic_outs = [] | |
for stage_name in self.semantic_stages: | |
semantic_stage = getattr(self, stage_name) | |
x = semantic_stage(x) | |
semantic_outs.append(x) | |
return semantic_outs | |
class BGALayer(BaseModule): | |
"""Bilateral Guided Aggregation Layer to fuse the complementary information | |
from both Detail Branch and Semantic Branch. | |
Args: | |
out_channels (int): Number of output channels. | |
Default: 128. | |
align_corners (bool): align_corners argument of F.interpolate. | |
Default: False. | |
conv_cfg (dict | None): Config of conv layers. | |
Default: None. | |
norm_cfg (dict | None): Config of norm layers. | |
Default: dict(type='BN'). | |
act_cfg (dict): Config of activation layers. | |
Default: dict(type='ReLU'). | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
Returns: | |
output (torch.Tensor): Output feature map for Segment heads. | |
""" | |
def __init__(self, | |
out_channels=128, | |
align_corners=False, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.out_channels = out_channels | |
self.align_corners = align_corners | |
self.detail_dwconv = nn.Sequential( | |
DepthwiseSeparableConvModule( | |
in_channels=self.out_channels, | |
out_channels=self.out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
dw_norm_cfg=norm_cfg, | |
dw_act_cfg=None, | |
pw_norm_cfg=None, | |
pw_act_cfg=None, | |
)) | |
self.detail_down = nn.Sequential( | |
ConvModule( | |
in_channels=self.out_channels, | |
out_channels=self.out_channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
bias=False, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None), | |
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)) | |
self.semantic_conv = nn.Sequential( | |
ConvModule( | |
in_channels=self.out_channels, | |
out_channels=self.out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=False, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None)) | |
self.semantic_dwconv = nn.Sequential( | |
DepthwiseSeparableConvModule( | |
in_channels=self.out_channels, | |
out_channels=self.out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
dw_norm_cfg=norm_cfg, | |
dw_act_cfg=None, | |
pw_norm_cfg=None, | |
pw_act_cfg=None, | |
)) | |
self.conv = ConvModule( | |
in_channels=self.out_channels, | |
out_channels=self.out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
inplace=True, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
) | |
def forward(self, x_d, x_s): | |
detail_dwconv = self.detail_dwconv(x_d) | |
detail_down = self.detail_down(x_d) | |
semantic_conv = self.semantic_conv(x_s) | |
semantic_dwconv = self.semantic_dwconv(x_s) | |
semantic_conv = resize( | |
input=semantic_conv, | |
size=detail_dwconv.shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) | |
fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv) | |
fuse_2 = detail_down * torch.sigmoid(semantic_dwconv) | |
fuse_2 = resize( | |
input=fuse_2, | |
size=fuse_1.shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) | |
output = self.conv(fuse_1 + fuse_2) | |
return output | |
class BiSeNetV2(BaseModule): | |
"""BiSeNetV2: Bilateral Network with Guided Aggregation for | |
Real-time Semantic Segmentation. | |
This backbone is the implementation of | |
`BiSeNetV2 <https://arxiv.org/abs/2004.02147>`_. | |
Args: | |
in_channels (int): Number of channel of input image. Default: 3. | |
detail_channels (Tuple[int], optional): Channels of each stage | |
in Detail Branch. Default: (64, 64, 128). | |
semantic_channels (Tuple[int], optional): Channels of each stage | |
in Semantic Branch. Default: (16, 32, 64, 128). | |
See Table 1 and Figure 3 of paper for more details. | |
semantic_expansion_ratio (int, optional): The expansion factor | |
expanding channel number of middle channels in Semantic Branch. | |
Default: 6. | |
bga_channels (int, optional): Number of middle channels in | |
Bilateral Guided Aggregation Layer. Default: 128. | |
out_indices (Tuple[int] | int, optional): Output from which stages. | |
Default: (0, 1, 2, 3, 4). | |
align_corners (bool, optional): The align_corners argument of | |
resize operation in Bilateral Guided Aggregation Layer. | |
Default: False. | |
conv_cfg (dict | None): Config of conv layers. | |
Default: None. | |
norm_cfg (dict | None): Config of norm layers. | |
Default: dict(type='BN'). | |
act_cfg (dict): Config of activation layers. | |
Default: dict(type='ReLU'). | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
""" | |
def __init__(self, | |
in_channels=3, | |
detail_channels=(64, 64, 128), | |
semantic_channels=(16, 32, 64, 128), | |
semantic_expansion_ratio=6, | |
bga_channels=128, | |
out_indices=(0, 1, 2, 3, 4), | |
align_corners=False, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
init_cfg=None): | |
if init_cfg is None: | |
init_cfg = [ | |
dict(type='Kaiming', layer='Conv2d'), | |
dict( | |
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) | |
] | |
super().__init__(init_cfg=init_cfg) | |
self.in_channels = in_channels | |
self.out_indices = out_indices | |
self.detail_channels = detail_channels | |
self.semantic_channels = semantic_channels | |
self.semantic_expansion_ratio = semantic_expansion_ratio | |
self.bga_channels = bga_channels | |
self.align_corners = align_corners | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.detail = DetailBranch(self.detail_channels, self.in_channels) | |
self.semantic = SemanticBranch(self.semantic_channels, | |
self.in_channels, | |
self.semantic_expansion_ratio) | |
self.bga = BGALayer(self.bga_channels, self.align_corners) | |
def forward(self, x): | |
# stole refactoring code from Coin Cheung, thanks | |
x_detail = self.detail(x) | |
x_semantic_lst = self.semantic(x) | |
x_head = self.bga(x_detail, x_semantic_lst[-1]) | |
outs = [x_head] + x_semantic_lst[:-1] | |
outs = [outs[i] for i in self.out_indices] | |
return tuple(outs) | |