KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) Open-CD. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import Conv2d, ConvModule, build_activation_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmengine.model import BaseModule, Sequential
from torch.nn import functional as F
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.models.utils import resize
from opencd.registry import MODELS
from ..necks.feature_fusion import FeatureFusionNeck
class FDAF(BaseModule):
"""Flow Dual-Alignment Fusion Module.
Args:
in_channels (int): Input channels of features.
conv_cfg (dict | None): Config of conv layers.
Default: None
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN')
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
"""
def __init__(self,
in_channels,
conv_cfg=None,
norm_cfg=dict(type='IN'),
act_cfg=dict(type='GELU')):
super(FDAF, self).__init__()
self.in_channels = in_channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
# TODO
conv_cfg=None
norm_cfg=dict(type='IN')
act_cfg=dict(type='GELU')
kernel_size = 5
self.flow_make = Sequential(
nn.Conv2d(in_channels*2, in_channels*2, kernel_size=kernel_size, padding=(kernel_size-1)//2, bias=True, groups=in_channels*2),
nn.InstanceNorm2d(in_channels*2),
nn.GELU(),
nn.Conv2d(in_channels*2, 4, kernel_size=1, padding=0, bias=False),
)
def forward(self, x1, x2, fusion_policy=None):
"""Forward function."""
output = torch.cat([x1, x2], dim=1)
flow = self.flow_make(output)
f1, f2 = torch.chunk(flow, 2, dim=1)
x1_feat = self.warp(x1, f1) - x2
x2_feat = self.warp(x2, f2) - x1
if fusion_policy == None:
return x1_feat, x2_feat
output = FeatureFusionNeck.fusion(x1_feat, x2_feat, fusion_policy)
return output
@staticmethod
def warp(x, flow):
n, c, h, w = x.size()
norm = torch.tensor([[[[w, h]]]]).type_as(x).to(x.device)
col = torch.linspace(-1.0, 1.0, h).view(-1, 1).repeat(1, w)
row = torch.linspace(-1.0, 1.0, w).repeat(h, 1)
grid = torch.cat((row.unsqueeze(2), col.unsqueeze(2)), 2)
grid = grid.repeat(n, 1, 1, 1).type_as(x).to(x.device)
grid = grid + flow.permute(0, 2, 3, 1) / norm
output = F.grid_sample(x, grid, align_corners=True)
return output
class MixFFN(BaseModule):
"""An implementation of MixFFN of Segformer. \
Here MixFFN is uesd as projection head of Changer.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`. Defaults: 256.
feedforward_channels (int): The hidden dimension of FFNs.
Defaults: 1024.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='ReLU')
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
feedforward_channels,
act_cfg=dict(type='GELU'),
ffn_drop=0.,
dropout_layer=None,
init_cfg=None):
super(MixFFN, self).__init__(init_cfg)
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.act_cfg = act_cfg
self.activate = build_activation_layer(act_cfg)
in_channels = embed_dims
fc1 = Conv2d(
in_channels=in_channels,
out_channels=feedforward_channels,
kernel_size=1,
stride=1,
bias=True)
# 3x3 depth wise conv to provide positional encode information
pe_conv = Conv2d(
in_channels=feedforward_channels,
out_channels=feedforward_channels,
kernel_size=3,
stride=1,
padding=(3 - 1) // 2,
bias=True,
groups=feedforward_channels)
fc2 = Conv2d(
in_channels=feedforward_channels,
out_channels=in_channels,
kernel_size=1,
stride=1,
bias=True)
drop = nn.Dropout(ffn_drop)
layers = [fc1, pe_conv, self.activate, drop, fc2, drop]
self.layers = Sequential(*layers)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else torch.nn.Identity()
def forward(self, x, identity=None):
out = self.layers(x)
if identity is None:
identity = x
return identity + self.dropout_layer(out)
@MODELS.register_module()
class Changer(BaseDecodeHead):
"""The Head of Changer.
This head is the implementation of
`Changer <https://arxiv.org/abs/2209.08290>` _.
Args:
interpolate_mode: The interpolate mode of MLP head upsample operation.
Default: 'bilinear'.
"""
def __init__(self, interpolate_mode='bilinear', **kwargs):
super().__init__(input_transform='multiple_select', **kwargs)
self.interpolate_mode = interpolate_mode
num_inputs = len(self.in_channels)
assert num_inputs == len(self.in_index)
self.convs = nn.ModuleList()
for i in range(num_inputs):
self.convs.append(
ConvModule(
in_channels=self.in_channels[i],
out_channels=self.channels,
kernel_size=1,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.fusion_conv = ConvModule(
in_channels=self.channels * num_inputs,
out_channels=self.channels // 2,
kernel_size=1,
norm_cfg=self.norm_cfg)
self.neck_layer = FDAF(in_channels=self.channels // 2)
# projection head
self.discriminator = MixFFN(
embed_dims=self.channels,
feedforward_channels=self.channels,
ffn_drop=0.,
dropout_layer=dict(type='DropPath', drop_prob=0.),
act_cfg=dict(type='GELU'))
def base_forward(self, inputs):
outs = []
for idx in range(len(inputs)):
x = inputs[idx]
conv = self.convs[idx]
outs.append(
resize(
input=conv(x),
size=inputs[0].shape[2:],
mode=self.interpolate_mode,
align_corners=self.align_corners))
out = self.fusion_conv(torch.cat(outs, dim=1))
return out
def forward(self, inputs):
# Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
inputs = self._transform_inputs(inputs)
inputs1 = []
inputs2 = []
for input in inputs:
f1, f2 = torch.chunk(input, 2, dim=1)
inputs1.append(f1)
inputs2.append(f2)
out1 = self.base_forward(inputs1)
out2 = self.base_forward(inputs2)
out = self.neck_layer(out1, out2, 'concat')
out = self.discriminator(out)
out = self.cls_seg(out)
return out