|
|
|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule |
|
|
|
from mmseg.registry import MODELS |
|
from ..utils import resize |
|
from .decode_head import BaseDecodeHead |
|
from .psp_head import PPM |
|
|
|
|
|
@MODELS.register_module() |
|
class UPerHead(BaseDecodeHead): |
|
"""Unified Perceptual Parsing for Scene Understanding. |
|
|
|
This head is the implementation of `UPerNet |
|
<https://arxiv.org/abs/1807.10221>`_. |
|
|
|
Args: |
|
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid |
|
Module applied on the last feature. Default: (1, 2, 3, 6). |
|
""" |
|
|
|
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): |
|
super().__init__(input_transform='multiple_select', **kwargs) |
|
|
|
self.psp_modules = PPM( |
|
pool_scales, |
|
self.in_channels[-1], |
|
self.channels, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg, |
|
align_corners=self.align_corners) |
|
self.bottleneck = ConvModule( |
|
self.in_channels[-1] + len(pool_scales) * self.channels, |
|
self.channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
|
|
self.lateral_convs = nn.ModuleList() |
|
self.fpn_convs = nn.ModuleList() |
|
for in_channels in self.in_channels[:-1]: |
|
l_conv = ConvModule( |
|
in_channels, |
|
self.channels, |
|
1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg, |
|
inplace=False) |
|
fpn_conv = ConvModule( |
|
self.channels, |
|
self.channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg, |
|
inplace=False) |
|
self.lateral_convs.append(l_conv) |
|
self.fpn_convs.append(fpn_conv) |
|
|
|
self.fpn_bottleneck = ConvModule( |
|
len(self.in_channels) * self.channels, |
|
self.channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
|
|
def psp_forward(self, inputs): |
|
"""Forward function of PSP module.""" |
|
x = inputs[-1] |
|
psp_outs = [x] |
|
psp_outs.extend(self.psp_modules(x)) |
|
psp_outs = torch.cat(psp_outs, dim=1) |
|
output = self.bottleneck(psp_outs) |
|
|
|
return output |
|
|
|
def _forward_feature(self, inputs): |
|
"""Forward function for feature maps before classifying each pixel with |
|
``self.cls_seg`` fc. |
|
|
|
Args: |
|
inputs (list[Tensor]): List of multi-level img features. |
|
|
|
Returns: |
|
feats (Tensor): A tensor of shape (batch_size, self.channels, |
|
H, W) which is feature map for last layer of decoder head. |
|
""" |
|
inputs = self._transform_inputs(inputs) |
|
|
|
|
|
laterals = [ |
|
lateral_conv(inputs[i]) |
|
for i, lateral_conv in enumerate(self.lateral_convs) |
|
] |
|
|
|
laterals.append(self.psp_forward(inputs)) |
|
|
|
|
|
used_backbone_levels = len(laterals) |
|
for i in range(used_backbone_levels - 1, 0, -1): |
|
prev_shape = laterals[i - 1].shape[2:] |
|
laterals[i - 1] = laterals[i - 1] + resize( |
|
laterals[i], |
|
size=prev_shape, |
|
mode='bilinear', |
|
align_corners=self.align_corners) |
|
|
|
|
|
fpn_outs = [ |
|
self.fpn_convs[i](laterals[i]) |
|
for i in range(used_backbone_levels - 1) |
|
] |
|
|
|
fpn_outs.append(laterals[-1]) |
|
|
|
for i in range(used_backbone_levels - 1, 0, -1): |
|
fpn_outs[i] = resize( |
|
fpn_outs[i], |
|
size=fpn_outs[0].shape[2:], |
|
mode='bilinear', |
|
align_corners=self.align_corners) |
|
fpn_outs = torch.cat(fpn_outs, dim=1) |
|
feats = self.fpn_bottleneck(fpn_outs) |
|
return feats |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
output = self._forward_feature(inputs) |
|
output = self.cls_seg(output) |
|
return output |
|
|