Robert001's picture
first commit
b334e29
raw
history blame
No virus
2.77 kB
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from mmcv import ops
class BaseRoIExtractor(nn.Module, metaclass=ABCMeta):
"""Base class for RoI extractor.
Args:
roi_layer (dict): Specify RoI layer type and arguments.
out_channels (int): Output channels of RoI layers.
featmap_strides (List[int]): Strides of input feature maps.
"""
def __init__(self, roi_layer, out_channels, featmap_strides):
super(BaseRoIExtractor, self).__init__()
self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
self.out_channels = out_channels
self.featmap_strides = featmap_strides
self.fp16_enabled = False
@property
def num_inputs(self):
"""int: Number of input feature maps."""
return len(self.featmap_strides)
def init_weights(self):
pass
def build_roi_layers(self, layer_cfg, featmap_strides):
"""Build RoI operator to extract feature from each level feature map.
Args:
layer_cfg (dict): Dictionary to construct and config RoI layer
operation. Options are modules under ``mmcv/ops`` such as
``RoIAlign``.
featmap_strides (List[int]): The stride of input feature map w.r.t
to the original image size, which would be used to scale RoI
coordinate (original image coordinate system) to feature
coordinate system.
Returns:
nn.ModuleList: The RoI extractor modules for each level feature
map.
"""
cfg = layer_cfg.copy()
layer_type = cfg.pop('type')
assert hasattr(ops, layer_type)
layer_cls = getattr(ops, layer_type)
roi_layers = nn.ModuleList(
[layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides])
return roi_layers
def roi_rescale(self, rois, scale_factor):
"""Scale RoI coordinates by scale factor.
Args:
rois (torch.Tensor): RoI (Region of Interest), shape (n, 5)
scale_factor (float): Scale factor that RoI will be multiplied by.
Returns:
torch.Tensor: Scaled RoI.
"""
cx = (rois[:, 1] + rois[:, 3]) * 0.5
cy = (rois[:, 2] + rois[:, 4]) * 0.5
w = rois[:, 3] - rois[:, 1]
h = rois[:, 4] - rois[:, 2]
new_w = w * scale_factor
new_h = h * scale_factor
x1 = cx - new_w * 0.5
x2 = cx + new_w * 0.5
y1 = cy - new_h * 0.5
y2 = cy + new_h * 0.5
new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1)
return new_rois
@abstractmethod
def forward(self, feats, rois, roi_scale_factor=None):
pass