Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule | |
from mmengine.model import BaseModule | |
from mmseg.registry import MODELS | |
from ..utils import resize | |
class JPU(BaseModule): | |
"""FastFCN: Rethinking Dilated Convolution in the Backbone | |
for Semantic Segmentation. | |
This Joint Pyramid Upsampling (JPU) neck is the implementation of | |
`FastFCN <https://arxiv.org/abs/1903.11816>`_. | |
Args: | |
in_channels (Tuple[int], optional): The number of input channels | |
for each convolution operations before upsampling. | |
Default: (512, 1024, 2048). | |
mid_channels (int): The number of output channels of JPU. | |
Default: 512. | |
start_level (int): Index of the start input backbone level used to | |
build the feature pyramid. Default: 0. | |
end_level (int): Index of the end input backbone level (exclusive) to | |
build the feature pyramid. Default: -1, which means the last level. | |
dilations (tuple[int]): Dilation rate of each Depthwise | |
Separable ConvModule. Default: (1, 2, 4, 8). | |
align_corners (bool, optional): The align_corners argument of | |
resize operation. Default: False. | |
conv_cfg (dict | None): Config of conv layers. | |
Default: None. | |
norm_cfg (dict | None): Config of norm layers. | |
Default: dict(type='BN'). | |
act_cfg (dict): Config of activation layers. | |
Default: dict(type='ReLU'). | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
""" | |
def __init__(self, | |
in_channels=(512, 1024, 2048), | |
mid_channels=512, | |
start_level=0, | |
end_level=-1, | |
dilations=(1, 2, 4, 8), | |
align_corners=False, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
assert isinstance(in_channels, tuple) | |
assert isinstance(dilations, tuple) | |
self.in_channels = in_channels | |
self.mid_channels = mid_channels | |
self.start_level = start_level | |
self.num_ins = len(in_channels) | |
if end_level == -1: | |
self.backbone_end_level = self.num_ins | |
else: | |
self.backbone_end_level = end_level | |
assert end_level <= len(in_channels) | |
self.dilations = dilations | |
self.align_corners = align_corners | |
self.conv_layers = nn.ModuleList() | |
self.dilation_layers = nn.ModuleList() | |
for i in range(self.start_level, self.backbone_end_level): | |
conv_layer = nn.Sequential( | |
ConvModule( | |
self.in_channels[i], | |
self.mid_channels, | |
kernel_size=3, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
self.conv_layers.append(conv_layer) | |
for i in range(len(dilations)): | |
dilation_layer = nn.Sequential( | |
DepthwiseSeparableConvModule( | |
in_channels=(self.backbone_end_level - self.start_level) * | |
self.mid_channels, | |
out_channels=self.mid_channels, | |
kernel_size=3, | |
stride=1, | |
padding=dilations[i], | |
dilation=dilations[i], | |
dw_norm_cfg=norm_cfg, | |
dw_act_cfg=None, | |
pw_norm_cfg=norm_cfg, | |
pw_act_cfg=act_cfg)) | |
self.dilation_layers.append(dilation_layer) | |
def forward(self, inputs): | |
"""Forward function.""" | |
assert len(inputs) == len(self.in_channels), 'Length of inputs must \ | |
be the same with self.in_channels!' | |
feats = [ | |
self.conv_layers[i - self.start_level](inputs[i]) | |
for i in range(self.start_level, self.backbone_end_level) | |
] | |
h, w = feats[0].shape[2:] | |
for i in range(1, len(feats)): | |
feats[i] = resize( | |
feats[i], | |
size=(h, w), | |
mode='bilinear', | |
align_corners=self.align_corners) | |
feat = torch.cat(feats, dim=1) | |
concat_feat = torch.cat([ | |
self.dilation_layers[i](feat) for i in range(len(self.dilations)) | |
], | |
dim=1) | |
outs = [] | |
# Default: outs[2] is the output of JPU for decoder head, outs[1] is | |
# the feature map from backbone for auxiliary head. Additionally, | |
# outs[0] can also be used for auxiliary head. | |
for i in range(self.start_level, self.backbone_end_level - 1): | |
outs.append(inputs[i]) | |
outs.append(concat_feat) | |
return tuple(outs) | |