Robert001's picture
first commit
b334e29
raw
history blame
No virus
3.53 kB
import torch
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from annotator.uniformer.mmseg.ops import resize
from ..builder import HEADS
from .aspp_head import ASPPHead, ASPPModule
class DepthwiseSeparableASPPModule(ASPPModule):
"""Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
conv."""
def __init__(self, **kwargs):
super(DepthwiseSeparableASPPModule, self).__init__(**kwargs)
for i, dilation in enumerate(self.dilations):
if dilation > 1:
self[i] = DepthwiseSeparableConvModule(
self.in_channels,
self.channels,
3,
dilation=dilation,
padding=dilation,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
@HEADS.register_module()
class DepthwiseSeparableASPPHead(ASPPHead):
"""Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation.
This head is the implementation of `DeepLabV3+
<https://arxiv.org/abs/1802.02611>`_.
Args:
c1_in_channels (int): The input channels of c1 decoder. If is 0,
the no decoder will be used.
c1_channels (int): The intermediate channels of c1 decoder.
"""
def __init__(self, c1_in_channels, c1_channels, **kwargs):
super(DepthwiseSeparableASPPHead, self).__init__(**kwargs)
assert c1_in_channels >= 0
self.aspp_modules = DepthwiseSeparableASPPModule(
dilations=self.dilations,
in_channels=self.in_channels,
channels=self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if c1_in_channels > 0:
self.c1_bottleneck = ConvModule(
c1_in_channels,
c1_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
else:
self.c1_bottleneck = None
self.sep_bottleneck = nn.Sequential(
DepthwiseSeparableConvModule(
self.channels + c1_channels,
self.channels,
3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
DepthwiseSeparableConvModule(
self.channels,
self.channels,
3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
aspp_outs = [
resize(
self.image_pool(x),
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
]
aspp_outs.extend(self.aspp_modules(x))
aspp_outs = torch.cat(aspp_outs, dim=1)
output = self.bottleneck(aspp_outs)
if self.c1_bottleneck is not None:
c1_output = self.c1_bottleneck(inputs[0])
output = resize(
input=output,
size=c1_output.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
output = torch.cat([output, c1_output], dim=1)
output = self.sep_bottleneck(output)
output = self.cls_seg(output)
return output