# Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn from mmcv.cnn import ConvModule, build_upsample_layer, xavier_init from mmcv.ops.carafe import CARAFEPack from mmcv.runner import BaseModule, ModuleList from ..builder import NECKS @NECKS.register_module() class FPN_CARAFE(BaseModule): """FPN_CARAFE is a more flexible implementation of FPN. It allows more choice for upsample methods during the top-down pathway. It can reproduce the performance of ICCV 2019 paper CARAFE: Content-Aware ReAssembly of FEatures Please refer to https://arxiv.org/abs/1905.02188 for more details. Args: in_channels (list[int]): Number of channels for each input feature map. out_channels (int): Output channels of feature pyramids. num_outs (int): Number of output stages. start_level (int): Start level of feature pyramids. (Default: 0) end_level (int): End level of feature pyramids. (Default: -1 indicates the last level). norm_cfg (dict): Dictionary to construct and config norm layer. activate (str): Type of activation function in ConvModule (Default: None indicates w/o activation). order (dict): Order of components in ConvModule. upsample (str): Type of upsample layer. upsample_cfg (dict): Dictionary to construct and config upsample layer. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None """ def __init__(self, in_channels, out_channels, num_outs, start_level=0, end_level=-1, norm_cfg=None, act_cfg=None, order=('conv', 'norm', 'act'), upsample_cfg=dict( type='carafe', up_kernel=5, up_group=1, encoder_kernel=3, encoder_dilation=1), init_cfg=None): assert init_cfg is None, 'To prevent abnormal initialization ' \ 'behavior, init_cfg is not allowed to be set' super(FPN_CARAFE, self).__init__(init_cfg) assert isinstance(in_channels, list) self.in_channels = in_channels self.out_channels = out_channels self.num_ins = len(in_channels) self.num_outs = num_outs self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.with_bias = norm_cfg is None self.upsample_cfg = upsample_cfg.copy() self.upsample = self.upsample_cfg.get('type') self.relu = nn.ReLU(inplace=False) self.order = order assert order in [('conv', 'norm', 'act'), ('act', 'conv', 'norm')] assert self.upsample in [ 'nearest', 'bilinear', 'deconv', 'pixel_shuffle', 'carafe', None ] if self.upsample in ['deconv', 'pixel_shuffle']: assert hasattr( self.upsample_cfg, 'upsample_kernel') and self.upsample_cfg.upsample_kernel > 0 self.upsample_kernel = self.upsample_cfg.pop('upsample_kernel') if end_level == -1 or end_level == self.num_ins - 1: self.backbone_end_level = self.num_ins assert num_outs >= self.num_ins - start_level else: # if end_level is not the last level, no extra level is allowed self.backbone_end_level = end_level + 1 assert end_level < self.num_ins assert num_outs == end_level - start_level + 1 self.start_level = start_level self.end_level = end_level self.lateral_convs = ModuleList() self.fpn_convs = ModuleList() self.upsample_modules = ModuleList() for i in range(self.start_level, self.backbone_end_level): l_conv = ConvModule( in_channels[i], out_channels, 1, norm_cfg=norm_cfg, bias=self.with_bias, act_cfg=act_cfg, inplace=False, order=self.order) fpn_conv = ConvModule( out_channels, out_channels, 3, padding=1, norm_cfg=self.norm_cfg, bias=self.with_bias, act_cfg=act_cfg, inplace=False, order=self.order) if i != self.backbone_end_level - 1: upsample_cfg_ = self.upsample_cfg.copy() if self.upsample == 'deconv': upsample_cfg_.update( in_channels=out_channels, out_channels=out_channels, kernel_size=self.upsample_kernel, stride=2, padding=(self.upsample_kernel - 1) // 2, output_padding=(self.upsample_kernel - 1) // 2) elif self.upsample == 'pixel_shuffle': upsample_cfg_.update( in_channels=out_channels, out_channels=out_channels, scale_factor=2, upsample_kernel=self.upsample_kernel) elif self.upsample == 'carafe': upsample_cfg_.update(channels=out_channels, scale_factor=2) else: # suppress warnings align_corners = (None if self.upsample == 'nearest' else False) upsample_cfg_.update( scale_factor=2, mode=self.upsample, align_corners=align_corners) upsample_module = build_upsample_layer(upsample_cfg_) self.upsample_modules.append(upsample_module) self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) # add extra conv layers (e.g., RetinaNet) extra_out_levels = ( num_outs - self.backbone_end_level + self.start_level) if extra_out_levels >= 1: for i in range(extra_out_levels): in_channels = ( self.in_channels[self.backbone_end_level - 1] if i == 0 else out_channels) extra_l_conv = ConvModule( in_channels, out_channels, 3, stride=2, padding=1, norm_cfg=norm_cfg, bias=self.with_bias, act_cfg=act_cfg, inplace=False, order=self.order) if self.upsample == 'deconv': upsampler_cfg_ = dict( in_channels=out_channels, out_channels=out_channels, kernel_size=self.upsample_kernel, stride=2, padding=(self.upsample_kernel - 1) // 2, output_padding=(self.upsample_kernel - 1) // 2) elif self.upsample == 'pixel_shuffle': upsampler_cfg_ = dict( in_channels=out_channels, out_channels=out_channels, scale_factor=2, upsample_kernel=self.upsample_kernel) elif self.upsample == 'carafe': upsampler_cfg_ = dict( channels=out_channels, scale_factor=2, **self.upsample_cfg) else: # suppress warnings align_corners = (None if self.upsample == 'nearest' else False) upsampler_cfg_ = dict( scale_factor=2, mode=self.upsample, align_corners=align_corners) upsampler_cfg_['type'] = self.upsample upsample_module = build_upsample_layer(upsampler_cfg_) extra_fpn_conv = ConvModule( out_channels, out_channels, 3, padding=1, norm_cfg=self.norm_cfg, bias=self.with_bias, act_cfg=act_cfg, inplace=False, order=self.order) self.upsample_modules.append(upsample_module) self.fpn_convs.append(extra_fpn_conv) self.lateral_convs.append(extra_l_conv) # default init_weights for conv(msra) and norm in ConvModule def init_weights(self): """Initialize the weights of module.""" super(FPN_CARAFE, self).init_weights() for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): xavier_init(m, distribution='uniform') for m in self.modules(): if isinstance(m, CARAFEPack): m.init_weights() def slice_as(self, src, dst): """Slice ``src`` as ``dst`` Note: ``src`` should have the same or larger size than ``dst``. Args: src (torch.Tensor): Tensors to be sliced. dst (torch.Tensor): ``src`` will be sliced to have the same size as ``dst``. Returns: torch.Tensor: Sliced tensor. """ assert (src.size(2) >= dst.size(2)) and (src.size(3) >= dst.size(3)) if src.size(2) == dst.size(2) and src.size(3) == dst.size(3): return src else: return src[:, :, :dst.size(2), :dst.size(3)] def tensor_add(self, a, b): """Add tensors ``a`` and ``b`` that might have different sizes.""" if a.size() == b.size(): c = a + b else: c = a + self.slice_as(b, a) return c def forward(self, inputs): """Forward function.""" assert len(inputs) == len(self.in_channels) # build laterals laterals = [] for i, lateral_conv in enumerate(self.lateral_convs): if i <= self.backbone_end_level - self.start_level: input = inputs[min(i + self.start_level, len(inputs) - 1)] else: input = laterals[-1] lateral = lateral_conv(input) laterals.append(lateral) # build top-down path for i in range(len(laterals) - 1, 0, -1): if self.upsample is not None: upsample_feat = self.upsample_modules[i - 1](laterals[i]) else: upsample_feat = laterals[i] laterals[i - 1] = self.tensor_add(laterals[i - 1], upsample_feat) # build outputs num_conv_outs = len(self.fpn_convs) outs = [] for i in range(num_conv_outs): out = self.fpn_convs[i](laterals[i]) outs.append(out) return tuple(outs)