Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule | |
from mmengine.utils import is_tuple_of | |
from mmseg.registry import MODELS | |
from ..utils import resize | |
from .decode_head import BaseDecodeHead | |
class LRASPPHead(BaseDecodeHead): | |
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. | |
This head is the improved implementation of `Searching for MobileNetV3 | |
<https://ieeexplore.ieee.org/document/9008835>`_. | |
Args: | |
branch_channels (tuple[int]): The number of output channels in every | |
each branch. Default: (32, 64). | |
""" | |
def __init__(self, branch_channels=(32, 64), **kwargs): | |
super().__init__(**kwargs) | |
if self.input_transform != 'multiple_select': | |
raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' | |
f'must be \'multiple_select\'. But received ' | |
f'\'{self.input_transform}\'') | |
assert is_tuple_of(branch_channels, int) | |
assert len(branch_channels) == len(self.in_channels) - 1 | |
self.branch_channels = branch_channels | |
self.convs = nn.Sequential() | |
self.conv_ups = nn.Sequential() | |
for i in range(len(branch_channels)): | |
self.convs.add_module( | |
f'conv{i}', | |
nn.Conv2d( | |
self.in_channels[i], branch_channels[i], 1, bias=False)) | |
self.conv_ups.add_module( | |
f'conv_up{i}', | |
ConvModule( | |
self.channels + branch_channels[i], | |
self.channels, | |
1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg, | |
bias=False)) | |
self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) | |
self.aspp_conv = ConvModule( | |
self.in_channels[-1], | |
self.channels, | |
1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg, | |
bias=False) | |
self.image_pool = nn.Sequential( | |
nn.AvgPool2d(kernel_size=49, stride=(16, 20)), | |
ConvModule( | |
self.in_channels[2], | |
self.channels, | |
1, | |
act_cfg=dict(type='Sigmoid'), | |
bias=False)) | |
def forward(self, inputs): | |
"""Forward function.""" | |
inputs = self._transform_inputs(inputs) | |
x = inputs[-1] | |
x = self.aspp_conv(x) * resize( | |
self.image_pool(x), | |
size=x.size()[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) | |
x = self.conv_up_input(x) | |
for i in range(len(self.branch_channels) - 1, -1, -1): | |
x = resize( | |
x, | |
size=inputs[i].size()[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) | |
x = torch.cat([x, self.convs[i](inputs[i])], 1) | |
x = self.conv_ups[i](x) | |
return self.cls_seg(x) | |