| |
| import torch |
| import torch.nn as nn |
| from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule |
| from mmengine.model import BaseModule |
|
|
| from mmseg.registry import MODELS |
| from ..utils import resize |
|
|
|
|
| @MODELS.register_module() |
| class JPU(BaseModule): |
| """FastFCN: Rethinking Dilated Convolution in the Backbone |
| for Semantic Segmentation. |
| |
| This Joint Pyramid Upsampling (JPU) neck is the implementation of |
| `FastFCN <https://arxiv.org/abs/1903.11816>`_. |
| |
| Args: |
| in_channels (Tuple[int], optional): The number of input channels |
| for each convolution operations before upsampling. |
| Default: (512, 1024, 2048). |
| mid_channels (int): The number of output channels of JPU. |
| Default: 512. |
| start_level (int): Index of the start input backbone level used to |
| build the feature pyramid. Default: 0. |
| end_level (int): Index of the end input backbone level (exclusive) to |
| build the feature pyramid. Default: -1, which means the last level. |
| dilations (tuple[int]): Dilation rate of each Depthwise |
| Separable ConvModule. Default: (1, 2, 4, 8). |
| align_corners (bool, optional): The align_corners argument of |
| resize operation. Default: False. |
| conv_cfg (dict | None): Config of conv layers. |
| Default: None. |
| norm_cfg (dict | None): Config of norm layers. |
| Default: dict(type='BN'). |
| act_cfg (dict): Config of activation layers. |
| Default: dict(type='ReLU'). |
| init_cfg (dict or list[dict], optional): Initialization config dict. |
| Default: None. |
| """ |
|
|
| def __init__(self, |
| in_channels=(512, 1024, 2048), |
| mid_channels=512, |
| start_level=0, |
| end_level=-1, |
| dilations=(1, 2, 4, 8), |
| align_corners=False, |
| conv_cfg=None, |
| norm_cfg=dict(type='BN'), |
| act_cfg=dict(type='ReLU'), |
| init_cfg=None): |
| super().__init__(init_cfg=init_cfg) |
| assert isinstance(in_channels, tuple) |
| assert isinstance(dilations, tuple) |
| self.in_channels = in_channels |
| self.mid_channels = mid_channels |
| self.start_level = start_level |
| self.num_ins = len(in_channels) |
| if end_level == -1: |
| self.backbone_end_level = self.num_ins |
| else: |
| self.backbone_end_level = end_level |
| assert end_level <= len(in_channels) |
|
|
| self.dilations = dilations |
| self.align_corners = align_corners |
|
|
| self.conv_layers = nn.ModuleList() |
| self.dilation_layers = nn.ModuleList() |
| for i in range(self.start_level, self.backbone_end_level): |
| conv_layer = nn.Sequential( |
| ConvModule( |
| self.in_channels[i], |
| self.mid_channels, |
| kernel_size=3, |
| padding=1, |
| conv_cfg=conv_cfg, |
| norm_cfg=norm_cfg, |
| act_cfg=act_cfg)) |
| self.conv_layers.append(conv_layer) |
| for i in range(len(dilations)): |
| dilation_layer = nn.Sequential( |
| DepthwiseSeparableConvModule( |
| in_channels=(self.backbone_end_level - self.start_level) * |
| self.mid_channels, |
| out_channels=self.mid_channels, |
| kernel_size=3, |
| stride=1, |
| padding=dilations[i], |
| dilation=dilations[i], |
| dw_norm_cfg=norm_cfg, |
| dw_act_cfg=None, |
| pw_norm_cfg=norm_cfg, |
| pw_act_cfg=act_cfg)) |
| self.dilation_layers.append(dilation_layer) |
|
|
| def forward(self, inputs): |
| """Forward function.""" |
| assert len(inputs) == len(self.in_channels), 'Length of inputs must \ |
| be the same with self.in_channels!' |
|
|
| feats = [ |
| self.conv_layers[i - self.start_level](inputs[i]) |
| for i in range(self.start_level, self.backbone_end_level) |
| ] |
|
|
| h, w = feats[0].shape[2:] |
| for i in range(1, len(feats)): |
| feats[i] = resize( |
| feats[i], |
| size=(h, w), |
| mode='bilinear', |
| align_corners=self.align_corners) |
|
|
| feat = torch.cat(feats, dim=1) |
| concat_feat = torch.cat([ |
| self.dilation_layers[i](feat) for i in range(len(self.dilations)) |
| ], |
| dim=1) |
|
|
| outs = [] |
|
|
| |
| |
| |
| for i in range(self.start_level, self.backbone_end_level - 1): |
| outs.append(inputs[i]) |
| outs.append(concat_feat) |
| return tuple(outs) |
|
|