# Copyright (c) Open-CD. All rights reserved. import torch import torch.nn as nn from mmcv.cnn import Conv2d, ConvModule, build_activation_layer from mmcv.cnn.bricks.drop import build_dropout from mmengine.model import BaseModule, Sequential from torch.nn import functional as F from mmseg.models.decode_heads.decode_head import BaseDecodeHead from mmseg.models.utils import resize from opencd.registry import MODELS from ..necks.feature_fusion import FeatureFusionNeck class FDAF(BaseModule): """Flow Dual-Alignment Fusion Module. Args: in_channels (int): Input channels of features. conv_cfg (dict | None): Config of conv layers. Default: None norm_cfg (dict | None): Config of norm layers. Default: dict(type='BN') act_cfg (dict): Config of activation layers. Default: dict(type='ReLU') """ def __init__(self, in_channels, conv_cfg=None, norm_cfg=dict(type='IN'), act_cfg=dict(type='GELU')): super(FDAF, self).__init__() self.in_channels = in_channels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg # TODO conv_cfg=None norm_cfg=dict(type='IN') act_cfg=dict(type='GELU') kernel_size = 5 self.flow_make = Sequential( nn.Conv2d(in_channels*2, in_channels*2, kernel_size=kernel_size, padding=(kernel_size-1)//2, bias=True, groups=in_channels*2), nn.InstanceNorm2d(in_channels*2), nn.GELU(), nn.Conv2d(in_channels*2, 4, kernel_size=1, padding=0, bias=False), ) def forward(self, x1, x2, fusion_policy=None): """Forward function.""" output = torch.cat([x1, x2], dim=1) flow = self.flow_make(output) f1, f2 = torch.chunk(flow, 2, dim=1) x1_feat = self.warp(x1, f1) - x2 x2_feat = self.warp(x2, f2) - x1 if fusion_policy == None: return x1_feat, x2_feat output = FeatureFusionNeck.fusion(x1_feat, x2_feat, fusion_policy) return output @staticmethod def warp(x, flow): n, c, h, w = x.size() norm = torch.tensor([[[[w, h]]]]).type_as(x).to(x.device) col = torch.linspace(-1.0, 1.0, h).view(-1, 1).repeat(1, w) row = torch.linspace(-1.0, 1.0, w).repeat(h, 1) grid = torch.cat((row.unsqueeze(2), col.unsqueeze(2)), 2) grid = grid.repeat(n, 1, 1, 1).type_as(x).to(x.device) grid = grid + flow.permute(0, 2, 3, 1) / norm output = F.grid_sample(x, grid, align_corners=True) return output class MixFFN(BaseModule): """An implementation of MixFFN of Segformer. \ Here MixFFN is uesd as projection head of Changer. Args: embed_dims (int): The feature dimension. Same as `MultiheadAttention`. Defaults: 256. feedforward_channels (int): The hidden dimension of FFNs. Defaults: 1024. act_cfg (dict, optional): The activation config for FFNs. Default: dict(type='ReLU') ffn_drop (float, optional): Probability of an element to be zeroed in FFN. Default 0.0. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. """ def __init__(self, embed_dims, feedforward_channels, act_cfg=dict(type='GELU'), ffn_drop=0., dropout_layer=None, init_cfg=None): super(MixFFN, self).__init__(init_cfg) self.embed_dims = embed_dims self.feedforward_channels = feedforward_channels self.act_cfg = act_cfg self.activate = build_activation_layer(act_cfg) in_channels = embed_dims fc1 = Conv2d( in_channels=in_channels, out_channels=feedforward_channels, kernel_size=1, stride=1, bias=True) # 3x3 depth wise conv to provide positional encode information pe_conv = Conv2d( in_channels=feedforward_channels, out_channels=feedforward_channels, kernel_size=3, stride=1, padding=(3 - 1) // 2, bias=True, groups=feedforward_channels) fc2 = Conv2d( in_channels=feedforward_channels, out_channels=in_channels, kernel_size=1, stride=1, bias=True) drop = nn.Dropout(ffn_drop) layers = [fc1, pe_conv, self.activate, drop, fc2, drop] self.layers = Sequential(*layers) self.dropout_layer = build_dropout( dropout_layer) if dropout_layer else torch.nn.Identity() def forward(self, x, identity=None): out = self.layers(x) if identity is None: identity = x return identity + self.dropout_layer(out) @MODELS.register_module() class Changer(BaseDecodeHead): """The Head of Changer. This head is the implementation of `Changer ` _. Args: interpolate_mode: The interpolate mode of MLP head upsample operation. Default: 'bilinear'. """ def __init__(self, interpolate_mode='bilinear', **kwargs): super().__init__(input_transform='multiple_select', **kwargs) self.interpolate_mode = interpolate_mode num_inputs = len(self.in_channels) assert num_inputs == len(self.in_index) self.convs = nn.ModuleList() for i in range(num_inputs): self.convs.append( ConvModule( in_channels=self.in_channels[i], out_channels=self.channels, kernel_size=1, stride=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.fusion_conv = ConvModule( in_channels=self.channels * num_inputs, out_channels=self.channels // 2, kernel_size=1, norm_cfg=self.norm_cfg) self.neck_layer = FDAF(in_channels=self.channels // 2) # projection head self.discriminator = MixFFN( embed_dims=self.channels, feedforward_channels=self.channels, ffn_drop=0., dropout_layer=dict(type='DropPath', drop_prob=0.), act_cfg=dict(type='GELU')) def base_forward(self, inputs): outs = [] for idx in range(len(inputs)): x = inputs[idx] conv = self.convs[idx] outs.append( resize( input=conv(x), size=inputs[0].shape[2:], mode=self.interpolate_mode, align_corners=self.align_corners)) out = self.fusion_conv(torch.cat(outs, dim=1)) return out def forward(self, inputs): # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 inputs = self._transform_inputs(inputs) inputs1 = [] inputs2 = [] for input in inputs: f1, f2 = torch.chunk(input, 2, dim=1) inputs1.append(f1) inputs2.append(f2) out1 = self.base_forward(inputs1) out2 = self.base_forward(inputs2) out = self.neck_layer(out1, out2, 'concat') out = self.discriminator(out) out = self.cls_seg(out) return out