KyanChen's picture
Upload 1861 files
3b96cb1
raw
history blame
10.3 kB
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, Linear, build_activation_layer
from mmengine.model import BaseModule
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead
class ReassembleBlocks(BaseModule):
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
rearrange the feature vector to feature map.
Args:
in_channels (int): ViT feature channels. Default: 768.
out_channels (List): output channels of each stage.
Default: [96, 192, 384, 768].
readout_type (str): Type of readout operation. Default: 'ignore'.
patch_size (int): The patch size. Default: 16.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
def __init__(self,
in_channels=768,
out_channels=[96, 192, 384, 768],
readout_type='ignore',
patch_size=16,
init_cfg=None):
super().__init__(init_cfg)
assert readout_type in ['ignore', 'add', 'project']
self.readout_type = readout_type
self.patch_size = patch_size
self.projects = nn.ModuleList([
ConvModule(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
act_cfg=None,
) for out_channel in out_channels
])
self.resize_layers = nn.ModuleList([
nn.ConvTranspose2d(
in_channels=out_channels[0],
out_channels=out_channels[0],
kernel_size=4,
stride=4,
padding=0),
nn.ConvTranspose2d(
in_channels=out_channels[1],
out_channels=out_channels[1],
kernel_size=2,
stride=2,
padding=0),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3],
out_channels=out_channels[3],
kernel_size=3,
stride=2,
padding=1)
])
if self.readout_type == 'project':
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(
nn.Sequential(
Linear(2 * in_channels, in_channels),
build_activation_layer(dict(type='GELU'))))
def forward(self, inputs):
assert isinstance(inputs, list)
out = []
for i, x in enumerate(inputs):
assert len(x) == 2
x, cls_token = x[0], x[1]
feature_shape = x.shape
if self.readout_type == 'project':
x = x.flatten(2).permute((0, 2, 1))
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
x = x.permute(0, 2, 1).reshape(feature_shape)
elif self.readout_type == 'add':
x = x.flatten(2) + cls_token.unsqueeze(-1)
x = x.reshape(feature_shape)
else:
pass
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
return out
class PreActResidualConvUnit(BaseModule):
"""ResidualConvUnit, pre-activate residual unit.
Args:
in_channels (int): number of channels in the input feature map.
act_cfg (dict): dictionary to construct and config activation layer.
norm_cfg (dict): dictionary to construct and config norm layer.
stride (int): stride of the first block. Default: 1
dilation (int): dilation rate for convs layers. Default: 1.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
def __init__(self,
in_channels,
act_cfg,
norm_cfg,
stride=1,
dilation=1,
init_cfg=None):
super().__init__(init_cfg)
self.conv1 = ConvModule(
in_channels,
in_channels,
3,
stride=stride,
padding=dilation,
dilation=dilation,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
bias=False,
order=('act', 'conv', 'norm'))
self.conv2 = ConvModule(
in_channels,
in_channels,
3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
bias=False,
order=('act', 'conv', 'norm'))
def forward(self, inputs):
inputs_ = inputs.clone()
x = self.conv1(inputs)
x = self.conv2(x)
return x + inputs_
class FeatureFusionBlock(BaseModule):
"""FeatureFusionBlock, merge feature map from different stages.
Args:
in_channels (int): Input channels.
act_cfg (dict): The activation config for ResidualConvUnit.
norm_cfg (dict): Config dict for normalization layer.
expand (bool): Whether expand the channels in post process block.
Default: False.
align_corners (bool): align_corner setting for bilinear upsample.
Default: True.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
def __init__(self,
in_channels,
act_cfg,
norm_cfg,
expand=False,
align_corners=True,
init_cfg=None):
super().__init__(init_cfg)
self.in_channels = in_channels
self.expand = expand
self.align_corners = align_corners
self.out_channels = in_channels
if self.expand:
self.out_channels = in_channels // 2
self.project = ConvModule(
self.in_channels,
self.out_channels,
kernel_size=1,
act_cfg=None,
bias=True)
self.res_conv_unit1 = PreActResidualConvUnit(
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
self.res_conv_unit2 = PreActResidualConvUnit(
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
def forward(self, *inputs):
x = inputs[0]
if len(inputs) == 2:
if x.shape != inputs[1].shape:
res = resize(
inputs[1],
size=(x.shape[2], x.shape[3]),
mode='bilinear',
align_corners=False)
else:
res = inputs[1]
x = x + self.res_conv_unit1(res)
x = self.res_conv_unit2(x)
x = resize(
x,
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners)
x = self.project(x)
return x
@MODELS.register_module()
class DPTHead(BaseDecodeHead):
"""Vision Transformers for Dense Prediction.
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
Args:
embed_dims (int): The embed dimension of the ViT backbone.
Default: 768.
post_process_channels (List): Out channels of post process conv
layers. Default: [96, 192, 384, 768].
readout_type (str): Type of readout operation. Default: 'ignore'.
patch_size (int): The patch size. Default: 16.
expand_channels (bool): Whether expand the channels in post process
block. Default: False.
act_cfg (dict): The activation config for residual conv unit.
Default dict(type='ReLU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
"""
def __init__(self,
embed_dims=768,
post_process_channels=[96, 192, 384, 768],
readout_type='ignore',
patch_size=16,
expand_channels=False,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN'),
**kwargs):
super().__init__(**kwargs)
self.in_channels = self.in_channels
self.expand_channels = expand_channels
self.reassemble_blocks = ReassembleBlocks(embed_dims,
post_process_channels,
readout_type, patch_size)
self.post_process_channels = [
channel * math.pow(2, i) if expand_channels else channel
for i, channel in enumerate(post_process_channels)
]
self.convs = nn.ModuleList()
for channel in self.post_process_channels:
self.convs.append(
ConvModule(
channel,
self.channels,
kernel_size=3,
padding=1,
act_cfg=None,
bias=False))
self.fusion_blocks = nn.ModuleList()
for _ in range(len(self.convs)):
self.fusion_blocks.append(
FeatureFusionBlock(self.channels, act_cfg, norm_cfg))
self.fusion_blocks[0].res_conv_unit1 = None
self.project = ConvModule(
self.channels,
self.channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg)
self.num_fusion_blocks = len(self.fusion_blocks)
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
self.num_post_process_channels = len(self.post_process_channels)
assert self.num_fusion_blocks == self.num_reassemble_blocks
assert self.num_reassemble_blocks == self.num_post_process_channels
def forward(self, inputs):
assert len(inputs) == self.num_reassemble_blocks
x = self._transform_inputs(inputs)
x = self.reassemble_blocks(x)
x = [self.convs[i](feature) for i, feature in enumerate(x)]
out = self.fusion_blocks[0](x[-1])
for i in range(1, len(self.fusion_blocks)):
out = self.fusion_blocks[i](out, x[-(i + 1)])
out = self.project(out)
out = self.cls_seg(out)
return out