# Copyright (c) Open-CD. All rights reserved. import torch import torch.nn as nn from mmcv.cnn import ConvModule from mmseg.models.decode_heads.decode_head import BaseDecodeHead from mmseg.models.utils import resize from opencd.registry import MODELS @MODELS.register_module() class TinyHead(BaseDecodeHead): """ This head is the implementation of `TinyCDv2 `_. Args: feature_strides (tuple[int]): The strides for input feature maps. stack_lateral. All strides suppose to be power of 2. The first one is of largest resolution. priori_attn (bool): Whether use Priori Guiding Connection. Default to False. """ def __init__(self, feature_strides, priori_attn=False, **kwargs): super().__init__(input_transform='multiple_select', **kwargs) assert len(feature_strides) == len(self.in_channels) assert min(feature_strides) == feature_strides[0] if priori_attn: attn_channels = self.in_channels[0] self.in_channels = self.in_channels[1:] feature_strides = feature_strides[1:] self.feature_strides = feature_strides self.priori_attn = priori_attn self.scale_heads = nn.ModuleList() for i in range(len(feature_strides)): scale_head = [] scale_head.append( ConvModule( in_channels=self.in_channels[i], out_channels=self.channels, kernel_size=1, stride=1, groups=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.scale_heads.append(nn.Sequential(*scale_head)) if self.priori_attn: self.gen_diff_attn = ConvModule( in_channels=attn_channels // 2, out_channels=self.channels, kernel_size=1, stride=1, groups=1, norm_cfg=None, act_cfg=None ) def forward(self, inputs): x = self._transform_inputs(inputs) if self.priori_attn: early_x = x[0] x = x[1:] output = self.scale_heads[0](x[0]) for i in range(1, len(self.feature_strides)): # non inplace output = output + resize( self.scale_heads[i](x[i]), size=output.shape[2:], mode='bilinear', align_corners=self.align_corners) if self.priori_attn: x1_, x2_ = torch.chunk(early_x, 2, dim=1) diff_x = torch.abs(x1_ - x2_) diff_x = self.gen_diff_attn(diff_x) if diff_x.shape != output.shape: output = resize(output, diff_x.shape[2:], mode='bilinear', align_corners=self.align_corners) output = output * torch.sigmoid(diff_x) + output output = self.cls_seg(output) return output