|
|
|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule, build_upsample_layer |
|
|
|
|
|
class UpConvBlock(nn.Module): |
|
"""Upsample convolution block in decoder for UNet. |
|
|
|
This upsample convolution block consists of one upsample module |
|
followed by one convolution block. The upsample module expands the |
|
high-level low-resolution feature map and the convolution block fuses |
|
the upsampled high-level low-resolution feature map and the low-level |
|
high-resolution feature map from encoder. |
|
|
|
Args: |
|
conv_block (nn.Sequential): Sequential of convolutional layers. |
|
in_channels (int): Number of input channels of the high-level |
|
skip_channels (int): Number of input channels of the low-level |
|
high-resolution feature map from encoder. |
|
out_channels (int): Number of output channels. |
|
num_convs (int): Number of convolutional layers in the conv_block. |
|
Default: 2. |
|
stride (int): Stride of convolutional layer in conv_block. Default: 1. |
|
dilation (int): Dilation rate of convolutional layer in conv_block. |
|
Default: 1. |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
|
memory while slowing down the training speed. Default: False. |
|
conv_cfg (dict | None): Config dict for convolution layer. |
|
Default: None. |
|
norm_cfg (dict | None): Config dict for normalization layer. |
|
Default: dict(type='BN'). |
|
act_cfg (dict | None): Config dict for activation layer in ConvModule. |
|
Default: dict(type='ReLU'). |
|
upsample_cfg (dict): The upsample config of the upsample module in |
|
decoder. Default: dict(type='InterpConv'). If the size of |
|
high-level feature map is the same as that of skip feature map |
|
(low-level feature map from encoder), it does not need upsample the |
|
high-level feature map and the upsample_cfg is None. |
|
dcn (bool): Use deformable convolution in convolutional layer or not. |
|
Default: None. |
|
plugins (dict): plugins for convolutional layers. Default: None. |
|
""" |
|
|
|
def __init__(self, |
|
conv_block, |
|
in_channels, |
|
skip_channels, |
|
out_channels, |
|
num_convs=2, |
|
stride=1, |
|
dilation=1, |
|
with_cp=False, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN'), |
|
act_cfg=dict(type='ReLU'), |
|
upsample_cfg=dict(type='InterpConv'), |
|
dcn=None, |
|
plugins=None): |
|
super().__init__() |
|
assert dcn is None, 'Not implemented yet.' |
|
assert plugins is None, 'Not implemented yet.' |
|
|
|
self.conv_block = conv_block( |
|
in_channels=2 * skip_channels, |
|
out_channels=out_channels, |
|
num_convs=num_convs, |
|
stride=stride, |
|
dilation=dilation, |
|
with_cp=with_cp, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg, |
|
dcn=None, |
|
plugins=None) |
|
if upsample_cfg is not None: |
|
self.upsample = build_upsample_layer( |
|
cfg=upsample_cfg, |
|
in_channels=in_channels, |
|
out_channels=skip_channels, |
|
with_cp=with_cp, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
else: |
|
self.upsample = ConvModule( |
|
in_channels, |
|
skip_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
|
|
def forward(self, skip, x): |
|
"""Forward function.""" |
|
|
|
x = self.upsample(x) |
|
out = torch.cat([skip, x], dim=1) |
|
out = self.conv_block(out) |
|
|
|
return out |
|
|