# Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn from mmcv.cnn import ConvModule, build_norm_layer from mmengine.model import BaseModule from mmseg.models.utils import DAPPM, BasicBlock, Bottleneck, resize from mmseg.registry import MODELS from mmseg.utils import OptConfigType @MODELS.register_module() class DDRNet(BaseModule): """DDRNet backbone. This backbone is the implementation of `Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation of Road Scenes `_. Modified from https://github.com/ydhongHIT/DDRNet. Args: in_channels (int): Number of input image channels. Default: 3. channels: (int): The base channels of DDRNet. Default: 32. ppm_channels (int): The channels of PPM module. Default: 128. align_corners (bool): align_corners argument of F.interpolate. Default: False. norm_cfg (dict): Config dict to build norm layer. Default: dict(type='BN', requires_grad=True). act_cfg (dict): Config dict for activation layer. Default: dict(type='ReLU', inplace=True). init_cfg (dict, optional): Initialization config dict. Default: None. """ def __init__(self, in_channels: int = 3, channels: int = 32, ppm_channels: int = 128, align_corners: bool = False, norm_cfg: OptConfigType = dict(type='BN', requires_grad=True), act_cfg: OptConfigType = dict(type='ReLU', inplace=True), init_cfg: OptConfigType = None): super().__init__(init_cfg) self.in_channels = in_channels self.ppm_channels = ppm_channels self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.align_corners = align_corners # stage 0-2 self.stem = self._make_stem_layer(in_channels, channels, num_blocks=2) self.relu = nn.ReLU() # low resolution(context) branch self.context_branch_layers = nn.ModuleList() for i in range(3): self.context_branch_layers.append( self._make_layer( block=BasicBlock if i < 2 else Bottleneck, inplanes=channels * 2**(i + 1), planes=channels * 8 if i > 0 else channels * 4, num_blocks=2 if i < 2 else 1, stride=2)) # bilateral fusion self.compression_1 = ConvModule( channels * 4, channels * 2, kernel_size=1, norm_cfg=self.norm_cfg, act_cfg=None) self.down_1 = ConvModule( channels * 2, channels * 4, kernel_size=3, stride=2, padding=1, norm_cfg=self.norm_cfg, act_cfg=None) self.compression_2 = ConvModule( channels * 8, channels * 2, kernel_size=1, norm_cfg=self.norm_cfg, act_cfg=None) self.down_2 = nn.Sequential( ConvModule( channels * 2, channels * 4, kernel_size=3, stride=2, padding=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg), ConvModule( channels * 4, channels * 8, kernel_size=3, stride=2, padding=1, norm_cfg=self.norm_cfg, act_cfg=None)) # high resolution(spatial) branch self.spatial_branch_layers = nn.ModuleList() for i in range(3): self.spatial_branch_layers.append( self._make_layer( block=BasicBlock if i < 2 else Bottleneck, inplanes=channels * 2, planes=channels * 2, num_blocks=2 if i < 2 else 1, )) self.spp = DAPPM( channels * 16, ppm_channels, channels * 4, num_scales=5) def _make_stem_layer(self, in_channels, channels, num_blocks): layers = [ ConvModule( in_channels, channels, kernel_size=3, stride=2, padding=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg), ConvModule( channels, channels, kernel_size=3, stride=2, padding=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) ] layers.extend([ self._make_layer(BasicBlock, channels, channels, num_blocks), nn.ReLU(), self._make_layer( BasicBlock, channels, channels * 2, num_blocks, stride=2), nn.ReLU(), ]) return nn.Sequential(*layers) def _make_layer(self, block, inplanes, planes, num_blocks, stride=1): downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d( inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) layers = [ block( in_channels=inplanes, channels=planes, stride=stride, downsample=downsample) ] inplanes = planes * block.expansion for i in range(1, num_blocks): layers.append( block( in_channels=inplanes, channels=planes, stride=1, norm_cfg=self.norm_cfg, act_cfg_out=None if i == num_blocks - 1 else self.act_cfg)) return nn.Sequential(*layers) def forward(self, x): """Forward function.""" out_size = (x.shape[-2] // 8, x.shape[-1] // 8) # stage 0-2 x = self.stem(x) # stage3 x_c = self.context_branch_layers[0](x) x_s = self.spatial_branch_layers[0](x) comp_c = self.compression_1(self.relu(x_c)) x_c += self.down_1(self.relu(x_s)) x_s += resize( comp_c, size=out_size, mode='bilinear', align_corners=self.align_corners) if self.training: temp_context = x_s.clone() # stage4 x_c = self.context_branch_layers[1](self.relu(x_c)) x_s = self.spatial_branch_layers[1](self.relu(x_s)) comp_c = self.compression_2(self.relu(x_c)) x_c += self.down_2(self.relu(x_s)) x_s += resize( comp_c, size=out_size, mode='bilinear', align_corners=self.align_corners) # stage5 x_s = self.spatial_branch_layers[2](self.relu(x_s)) x_c = self.context_branch_layers[2](self.relu(x_c)) x_c = self.spp(x_c) x_c = resize( x_c, size=out_size, mode='bilinear', align_corners=self.align_corners) return (temp_context, x_s + x_c) if self.training else x_s + x_c