Spaces:
Runtime error
Runtime error
# Copyright (c) Open-CD. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import Conv2d, ConvModule, build_activation_layer | |
from mmcv.cnn.bricks.drop import build_dropout | |
from mmengine.model import BaseModule, Sequential | |
from torch.nn import functional as F | |
from mmseg.models.decode_heads.decode_head import BaseDecodeHead | |
from mmseg.models.utils import resize | |
from opencd.registry import MODELS | |
from ..necks.feature_fusion import FeatureFusionNeck | |
class FDAF(BaseModule): | |
"""Flow Dual-Alignment Fusion Module. | |
Args: | |
in_channels (int): Input channels of features. | |
conv_cfg (dict | None): Config of conv layers. | |
Default: None | |
norm_cfg (dict | None): Config of norm layers. | |
Default: dict(type='BN') | |
act_cfg (dict): Config of activation layers. | |
Default: dict(type='ReLU') | |
""" | |
def __init__(self, | |
in_channels, | |
conv_cfg=None, | |
norm_cfg=dict(type='IN'), | |
act_cfg=dict(type='GELU')): | |
super(FDAF, self).__init__() | |
self.in_channels = in_channels | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
# TODO | |
conv_cfg=None | |
norm_cfg=dict(type='IN') | |
act_cfg=dict(type='GELU') | |
kernel_size = 5 | |
self.flow_make = Sequential( | |
nn.Conv2d(in_channels*2, in_channels*2, kernel_size=kernel_size, padding=(kernel_size-1)//2, bias=True, groups=in_channels*2), | |
nn.InstanceNorm2d(in_channels*2), | |
nn.GELU(), | |
nn.Conv2d(in_channels*2, 4, kernel_size=1, padding=0, bias=False), | |
) | |
def forward(self, x1, x2, fusion_policy=None): | |
"""Forward function.""" | |
output = torch.cat([x1, x2], dim=1) | |
flow = self.flow_make(output) | |
f1, f2 = torch.chunk(flow, 2, dim=1) | |
x1_feat = self.warp(x1, f1) - x2 | |
x2_feat = self.warp(x2, f2) - x1 | |
if fusion_policy == None: | |
return x1_feat, x2_feat | |
output = FeatureFusionNeck.fusion(x1_feat, x2_feat, fusion_policy) | |
return output | |
def warp(x, flow): | |
n, c, h, w = x.size() | |
norm = torch.tensor([[[[w, h]]]]).type_as(x).to(x.device) | |
col = torch.linspace(-1.0, 1.0, h).view(-1, 1).repeat(1, w) | |
row = torch.linspace(-1.0, 1.0, w).repeat(h, 1) | |
grid = torch.cat((row.unsqueeze(2), col.unsqueeze(2)), 2) | |
grid = grid.repeat(n, 1, 1, 1).type_as(x).to(x.device) | |
grid = grid + flow.permute(0, 2, 3, 1) / norm | |
output = F.grid_sample(x, grid, align_corners=True) | |
return output | |
class MixFFN(BaseModule): | |
"""An implementation of MixFFN of Segformer. \ | |
Here MixFFN is uesd as projection head of Changer. | |
Args: | |
embed_dims (int): The feature dimension. Same as | |
`MultiheadAttention`. Defaults: 256. | |
feedforward_channels (int): The hidden dimension of FFNs. | |
Defaults: 1024. | |
act_cfg (dict, optional): The activation config for FFNs. | |
Default: dict(type='ReLU') | |
ffn_drop (float, optional): Probability of an element to be | |
zeroed in FFN. Default 0.0. | |
dropout_layer (obj:`ConfigDict`): The dropout_layer used | |
when adding the shortcut. | |
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. | |
Default: None. | |
""" | |
def __init__(self, | |
embed_dims, | |
feedforward_channels, | |
act_cfg=dict(type='GELU'), | |
ffn_drop=0., | |
dropout_layer=None, | |
init_cfg=None): | |
super(MixFFN, self).__init__(init_cfg) | |
self.embed_dims = embed_dims | |
self.feedforward_channels = feedforward_channels | |
self.act_cfg = act_cfg | |
self.activate = build_activation_layer(act_cfg) | |
in_channels = embed_dims | |
fc1 = Conv2d( | |
in_channels=in_channels, | |
out_channels=feedforward_channels, | |
kernel_size=1, | |
stride=1, | |
bias=True) | |
# 3x3 depth wise conv to provide positional encode information | |
pe_conv = Conv2d( | |
in_channels=feedforward_channels, | |
out_channels=feedforward_channels, | |
kernel_size=3, | |
stride=1, | |
padding=(3 - 1) // 2, | |
bias=True, | |
groups=feedforward_channels) | |
fc2 = Conv2d( | |
in_channels=feedforward_channels, | |
out_channels=in_channels, | |
kernel_size=1, | |
stride=1, | |
bias=True) | |
drop = nn.Dropout(ffn_drop) | |
layers = [fc1, pe_conv, self.activate, drop, fc2, drop] | |
self.layers = Sequential(*layers) | |
self.dropout_layer = build_dropout( | |
dropout_layer) if dropout_layer else torch.nn.Identity() | |
def forward(self, x, identity=None): | |
out = self.layers(x) | |
if identity is None: | |
identity = x | |
return identity + self.dropout_layer(out) | |
class Changer(BaseDecodeHead): | |
"""The Head of Changer. | |
This head is the implementation of | |
`Changer <https://arxiv.org/abs/2209.08290>` _. | |
Args: | |
interpolate_mode: The interpolate mode of MLP head upsample operation. | |
Default: 'bilinear'. | |
""" | |
def __init__(self, interpolate_mode='bilinear', **kwargs): | |
super().__init__(input_transform='multiple_select', **kwargs) | |
self.interpolate_mode = interpolate_mode | |
num_inputs = len(self.in_channels) | |
assert num_inputs == len(self.in_index) | |
self.convs = nn.ModuleList() | |
for i in range(num_inputs): | |
self.convs.append( | |
ConvModule( | |
in_channels=self.in_channels[i], | |
out_channels=self.channels, | |
kernel_size=1, | |
stride=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg)) | |
self.fusion_conv = ConvModule( | |
in_channels=self.channels * num_inputs, | |
out_channels=self.channels // 2, | |
kernel_size=1, | |
norm_cfg=self.norm_cfg) | |
self.neck_layer = FDAF(in_channels=self.channels // 2) | |
# projection head | |
self.discriminator = MixFFN( | |
embed_dims=self.channels, | |
feedforward_channels=self.channels, | |
ffn_drop=0., | |
dropout_layer=dict(type='DropPath', drop_prob=0.), | |
act_cfg=dict(type='GELU')) | |
def base_forward(self, inputs): | |
outs = [] | |
for idx in range(len(inputs)): | |
x = inputs[idx] | |
conv = self.convs[idx] | |
outs.append( | |
resize( | |
input=conv(x), | |
size=inputs[0].shape[2:], | |
mode=self.interpolate_mode, | |
align_corners=self.align_corners)) | |
out = self.fusion_conv(torch.cat(outs, dim=1)) | |
return out | |
def forward(self, inputs): | |
# Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 | |
inputs = self._transform_inputs(inputs) | |
inputs1 = [] | |
inputs2 = [] | |
for input in inputs: | |
f1, f2 = torch.chunk(input, 2, dim=1) | |
inputs1.append(f1) | |
inputs2.append(f2) | |
out1 = self.base_forward(inputs1) | |
out2 = self.base_forward(inputs2) | |
out = self.neck_layer(out1, out2, 'concat') | |
out = self.discriminator(out) | |
out = self.cls_seg(out) | |
return out | |