|
|
|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule |
|
|
|
from mmseg.registry import MODELS |
|
from .decode_head import BaseDecodeHead |
|
|
|
|
|
@MODELS.register_module() |
|
class FCNHead(BaseDecodeHead): |
|
"""Fully Convolution Networks for Semantic Segmentation. |
|
|
|
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_. |
|
|
|
Args: |
|
num_convs (int): Number of convs in the head. Default: 2. |
|
kernel_size (int): The kernel size for convs in the head. Default: 3. |
|
concat_input (bool): Whether concat the input and output of convs |
|
before classification layer. |
|
dilation (int): The dilation rate for convs in the head. Default: 1. |
|
""" |
|
|
|
def __init__(self, |
|
num_convs=2, |
|
kernel_size=3, |
|
concat_input=True, |
|
dilation=1, |
|
**kwargs): |
|
assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) |
|
self.num_convs = num_convs |
|
self.concat_input = concat_input |
|
self.kernel_size = kernel_size |
|
super().__init__(**kwargs) |
|
if num_convs == 0: |
|
assert self.in_channels == self.channels |
|
|
|
conv_padding = (kernel_size // 2) * dilation |
|
convs = [] |
|
convs.append( |
|
ConvModule( |
|
self.in_channels, |
|
self.channels, |
|
kernel_size=kernel_size, |
|
padding=conv_padding, |
|
dilation=dilation, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg)) |
|
for i in range(num_convs - 1): |
|
convs.append( |
|
ConvModule( |
|
self.channels, |
|
self.channels, |
|
kernel_size=kernel_size, |
|
padding=conv_padding, |
|
dilation=dilation, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg)) |
|
if num_convs == 0: |
|
self.convs = nn.Identity() |
|
else: |
|
self.convs = nn.Sequential(*convs) |
|
if self.concat_input: |
|
self.conv_cat = ConvModule( |
|
self.in_channels + self.channels, |
|
self.channels, |
|
kernel_size=kernel_size, |
|
padding=kernel_size // 2, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
|
|
def _forward_feature(self, inputs): |
|
"""Forward function for feature maps before classifying each pixel with |
|
``self.cls_seg`` fc. |
|
|
|
Args: |
|
inputs (list[Tensor]): List of multi-level img features. |
|
|
|
Returns: |
|
feats (Tensor): A tensor of shape (batch_size, self.channels, |
|
H, W) which is feature map for last layer of decoder head. |
|
""" |
|
x = self._transform_inputs(inputs) |
|
feats = self.convs(x) |
|
if self.concat_input: |
|
feats = self.conv_cat(torch.cat([x, feats], dim=1)) |
|
return feats |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
output = self._forward_feature(inputs) |
|
output = self.cls_seg(output) |
|
return output |
|
|