|
|
|
import math |
|
import warnings |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.utils.checkpoint as cp |
|
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer |
|
from mmcv.cnn.bricks.drop import build_dropout |
|
from mmcv.cnn.bricks.transformer import MultiheadAttention |
|
from mmengine.model import BaseModule, ModuleList, Sequential |
|
from mmengine.model.weight_init import (constant_init, normal_init, |
|
trunc_normal_init) |
|
|
|
from mmseg.registry import MODELS |
|
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw |
|
|
|
|
|
class MixFFN(BaseModule): |
|
"""An implementation of MixFFN of Segformer. |
|
|
|
The differences between MixFFN & FFN: |
|
1. Use 1X1 Conv to replace Linear layer. |
|
2. Introduce 3X3 Conv to encode positional information. |
|
Args: |
|
embed_dims (int): The feature dimension. Same as |
|
`MultiheadAttention`. Defaults: 256. |
|
feedforward_channels (int): The hidden dimension of FFNs. |
|
Defaults: 1024. |
|
act_cfg (dict, optional): The activation config for FFNs. |
|
Default: dict(type='ReLU') |
|
ffn_drop (float, optional): Probability of an element to be |
|
zeroed in FFN. Default 0.0. |
|
dropout_layer (obj:`ConfigDict`): The dropout_layer used |
|
when adding the shortcut. |
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. |
|
Default: None. |
|
""" |
|
|
|
def __init__(self, |
|
embed_dims, |
|
feedforward_channels, |
|
act_cfg=dict(type='GELU'), |
|
ffn_drop=0., |
|
dropout_layer=None, |
|
init_cfg=None): |
|
super().__init__(init_cfg) |
|
|
|
self.embed_dims = embed_dims |
|
self.feedforward_channels = feedforward_channels |
|
self.act_cfg = act_cfg |
|
self.activate = build_activation_layer(act_cfg) |
|
|
|
in_channels = embed_dims |
|
fc1 = Conv2d( |
|
in_channels=in_channels, |
|
out_channels=feedforward_channels, |
|
kernel_size=1, |
|
stride=1, |
|
bias=True) |
|
|
|
pe_conv = Conv2d( |
|
in_channels=feedforward_channels, |
|
out_channels=feedforward_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=(3 - 1) // 2, |
|
bias=True, |
|
groups=feedforward_channels) |
|
fc2 = Conv2d( |
|
in_channels=feedforward_channels, |
|
out_channels=in_channels, |
|
kernel_size=1, |
|
stride=1, |
|
bias=True) |
|
drop = nn.Dropout(ffn_drop) |
|
layers = [fc1, pe_conv, self.activate, drop, fc2, drop] |
|
self.layers = Sequential(*layers) |
|
self.dropout_layer = build_dropout( |
|
dropout_layer) if dropout_layer else torch.nn.Identity() |
|
|
|
def forward(self, x, hw_shape, identity=None): |
|
out = nlc_to_nchw(x, hw_shape) |
|
out = self.layers(out) |
|
out = nchw_to_nlc(out) |
|
if identity is None: |
|
identity = x |
|
return identity + self.dropout_layer(out) |
|
|
|
|
|
class EfficientMultiheadAttention(MultiheadAttention): |
|
"""An implementation of Efficient Multi-head Attention of Segformer. |
|
|
|
This module is modified from MultiheadAttention which is a module from |
|
mmcv.cnn.bricks.transformer. |
|
Args: |
|
embed_dims (int): The embedding dimension. |
|
num_heads (int): Parallel attention heads. |
|
attn_drop (float): A Dropout layer on attn_output_weights. |
|
Default: 0.0. |
|
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. |
|
Default: 0.0. |
|
dropout_layer (obj:`ConfigDict`): The dropout_layer used |
|
when adding the shortcut. Default: None. |
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. |
|
Default: None. |
|
batch_first (bool): Key, Query and Value are shape of |
|
(batch, n, embed_dim) |
|
or (n, batch, embed_dim). Default: False. |
|
qkv_bias (bool): enable bias for qkv if True. Default True. |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: dict(type='LN'). |
|
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head |
|
Attention of Segformer. Default: 1. |
|
""" |
|
|
|
def __init__(self, |
|
embed_dims, |
|
num_heads, |
|
attn_drop=0., |
|
proj_drop=0., |
|
dropout_layer=None, |
|
init_cfg=None, |
|
batch_first=True, |
|
qkv_bias=False, |
|
norm_cfg=dict(type='LN'), |
|
sr_ratio=1): |
|
super().__init__( |
|
embed_dims, |
|
num_heads, |
|
attn_drop, |
|
proj_drop, |
|
dropout_layer=dropout_layer, |
|
init_cfg=init_cfg, |
|
batch_first=batch_first, |
|
bias=qkv_bias) |
|
|
|
self.sr_ratio = sr_ratio |
|
if sr_ratio > 1: |
|
self.sr = Conv2d( |
|
in_channels=embed_dims, |
|
out_channels=embed_dims, |
|
kernel_size=sr_ratio, |
|
stride=sr_ratio) |
|
|
|
self.norm = build_norm_layer(norm_cfg, embed_dims)[1] |
|
|
|
|
|
from mmseg import digit_version, mmcv_version |
|
if mmcv_version < digit_version('1.3.17'): |
|
warnings.warn('The legacy version of forward function in' |
|
'EfficientMultiheadAttention is deprecated in' |
|
'mmcv>=1.3.17 and will no longer support in the' |
|
'future. Please upgrade your mmcv.') |
|
self.forward = self.legacy_forward |
|
|
|
def forward(self, x, hw_shape, identity=None): |
|
|
|
x_q = x |
|
if self.sr_ratio > 1: |
|
x_kv = nlc_to_nchw(x, hw_shape) |
|
x_kv = self.sr(x_kv) |
|
x_kv = nchw_to_nlc(x_kv) |
|
x_kv = self.norm(x_kv) |
|
else: |
|
x_kv = x |
|
|
|
if identity is None: |
|
identity = x_q |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.batch_first: |
|
x_q = x_q.transpose(0, 1) |
|
x_kv = x_kv.transpose(0, 1) |
|
|
|
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] |
|
|
|
if self.batch_first: |
|
out = out.transpose(0, 1) |
|
|
|
return identity + self.dropout_layer(self.proj_drop(out)) |
|
|
|
def legacy_forward(self, x, hw_shape, identity=None): |
|
"""multi head attention forward in mmcv version < 1.3.17.""" |
|
|
|
x_q = x |
|
if self.sr_ratio > 1: |
|
x_kv = nlc_to_nchw(x, hw_shape) |
|
x_kv = self.sr(x_kv) |
|
x_kv = nchw_to_nlc(x_kv) |
|
x_kv = self.norm(x_kv) |
|
else: |
|
x_kv = x |
|
|
|
if identity is None: |
|
identity = x_q |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0] |
|
|
|
return identity + self.dropout_layer(self.proj_drop(out)) |
|
|
|
|
|
class TransformerEncoderLayer(BaseModule): |
|
"""Implements one encoder layer in Segformer. |
|
|
|
Args: |
|
embed_dims (int): The feature dimension. |
|
num_heads (int): Parallel attention heads. |
|
feedforward_channels (int): The hidden dimension for FFNs. |
|
drop_rate (float): Probability of an element to be zeroed. |
|
after the feed forward layer. Default 0.0. |
|
attn_drop_rate (float): The drop out rate for attention layer. |
|
Default 0.0. |
|
drop_path_rate (float): stochastic depth rate. Default 0.0. |
|
qkv_bias (bool): enable bias for qkv if True. |
|
Default: True. |
|
act_cfg (dict): The activation config for FFNs. |
|
Default: dict(type='GELU'). |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: dict(type='LN'). |
|
batch_first (bool): Key, Query and Value are shape of |
|
(batch, n, embed_dim) |
|
or (n, batch, embed_dim). Default: False. |
|
init_cfg (dict, optional): Initialization config dict. |
|
Default:None. |
|
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head |
|
Attention of Segformer. Default: 1. |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save |
|
some memory while slowing down the training speed. Default: False. |
|
""" |
|
|
|
def __init__(self, |
|
embed_dims, |
|
num_heads, |
|
feedforward_channels, |
|
drop_rate=0., |
|
attn_drop_rate=0., |
|
drop_path_rate=0., |
|
qkv_bias=True, |
|
act_cfg=dict(type='GELU'), |
|
norm_cfg=dict(type='LN'), |
|
batch_first=True, |
|
sr_ratio=1, |
|
with_cp=False): |
|
super().__init__() |
|
|
|
|
|
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] |
|
|
|
self.attn = EfficientMultiheadAttention( |
|
embed_dims=embed_dims, |
|
num_heads=num_heads, |
|
attn_drop=attn_drop_rate, |
|
proj_drop=drop_rate, |
|
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), |
|
batch_first=batch_first, |
|
qkv_bias=qkv_bias, |
|
norm_cfg=norm_cfg, |
|
sr_ratio=sr_ratio) |
|
|
|
|
|
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] |
|
|
|
self.ffn = MixFFN( |
|
embed_dims=embed_dims, |
|
feedforward_channels=feedforward_channels, |
|
ffn_drop=drop_rate, |
|
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), |
|
act_cfg=act_cfg) |
|
|
|
self.with_cp = with_cp |
|
|
|
def forward(self, x, hw_shape): |
|
|
|
def _inner_forward(x): |
|
x = self.attn(self.norm1(x), hw_shape, identity=x) |
|
x = self.ffn(self.norm2(x), hw_shape, identity=x) |
|
return x |
|
|
|
if self.with_cp and x.requires_grad: |
|
x = cp.checkpoint(_inner_forward, x) |
|
else: |
|
x = _inner_forward(x) |
|
return x |
|
|
|
|
|
@MODELS.register_module() |
|
class MixVisionTransformer(BaseModule): |
|
"""The backbone of Segformer. |
|
|
|
This backbone is the implementation of `SegFormer: Simple and |
|
Efficient Design for Semantic Segmentation with |
|
Transformers <https://arxiv.org/abs/2105.15203>`_. |
|
Args: |
|
in_channels (int): Number of input channels. Default: 3. |
|
embed_dims (int): Embedding dimension. Default: 768. |
|
num_stags (int): The num of stages. Default: 4. |
|
num_layers (Sequence[int]): The layer number of each transformer encode |
|
layer. Default: [3, 4, 6, 3]. |
|
num_heads (Sequence[int]): The attention heads of each transformer |
|
encode layer. Default: [1, 2, 4, 8]. |
|
patch_sizes (Sequence[int]): The patch_size of each overlapped patch |
|
embedding. Default: [7, 3, 3, 3]. |
|
strides (Sequence[int]): The stride of each overlapped patch embedding. |
|
Default: [4, 2, 2, 2]. |
|
sr_ratios (Sequence[int]): The spatial reduction rate of each |
|
transformer encode layer. Default: [8, 4, 2, 1]. |
|
out_indices (Sequence[int] | int): Output from which stages. |
|
Default: (0, 1, 2, 3). |
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim. |
|
Default: 4. |
|
qkv_bias (bool): Enable bias for qkv if True. Default: True. |
|
drop_rate (float): Probability of an element to be zeroed. |
|
Default 0.0 |
|
attn_drop_rate (float): The drop out rate for attention layer. |
|
Default 0.0 |
|
drop_path_rate (float): stochastic depth rate. Default 0.0 |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: dict(type='LN') |
|
act_cfg (dict): The activation config for FFNs. |
|
Default: dict(type='GELU'). |
|
pretrained (str, optional): model pretrained path. Default: None. |
|
init_cfg (dict or list[dict], optional): Initialization config dict. |
|
Default: None. |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save |
|
some memory while slowing down the training speed. Default: False. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels=3, |
|
embed_dims=64, |
|
num_stages=4, |
|
num_layers=[3, 4, 6, 3], |
|
num_heads=[1, 2, 4, 8], |
|
patch_sizes=[7, 3, 3, 3], |
|
strides=[4, 2, 2, 2], |
|
sr_ratios=[8, 4, 2, 1], |
|
out_indices=(0, 1, 2, 3), |
|
mlp_ratio=4, |
|
qkv_bias=True, |
|
drop_rate=0., |
|
attn_drop_rate=0., |
|
drop_path_rate=0., |
|
act_cfg=dict(type='GELU'), |
|
norm_cfg=dict(type='LN', eps=1e-6), |
|
pretrained=None, |
|
init_cfg=None, |
|
with_cp=False): |
|
super().__init__(init_cfg=init_cfg) |
|
|
|
assert not (init_cfg and pretrained), \ |
|
'init_cfg and pretrained cannot be set at the same time' |
|
if isinstance(pretrained, str): |
|
warnings.warn('DeprecationWarning: pretrained is deprecated, ' |
|
'please use "init_cfg" instead') |
|
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) |
|
elif pretrained is not None: |
|
raise TypeError('pretrained must be a str or None') |
|
|
|
self.embed_dims = embed_dims |
|
self.num_stages = num_stages |
|
self.num_layers = num_layers |
|
self.num_heads = num_heads |
|
self.patch_sizes = patch_sizes |
|
self.strides = strides |
|
self.sr_ratios = sr_ratios |
|
self.with_cp = with_cp |
|
assert num_stages == len(num_layers) == len(num_heads) \ |
|
== len(patch_sizes) == len(strides) == len(sr_ratios) |
|
|
|
self.out_indices = out_indices |
|
assert max(out_indices) < self.num_stages |
|
|
|
|
|
dpr = [ |
|
x.item() |
|
for x in torch.linspace(0, drop_path_rate, sum(num_layers)) |
|
] |
|
|
|
cur = 0 |
|
self.layers = ModuleList() |
|
for i, num_layer in enumerate(num_layers): |
|
embed_dims_i = embed_dims * num_heads[i] |
|
patch_embed = PatchEmbed( |
|
in_channels=in_channels, |
|
embed_dims=embed_dims_i, |
|
kernel_size=patch_sizes[i], |
|
stride=strides[i], |
|
padding=patch_sizes[i] // 2, |
|
norm_cfg=norm_cfg) |
|
layer = ModuleList([ |
|
TransformerEncoderLayer( |
|
embed_dims=embed_dims_i, |
|
num_heads=num_heads[i], |
|
feedforward_channels=mlp_ratio * embed_dims_i, |
|
drop_rate=drop_rate, |
|
attn_drop_rate=attn_drop_rate, |
|
drop_path_rate=dpr[cur + idx], |
|
qkv_bias=qkv_bias, |
|
act_cfg=act_cfg, |
|
norm_cfg=norm_cfg, |
|
with_cp=with_cp, |
|
sr_ratio=sr_ratios[i]) for idx in range(num_layer) |
|
]) |
|
in_channels = embed_dims_i |
|
|
|
norm = build_norm_layer(norm_cfg, embed_dims_i)[1] |
|
self.layers.append(ModuleList([patch_embed, layer, norm])) |
|
cur += num_layer |
|
|
|
def init_weights(self): |
|
if self.init_cfg is None: |
|
for m in self.modules(): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_init(m, std=.02, bias=0.) |
|
elif isinstance(m, nn.LayerNorm): |
|
constant_init(m, val=1.0, bias=0.) |
|
elif isinstance(m, nn.Conv2d): |
|
fan_out = m.kernel_size[0] * m.kernel_size[ |
|
1] * m.out_channels |
|
fan_out //= m.groups |
|
normal_init( |
|
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) |
|
else: |
|
super().init_weights() |
|
|
|
def forward(self, x): |
|
outs = [] |
|
|
|
for i, layer in enumerate(self.layers): |
|
x, hw_shape = layer[0](x) |
|
for block in layer[1]: |
|
x = block(x, hw_shape) |
|
x = layer[2](x) |
|
x = nlc_to_nchw(x, hw_shape) |
|
if i in self.out_indices: |
|
outs.append(x) |
|
|
|
return outs |
|
|