Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import ConvModule, xavier_init | |
| from mmcv.cnn.bricks import NonLocal2d | |
| from ..builder import NECKS | |
| class BFP(nn.Module): | |
| """BFP (Balanced Feature Pyramids) | |
| BFP takes multi-level features as inputs and gather them into a single one, | |
| then refine the gathered feature and scatter the refined results to | |
| multi-level features. This module is used in Libra R-CNN (CVPR 2019), see | |
| the paper `Libra R-CNN: Towards Balanced Learning for Object Detection | |
| <https://arxiv.org/abs/1904.02701>`_ for details. | |
| Args: | |
| in_channels (int): Number of input channels (feature maps of all levels | |
| should have the same channels). | |
| num_levels (int): Number of input feature levels. | |
| conv_cfg (dict): The config dict for convolution layers. | |
| norm_cfg (dict): The config dict for normalization layers. | |
| refine_level (int): Index of integration and refine level of BSF in | |
| multi-level features from bottom to top. | |
| refine_type (str): Type of the refine op, currently support | |
| [None, 'conv', 'non_local']. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| num_levels, | |
| refine_level=2, | |
| refine_type=None, | |
| conv_cfg=None, | |
| norm_cfg=None): | |
| super(BFP, self).__init__() | |
| assert refine_type in [None, 'conv', 'non_local'] | |
| self.in_channels = in_channels | |
| self.num_levels = num_levels | |
| self.conv_cfg = conv_cfg | |
| self.norm_cfg = norm_cfg | |
| self.refine_level = refine_level | |
| self.refine_type = refine_type | |
| assert 0 <= self.refine_level < self.num_levels | |
| if self.refine_type == 'conv': | |
| self.refine = ConvModule( | |
| self.in_channels, | |
| self.in_channels, | |
| 3, | |
| padding=1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg) | |
| elif self.refine_type == 'non_local': | |
| self.refine = NonLocal2d( | |
| self.in_channels, | |
| reduction=1, | |
| use_scale=False, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg) | |
| def init_weights(self): | |
| """Initialize the weights of FPN module.""" | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| xavier_init(m, distribution='uniform') | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| assert len(inputs) == self.num_levels | |
| # step 1: gather multi-level features by resize and average | |
| feats = [] | |
| gather_size = inputs[self.refine_level].size()[2:] | |
| for i in range(self.num_levels): | |
| if i < self.refine_level: | |
| gathered = F.adaptive_max_pool2d( | |
| inputs[i], output_size=gather_size) | |
| else: | |
| gathered = F.interpolate( | |
| inputs[i], size=gather_size, mode='nearest') | |
| feats.append(gathered) | |
| bsf = sum(feats) / len(feats) | |
| # step 2: refine gathered features | |
| if self.refine_type is not None: | |
| bsf = self.refine(bsf) | |
| # step 3: scatter refined features to multi-levels by a residual path | |
| outs = [] | |
| for i in range(self.num_levels): | |
| out_size = inputs[i].size()[2:] | |
| if i < self.refine_level: | |
| residual = F.interpolate(bsf, size=out_size, mode='nearest') | |
| else: | |
| residual = F.adaptive_max_pool2d(bsf, output_size=out_size) | |
| outs.append(residual + inputs[i]) | |
| return tuple(outs) | |