| | |
| | from typing import Tuple |
| |
|
| | import torch.nn.functional as F |
| | from mmcv.cnn import ConvModule |
| | from mmcv.cnn.bricks import NonLocal2d |
| | from mmengine.model import BaseModule |
| | from torch import Tensor |
| |
|
| | from mmdet.registry import MODELS |
| | from mmdet.utils import OptConfigType, OptMultiConfig |
| |
|
| |
|
| | @MODELS.register_module() |
| | class BFP(BaseModule): |
| | """BFP (Balanced Feature Pyramids) |
| | |
| | BFP takes multi-level features as inputs and gather them into a single one, |
| | then refine the gathered feature and scatter the refined results to |
| | multi-level features. This module is used in Libra R-CNN (CVPR 2019), see |
| | the paper `Libra R-CNN: Towards Balanced Learning for Object Detection |
| | <https://arxiv.org/abs/1904.02701>`_ for details. |
| | |
| | Args: |
| | in_channels (int): Number of input channels (feature maps of all levels |
| | should have the same channels). |
| | num_levels (int): Number of input feature levels. |
| | refine_level (int): Index of integration and refine level of BSF in |
| | multi-level features from bottom to top. |
| | refine_type (str): Type of the refine op, currently support |
| | [None, 'conv', 'non_local']. |
| | conv_cfg (:obj:`ConfigDict` or dict, optional): The config dict for |
| | convolution layers. |
| | norm_cfg (:obj:`ConfigDict` or dict, optional): The config dict for |
| | normalization layers. |
| | init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or |
| | dict], optional): Initialization config dict. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int, |
| | num_levels: int, |
| | refine_level: int = 2, |
| | refine_type: str = None, |
| | conv_cfg: OptConfigType = None, |
| | norm_cfg: OptConfigType = None, |
| | init_cfg: OptMultiConfig = dict( |
| | type='Xavier', layer='Conv2d', distribution='uniform') |
| | ) -> None: |
| | super().__init__(init_cfg=init_cfg) |
| | assert refine_type in [None, 'conv', 'non_local'] |
| |
|
| | self.in_channels = in_channels |
| | self.num_levels = num_levels |
| | self.conv_cfg = conv_cfg |
| | self.norm_cfg = norm_cfg |
| |
|
| | self.refine_level = refine_level |
| | self.refine_type = refine_type |
| | assert 0 <= self.refine_level < self.num_levels |
| |
|
| | if self.refine_type == 'conv': |
| | self.refine = ConvModule( |
| | self.in_channels, |
| | self.in_channels, |
| | 3, |
| | padding=1, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg) |
| | elif self.refine_type == 'non_local': |
| | self.refine = NonLocal2d( |
| | self.in_channels, |
| | reduction=1, |
| | use_scale=False, |
| | conv_cfg=self.conv_cfg, |
| | norm_cfg=self.norm_cfg) |
| |
|
| | def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]: |
| | """Forward function.""" |
| | assert len(inputs) == self.num_levels |
| |
|
| | |
| | feats = [] |
| | gather_size = inputs[self.refine_level].size()[2:] |
| | for i in range(self.num_levels): |
| | if i < self.refine_level: |
| | gathered = F.adaptive_max_pool2d( |
| | inputs[i], output_size=gather_size) |
| | else: |
| | gathered = F.interpolate( |
| | inputs[i], size=gather_size, mode='nearest') |
| | feats.append(gathered) |
| |
|
| | bsf = sum(feats) / len(feats) |
| |
|
| | |
| | if self.refine_type is not None: |
| | bsf = self.refine(bsf) |
| |
|
| | |
| | outs = [] |
| | for i in range(self.num_levels): |
| | out_size = inputs[i].size()[2:] |
| | if i < self.refine_level: |
| | residual = F.interpolate(bsf, size=out_size, mode='nearest') |
| | else: |
| | residual = F.adaptive_max_pool2d(bsf, output_size=out_size) |
| | outs.append(residual + inputs[i]) |
| |
|
| | return tuple(outs) |
| |
|