# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from mmcv.cnn import ConvModule from mmseg.registry import MODELS from .decode_head import BaseDecodeHead @MODELS.register_module() class FCNHead(BaseDecodeHead): """Fully Convolution Networks for Semantic Segmentation. This head is implemented of `FCNNet `_. Args: num_convs (int): Number of convs in the head. Default: 2. kernel_size (int): The kernel size for convs in the head. Default: 3. concat_input (bool): Whether concat the input and output of convs before classification layer. dilation (int): The dilation rate for convs in the head. Default: 1. """ def __init__(self, num_convs=2, kernel_size=3, concat_input=True, dilation=1, **kwargs): assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) self.num_convs = num_convs self.concat_input = concat_input self.kernel_size = kernel_size super().__init__(**kwargs) if num_convs == 0: assert self.in_channels == self.channels conv_padding = (kernel_size // 2) * dilation convs = [] convs.append( ConvModule( self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) for i in range(num_convs - 1): convs.append( ConvModule( self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) if num_convs == 0: self.convs = nn.Identity() else: self.convs = nn.Sequential(*convs) if self.concat_input: self.conv_cat = ConvModule( self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def _forward_feature(self, inputs): """Forward function for feature maps before classifying each pixel with ``self.cls_seg`` fc. Args: inputs (list[Tensor]): List of multi-level img features. Returns: feats (Tensor): A tensor of shape (batch_size, self.channels, H, W) which is feature map for last layer of decoder head. """ x = self._transform_inputs(inputs) feats = self.convs(x) if self.concat_input: feats = self.conv_cat(torch.cat([x, feats], dim=1)) return feats def forward(self, inputs): """Forward function.""" output = self._forward_feature(inputs) output = self.cls_seg(output) return output