import torch import torch.nn as nn from annotator.uniformer.mmcv.cnn import ConvModule from ..builder import HEADS from ..utils import SelfAttentionBlock as _SelfAttentionBlock from .decode_head import BaseDecodeHead class PPMConcat(nn.ModuleList): """Pyramid Pooling Module that only concat the features of each layer. Args: pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module. """ def __init__(self, pool_scales=(1, 3, 6, 8)): super(PPMConcat, self).__init__( [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales]) def forward(self, feats): """Forward function.""" ppm_outs = [] for ppm in self: ppm_out = ppm(feats) ppm_outs.append(ppm_out.view(*feats.shape[:2], -1)) concat_outs = torch.cat(ppm_outs, dim=2) return concat_outs class SelfAttentionBlock(_SelfAttentionBlock): """Make a ANN used SelfAttentionBlock. Args: low_in_channels (int): Input channels of lower level feature, which is the key feature for self-attention. high_in_channels (int): Input channels of higher level feature, which is the query feature for self-attention. channels (int): Output channels of key/query transform. out_channels (int): Output channels. share_key_query (bool): Whether share projection weight between key and query projection. query_scale (int): The scale of query feature map. key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module of key feature. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict|None): Config of activation layers. """ def __init__(self, low_in_channels, high_in_channels, channels, out_channels, share_key_query, query_scale, key_pool_scales, conv_cfg, norm_cfg, act_cfg): key_psp = PPMConcat(key_pool_scales) if query_scale > 1: query_downsample = nn.MaxPool2d(kernel_size=query_scale) else: query_downsample = None super(SelfAttentionBlock, self).__init__( key_in_channels=low_in_channels, query_in_channels=high_in_channels, channels=channels, out_channels=out_channels, share_key_query=share_key_query, query_downsample=query_downsample, key_downsample=key_psp, key_query_num_convs=1, key_query_norm=True, value_out_num_convs=1, value_out_norm=False, matmul_norm=True, with_out=True, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) class AFNB(nn.Module): """Asymmetric Fusion Non-local Block(AFNB) Args: low_in_channels (int): Input channels of lower level feature, which is the key feature for self-attention. high_in_channels (int): Input channels of higher level feature, which is the query feature for self-attention. channels (int): Output channels of key/query transform. out_channels (int): Output channels. and query projection. query_scales (tuple[int]): The scales of query feature map. Default: (1,) key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module of key feature. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict|None): Config of activation layers. """ def __init__(self, low_in_channels, high_in_channels, channels, out_channels, query_scales, key_pool_scales, conv_cfg, norm_cfg, act_cfg): super(AFNB, self).__init__() self.stages = nn.ModuleList() for query_scale in query_scales: self.stages.append( SelfAttentionBlock( low_in_channels=low_in_channels, high_in_channels=high_in_channels, channels=channels, out_channels=out_channels, share_key_query=False, query_scale=query_scale, key_pool_scales=key_pool_scales, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) self.bottleneck = ConvModule( out_channels + high_in_channels, out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) def forward(self, low_feats, high_feats): """Forward function.""" priors = [stage(high_feats, low_feats) for stage in self.stages] context = torch.stack(priors, dim=0).sum(dim=0) output = self.bottleneck(torch.cat([context, high_feats], 1)) return output class APNB(nn.Module): """Asymmetric Pyramid Non-local Block (APNB) Args: in_channels (int): Input channels of key/query feature, which is the key feature for self-attention. channels (int): Output channels of key/query transform. out_channels (int): Output channels. query_scales (tuple[int]): The scales of query feature map. Default: (1,) key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module of key feature. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict|None): Config of activation layers. """ def __init__(self, in_channels, channels, out_channels, query_scales, key_pool_scales, conv_cfg, norm_cfg, act_cfg): super(APNB, self).__init__() self.stages = nn.ModuleList() for query_scale in query_scales: self.stages.append( SelfAttentionBlock( low_in_channels=in_channels, high_in_channels=in_channels, channels=channels, out_channels=out_channels, share_key_query=True, query_scale=query_scale, key_pool_scales=key_pool_scales, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) self.bottleneck = ConvModule( 2 * in_channels, out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) def forward(self, feats): """Forward function.""" priors = [stage(feats, feats) for stage in self.stages] context = torch.stack(priors, dim=0).sum(dim=0) output = self.bottleneck(torch.cat([context, feats], 1)) return output @HEADS.register_module() class ANNHead(BaseDecodeHead): """Asymmetric Non-local Neural Networks for Semantic Segmentation. This head is the implementation of `ANNNet `_. Args: project_channels (int): Projection channels for Nonlocal. query_scales (tuple[int]): The scales of query feature map. Default: (1,) key_pool_scales (tuple[int]): The pooling scales of key feature map. Default: (1, 3, 6, 8). """ def __init__(self, project_channels, query_scales=(1, ), key_pool_scales=(1, 3, 6, 8), **kwargs): super(ANNHead, self).__init__( input_transform='multiple_select', **kwargs) assert len(self.in_channels) == 2 low_in_channels, high_in_channels = self.in_channels self.project_channels = project_channels self.fusion = AFNB( low_in_channels=low_in_channels, high_in_channels=high_in_channels, out_channels=high_in_channels, channels=project_channels, query_scales=query_scales, key_pool_scales=key_pool_scales, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.bottleneck = ConvModule( high_in_channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.context = APNB( in_channels=self.channels, out_channels=self.channels, channels=project_channels, query_scales=query_scales, key_pool_scales=key_pool_scales, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def forward(self, inputs): """Forward function.""" low_feats, high_feats = self._transform_inputs(inputs) output = self.fusion(low_feats, high_feats) output = self.dropout(output) output = self.bottleneck(output) output = self.context(output) output = self.cls_seg(output) return output