Robert001's picture
first commit
b334e29
raw
history blame
No virus
4.47 kB
import torch
from mmcv.runner import force_fp32
from mmdet.models.builder import ROI_EXTRACTORS
from .base_roi_extractor import BaseRoIExtractor
@ROI_EXTRACTORS.register_module()
class SingleRoIExtractor(BaseRoIExtractor):
"""Extract RoI features from a single level feature map.
If there are multiple input feature levels, each RoI is mapped to a level
according to its scale. The mapping rule is proposed in
`FPN <https://arxiv.org/abs/1612.03144>`_.
Args:
roi_layer (dict): Specify RoI layer type and arguments.
out_channels (int): Output channels of RoI layers.
featmap_strides (List[int]): Strides of input feature maps.
finest_scale (int): Scale threshold of mapping to level 0. Default: 56.
"""
def __init__(self,
roi_layer,
out_channels,
featmap_strides,
finest_scale=56):
super(SingleRoIExtractor, self).__init__(roi_layer, out_channels,
featmap_strides)
self.finest_scale = finest_scale
def map_roi_levels(self, rois, num_levels):
"""Map rois to corresponding feature levels by scales.
- scale < finest_scale * 2: level 0
- finest_scale * 2 <= scale < finest_scale * 4: level 1
- finest_scale * 4 <= scale < finest_scale * 8: level 2
- scale >= finest_scale * 8: level 3
Args:
rois (Tensor): Input RoIs, shape (k, 5).
num_levels (int): Total level number.
Returns:
Tensor: Level index (0-based) of each RoI, shape (k, )
"""
scale = torch.sqrt(
(rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2]))
target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6))
target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
return target_lvls
@force_fp32(apply_to=('feats', ), out_fp16=True)
def forward(self, feats, rois, roi_scale_factor=None):
"""Forward function."""
out_size = self.roi_layers[0].output_size
num_levels = len(feats)
expand_dims = (-1, self.out_channels * out_size[0] * out_size[1])
if torch.onnx.is_in_onnx_export():
# Work around to export mask-rcnn to onnx
roi_feats = rois[:, :1].clone().detach()
roi_feats = roi_feats.expand(*expand_dims)
roi_feats = roi_feats.reshape(-1, self.out_channels, *out_size)
roi_feats = roi_feats * 0
else:
roi_feats = feats[0].new_zeros(
rois.size(0), self.out_channels, *out_size)
# TODO: remove this when parrots supports
if torch.__version__ == 'parrots':
roi_feats.requires_grad = True
if num_levels == 1:
if len(rois) == 0:
return roi_feats
return self.roi_layers[0](feats[0], rois)
target_lvls = self.map_roi_levels(rois, num_levels)
if roi_scale_factor is not None:
rois = self.roi_rescale(rois, roi_scale_factor)
for i in range(num_levels):
mask = target_lvls == i
if torch.onnx.is_in_onnx_export():
# To keep all roi_align nodes exported to onnx
# and skip nonzero op
mask = mask.float().unsqueeze(-1).expand(*expand_dims).reshape(
roi_feats.shape)
roi_feats_t = self.roi_layers[i](feats[i], rois)
roi_feats_t *= mask
roi_feats += roi_feats_t
continue
inds = mask.nonzero(as_tuple=False).squeeze(1)
if inds.numel() > 0:
rois_ = rois[inds]
roi_feats_t = self.roi_layers[i](feats[i], rois_)
roi_feats[inds] = roi_feats_t
else:
# Sometimes some pyramid levels will not be used for RoI
# feature extraction and this will cause an incomplete
# computation graph in one GPU, which is different from those
# in other GPUs and will cause a hanging error.
# Therefore, we add it to ensure each feature pyramid is
# included in the computation graph to avoid runtime bugs.
roi_feats += sum(
x.view(-1)[0]
for x in self.parameters()) * 0. + feats[i].sum() * 0.
return roi_feats