|
|
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule, Linear, build_activation_layer |
|
from mmengine.model import BaseModule |
|
|
|
from mmseg.registry import MODELS |
|
from ..utils import resize |
|
from .decode_head import BaseDecodeHead |
|
|
|
|
|
class ReassembleBlocks(BaseModule): |
|
"""ViTPostProcessBlock, process cls_token in ViT backbone output and |
|
rearrange the feature vector to feature map. |
|
|
|
Args: |
|
in_channels (int): ViT feature channels. Default: 768. |
|
out_channels (List): output channels of each stage. |
|
Default: [96, 192, 384, 768]. |
|
readout_type (str): Type of readout operation. Default: 'ignore'. |
|
patch_size (int): The patch size. Default: 16. |
|
init_cfg (dict, optional): Initialization config dict. Default: None. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels=768, |
|
out_channels=[96, 192, 384, 768], |
|
readout_type='ignore', |
|
patch_size=16, |
|
init_cfg=None): |
|
super().__init__(init_cfg) |
|
|
|
assert readout_type in ['ignore', 'add', 'project'] |
|
self.readout_type = readout_type |
|
self.patch_size = patch_size |
|
|
|
self.projects = nn.ModuleList([ |
|
ConvModule( |
|
in_channels=in_channels, |
|
out_channels=out_channel, |
|
kernel_size=1, |
|
act_cfg=None, |
|
) for out_channel in out_channels |
|
]) |
|
|
|
self.resize_layers = nn.ModuleList([ |
|
nn.ConvTranspose2d( |
|
in_channels=out_channels[0], |
|
out_channels=out_channels[0], |
|
kernel_size=4, |
|
stride=4, |
|
padding=0), |
|
nn.ConvTranspose2d( |
|
in_channels=out_channels[1], |
|
out_channels=out_channels[1], |
|
kernel_size=2, |
|
stride=2, |
|
padding=0), |
|
nn.Identity(), |
|
nn.Conv2d( |
|
in_channels=out_channels[3], |
|
out_channels=out_channels[3], |
|
kernel_size=3, |
|
stride=2, |
|
padding=1) |
|
]) |
|
if self.readout_type == 'project': |
|
self.readout_projects = nn.ModuleList() |
|
for _ in range(len(self.projects)): |
|
self.readout_projects.append( |
|
nn.Sequential( |
|
Linear(2 * in_channels, in_channels), |
|
build_activation_layer(dict(type='GELU')))) |
|
|
|
def forward(self, inputs): |
|
assert isinstance(inputs, list) |
|
out = [] |
|
for i, x in enumerate(inputs): |
|
assert len(x) == 2 |
|
x, cls_token = x[0], x[1] |
|
feature_shape = x.shape |
|
if self.readout_type == 'project': |
|
x = x.flatten(2).permute((0, 2, 1)) |
|
readout = cls_token.unsqueeze(1).expand_as(x) |
|
x = self.readout_projects[i](torch.cat((x, readout), -1)) |
|
x = x.permute(0, 2, 1).reshape(feature_shape) |
|
elif self.readout_type == 'add': |
|
x = x.flatten(2) + cls_token.unsqueeze(-1) |
|
x = x.reshape(feature_shape) |
|
else: |
|
pass |
|
x = self.projects[i](x) |
|
x = self.resize_layers[i](x) |
|
out.append(x) |
|
return out |
|
|
|
|
|
class PreActResidualConvUnit(BaseModule): |
|
"""ResidualConvUnit, pre-activate residual unit. |
|
|
|
Args: |
|
in_channels (int): number of channels in the input feature map. |
|
act_cfg (dict): dictionary to construct and config activation layer. |
|
norm_cfg (dict): dictionary to construct and config norm layer. |
|
stride (int): stride of the first block. Default: 1 |
|
dilation (int): dilation rate for convs layers. Default: 1. |
|
init_cfg (dict, optional): Initialization config dict. Default: None. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
act_cfg, |
|
norm_cfg, |
|
stride=1, |
|
dilation=1, |
|
init_cfg=None): |
|
super().__init__(init_cfg) |
|
|
|
self.conv1 = ConvModule( |
|
in_channels, |
|
in_channels, |
|
3, |
|
stride=stride, |
|
padding=dilation, |
|
dilation=dilation, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg, |
|
bias=False, |
|
order=('act', 'conv', 'norm')) |
|
|
|
self.conv2 = ConvModule( |
|
in_channels, |
|
in_channels, |
|
3, |
|
padding=1, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg, |
|
bias=False, |
|
order=('act', 'conv', 'norm')) |
|
|
|
def forward(self, inputs): |
|
inputs_ = inputs.clone() |
|
x = self.conv1(inputs) |
|
x = self.conv2(x) |
|
return x + inputs_ |
|
|
|
|
|
class FeatureFusionBlock(BaseModule): |
|
"""FeatureFusionBlock, merge feature map from different stages. |
|
|
|
Args: |
|
in_channels (int): Input channels. |
|
act_cfg (dict): The activation config for ResidualConvUnit. |
|
norm_cfg (dict): Config dict for normalization layer. |
|
expand (bool): Whether expand the channels in post process block. |
|
Default: False. |
|
align_corners (bool): align_corner setting for bilinear upsample. |
|
Default: True. |
|
init_cfg (dict, optional): Initialization config dict. Default: None. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
act_cfg, |
|
norm_cfg, |
|
expand=False, |
|
align_corners=True, |
|
init_cfg=None): |
|
super().__init__(init_cfg) |
|
|
|
self.in_channels = in_channels |
|
self.expand = expand |
|
self.align_corners = align_corners |
|
|
|
self.out_channels = in_channels |
|
if self.expand: |
|
self.out_channels = in_channels // 2 |
|
|
|
self.project = ConvModule( |
|
self.in_channels, |
|
self.out_channels, |
|
kernel_size=1, |
|
act_cfg=None, |
|
bias=True) |
|
|
|
self.res_conv_unit1 = PreActResidualConvUnit( |
|
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) |
|
self.res_conv_unit2 = PreActResidualConvUnit( |
|
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) |
|
|
|
def forward(self, *inputs): |
|
x = inputs[0] |
|
if len(inputs) == 2: |
|
if x.shape != inputs[1].shape: |
|
res = resize( |
|
inputs[1], |
|
size=(x.shape[2], x.shape[3]), |
|
mode='bilinear', |
|
align_corners=False) |
|
else: |
|
res = inputs[1] |
|
x = x + self.res_conv_unit1(res) |
|
x = self.res_conv_unit2(x) |
|
x = resize( |
|
x, |
|
scale_factor=2, |
|
mode='bilinear', |
|
align_corners=self.align_corners) |
|
x = self.project(x) |
|
return x |
|
|
|
|
|
@MODELS.register_module() |
|
class DPTHead(BaseDecodeHead): |
|
"""Vision Transformers for Dense Prediction. |
|
|
|
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_. |
|
|
|
Args: |
|
embed_dims (int): The embed dimension of the ViT backbone. |
|
Default: 768. |
|
post_process_channels (List): Out channels of post process conv |
|
layers. Default: [96, 192, 384, 768]. |
|
readout_type (str): Type of readout operation. Default: 'ignore'. |
|
patch_size (int): The patch size. Default: 16. |
|
expand_channels (bool): Whether expand the channels in post process |
|
block. Default: False. |
|
act_cfg (dict): The activation config for residual conv unit. |
|
Default dict(type='ReLU'). |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: dict(type='BN'). |
|
""" |
|
|
|
def __init__(self, |
|
embed_dims=768, |
|
post_process_channels=[96, 192, 384, 768], |
|
readout_type='ignore', |
|
patch_size=16, |
|
expand_channels=False, |
|
act_cfg=dict(type='ReLU'), |
|
norm_cfg=dict(type='BN'), |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.in_channels = self.in_channels |
|
self.expand_channels = expand_channels |
|
self.reassemble_blocks = ReassembleBlocks(embed_dims, |
|
post_process_channels, |
|
readout_type, patch_size) |
|
|
|
self.post_process_channels = [ |
|
channel * math.pow(2, i) if expand_channels else channel |
|
for i, channel in enumerate(post_process_channels) |
|
] |
|
self.convs = nn.ModuleList() |
|
for channel in self.post_process_channels: |
|
self.convs.append( |
|
ConvModule( |
|
channel, |
|
self.channels, |
|
kernel_size=3, |
|
padding=1, |
|
act_cfg=None, |
|
bias=False)) |
|
self.fusion_blocks = nn.ModuleList() |
|
for _ in range(len(self.convs)): |
|
self.fusion_blocks.append( |
|
FeatureFusionBlock(self.channels, act_cfg, norm_cfg)) |
|
self.fusion_blocks[0].res_conv_unit1 = None |
|
self.project = ConvModule( |
|
self.channels, |
|
self.channels, |
|
kernel_size=3, |
|
padding=1, |
|
norm_cfg=norm_cfg) |
|
self.num_fusion_blocks = len(self.fusion_blocks) |
|
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) |
|
self.num_post_process_channels = len(self.post_process_channels) |
|
assert self.num_fusion_blocks == self.num_reassemble_blocks |
|
assert self.num_reassemble_blocks == self.num_post_process_channels |
|
|
|
def forward(self, inputs): |
|
assert len(inputs) == self.num_reassemble_blocks |
|
x = self._transform_inputs(inputs) |
|
x = self.reassemble_blocks(x) |
|
x = [self.convs[i](feature) for i, feature in enumerate(x)] |
|
out = self.fusion_blocks[0](x[-1]) |
|
for i in range(1, len(self.fusion_blocks)): |
|
out = self.fusion_blocks[i](out, x[-(i + 1)]) |
|
out = self.project(out) |
|
out = self.cls_seg(out) |
|
return out |
|
|