Robert001's picture
first commit
b334e29
raw
history blame
No virus
4.01 kB
import torch
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import ConvModule
from annotator.uniformer.mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
from .psp_head import PPM
@HEADS.register_module()
class UPerHead(BaseDecodeHead):
"""Unified Perceptual Parsing for Scene Understanding.
This head is the implementation of `UPerNet
<https://arxiv.org/abs/1807.10221>`_.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module applied on the last feature. Default: (1, 2, 3, 6).
"""
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(UPerHead, self).__init__(
input_transform='multiple_select', **kwargs)
# PSP Module
self.psp_modules = PPM(
pool_scales,
self.in_channels[-1],
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.bottleneck = ConvModule(
self.in_channels[-1] + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
# FPN Module
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for in_channels in self.in_channels[:-1]: # skip the top layer
l_conv = ConvModule(
in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
fpn_conv = ConvModule(
self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_bottleneck = ConvModule(
len(self.in_channels) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def psp_forward(self, inputs):
"""Forward function of PSP module."""
x = inputs[-1]
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
return output
def forward(self, inputs):
"""Forward function."""
inputs = self._transform_inputs(inputs)
# build laterals
laterals = [
lateral_conv(inputs[i])
for i, lateral_conv in enumerate(self.lateral_convs)
]
laterals.append(self.psp_forward(inputs))
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += resize(
laterals[i],
size=prev_shape,
mode='bilinear',
align_corners=self.align_corners)
# build outputs
fpn_outs = [
self.fpn_convs[i](laterals[i])
for i in range(used_backbone_levels - 1)
]
# append psp feature
fpn_outs.append(laterals[-1])
for i in range(used_backbone_levels - 1, 0, -1):
fpn_outs[i] = resize(
fpn_outs[i],
size=fpn_outs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners)
fpn_outs = torch.cat(fpn_outs, dim=1)
output = self.fpn_bottleneck(fpn_outs)
output = self.cls_seg(output)
return output