|
|
|
import numpy as np |
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule |
|
|
|
from mmseg.registry import MODELS |
|
from ..utils import Upsample, resize |
|
from .decode_head import BaseDecodeHead |
|
|
|
|
|
@MODELS.register_module() |
|
class FPNHead(BaseDecodeHead): |
|
"""Panoptic Feature Pyramid Networks. |
|
|
|
This head is the implementation of `Semantic FPN |
|
<https://arxiv.org/abs/1901.02446>`_. |
|
|
|
Args: |
|
feature_strides (tuple[int]): The strides for input feature maps. |
|
stack_lateral. All strides suppose to be power of 2. The first |
|
one is of largest resolution. |
|
""" |
|
|
|
def __init__(self, feature_strides, **kwargs): |
|
super().__init__(input_transform='multiple_select', **kwargs) |
|
assert len(feature_strides) == len(self.in_channels) |
|
assert min(feature_strides) == feature_strides[0] |
|
self.feature_strides = feature_strides |
|
|
|
self.scale_heads = nn.ModuleList() |
|
for i in range(len(feature_strides)): |
|
head_length = max( |
|
1, |
|
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) |
|
scale_head = [] |
|
for k in range(head_length): |
|
scale_head.append( |
|
ConvModule( |
|
self.in_channels[i] if k == 0 else self.channels, |
|
self.channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg)) |
|
if feature_strides[i] != feature_strides[0]: |
|
scale_head.append( |
|
Upsample( |
|
scale_factor=2, |
|
mode='bilinear', |
|
align_corners=self.align_corners)) |
|
self.scale_heads.append(nn.Sequential(*scale_head)) |
|
|
|
def forward(self, inputs): |
|
|
|
x = self._transform_inputs(inputs) |
|
|
|
output = self.scale_heads[0](x[0]) |
|
for i in range(1, len(self.feature_strides)): |
|
|
|
output = output + resize( |
|
self.scale_heads[i](x[i]), |
|
size=output.shape[2:], |
|
mode='bilinear', |
|
align_corners=self.align_corners) |
|
|
|
output = self.cls_seg(output) |
|
return output |
|
|