Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from __future__ import division | |
import copy | |
import warnings | |
import torch | |
import torch.nn as nn | |
from mmcv import ConfigDict | |
from mmcv.ops import DeformConv2d, batched_nms | |
from mmcv.runner import BaseModule, ModuleList | |
from mmdet.core import (RegionAssigner, build_assigner, build_sampler, | |
images_to_levels, multi_apply) | |
from mmdet.core.utils import select_single_mlvl | |
from ..builder import HEADS, build_head | |
from .base_dense_head import BaseDenseHead | |
from .rpn_head import RPNHead | |
class AdaptiveConv(BaseModule): | |
"""AdaptiveConv used to adapt the sampling location with the anchors. | |
Args: | |
in_channels (int): Number of channels in the input image | |
out_channels (int): Number of channels produced by the convolution | |
kernel_size (int or tuple): Size of the conv kernel. Default: 3 | |
stride (int or tuple, optional): Stride of the convolution. Default: 1 | |
padding (int or tuple, optional): Zero-padding added to both sides of | |
the input. Default: 1 | |
dilation (int or tuple, optional): Spacing between kernel elements. | |
Default: 3 | |
groups (int, optional): Number of blocked connections from input | |
channels to output channels. Default: 1 | |
bias (bool, optional): If set True, adds a learnable bias to the | |
output. Default: False. | |
type (str, optional): Type of adaptive conv, can be either 'offset' | |
(arbitrary anchors) or 'dilation' (uniform anchor). | |
Default: 'dilation'. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
dilation=3, | |
groups=1, | |
bias=False, | |
type='dilation', | |
init_cfg=dict( | |
type='Normal', std=0.01, override=dict(name='conv'))): | |
super(AdaptiveConv, self).__init__(init_cfg) | |
assert type in ['offset', 'dilation'] | |
self.adapt_type = type | |
assert kernel_size == 3, 'Adaptive conv only supports kernels 3' | |
if self.adapt_type == 'offset': | |
assert stride == 1 and padding == 1 and groups == 1, \ | |
'Adaptive conv offset mode only supports padding: {1}, ' \ | |
f'stride: {1}, groups: {1}' | |
self.conv = DeformConv2d( | |
in_channels, | |
out_channels, | |
kernel_size, | |
padding=padding, | |
stride=stride, | |
groups=groups, | |
bias=bias) | |
else: | |
self.conv = nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size, | |
padding=dilation, | |
dilation=dilation) | |
def forward(self, x, offset): | |
"""Forward function.""" | |
if self.adapt_type == 'offset': | |
N, _, H, W = x.shape | |
assert offset is not None | |
assert H * W == offset.shape[1] | |
# reshape [N, NA, 18] to (N, 18, H, W) | |
offset = offset.permute(0, 2, 1).reshape(N, -1, H, W) | |
offset = offset.contiguous() | |
x = self.conv(x, offset) | |
else: | |
assert offset is None | |
x = self.conv(x) | |
return x | |
class StageCascadeRPNHead(RPNHead): | |
"""Stage of CascadeRPNHead. | |
Args: | |
in_channels (int): Number of channels in the input feature map. | |
anchor_generator (dict): anchor generator config. | |
adapt_cfg (dict): adaptation config. | |
bridged_feature (bool, optional): whether update rpn feature. | |
Default: False. | |
with_cls (bool, optional): whether use classification branch. | |
Default: True. | |
sampling (bool, optional): whether use sampling. Default: True. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None | |
""" | |
def __init__(self, | |
in_channels, | |
anchor_generator=dict( | |
type='AnchorGenerator', | |
scales=[8], | |
ratios=[1.0], | |
strides=[4, 8, 16, 32, 64]), | |
adapt_cfg=dict(type='dilation', dilation=3), | |
bridged_feature=False, | |
with_cls=True, | |
sampling=True, | |
init_cfg=None, | |
**kwargs): | |
self.with_cls = with_cls | |
self.anchor_strides = anchor_generator['strides'] | |
self.anchor_scales = anchor_generator['scales'] | |
self.bridged_feature = bridged_feature | |
self.adapt_cfg = adapt_cfg | |
super(StageCascadeRPNHead, self).__init__( | |
in_channels, | |
anchor_generator=anchor_generator, | |
init_cfg=init_cfg, | |
**kwargs) | |
# override sampling and sampler | |
self.sampling = sampling | |
if self.train_cfg: | |
self.assigner = build_assigner(self.train_cfg.assigner) | |
# use PseudoSampler when sampling is False | |
if self.sampling and hasattr(self.train_cfg, 'sampler'): | |
sampler_cfg = self.train_cfg.sampler | |
else: | |
sampler_cfg = dict(type='PseudoSampler') | |
self.sampler = build_sampler(sampler_cfg, context=self) | |
if init_cfg is None: | |
self.init_cfg = dict( | |
type='Normal', std=0.01, override=[dict(name='rpn_reg')]) | |
if self.with_cls: | |
self.init_cfg['override'].append(dict(name='rpn_cls')) | |
def _init_layers(self): | |
"""Init layers of a CascadeRPN stage.""" | |
self.rpn_conv = AdaptiveConv(self.in_channels, self.feat_channels, | |
**self.adapt_cfg) | |
if self.with_cls: | |
self.rpn_cls = nn.Conv2d(self.feat_channels, | |
self.num_anchors * self.cls_out_channels, | |
1) | |
self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1) | |
self.relu = nn.ReLU(inplace=True) | |
def forward_single(self, x, offset): | |
"""Forward function of single scale.""" | |
bridged_x = x | |
x = self.relu(self.rpn_conv(x, offset)) | |
if self.bridged_feature: | |
bridged_x = x # update feature | |
cls_score = self.rpn_cls(x) if self.with_cls else None | |
bbox_pred = self.rpn_reg(x) | |
return bridged_x, cls_score, bbox_pred | |
def forward(self, feats, offset_list=None): | |
"""Forward function.""" | |
if offset_list is None: | |
offset_list = [None for _ in range(len(feats))] | |
return multi_apply(self.forward_single, feats, offset_list) | |
def _region_targets_single(self, | |
anchors, | |
valid_flags, | |
gt_bboxes, | |
gt_bboxes_ignore, | |
gt_labels, | |
img_meta, | |
featmap_sizes, | |
label_channels=1): | |
"""Get anchor targets based on region for single level.""" | |
assign_result = self.assigner.assign( | |
anchors, | |
valid_flags, | |
gt_bboxes, | |
img_meta, | |
featmap_sizes, | |
self.anchor_scales[0], | |
self.anchor_strides, | |
gt_bboxes_ignore=gt_bboxes_ignore, | |
gt_labels=None, | |
allowed_border=self.train_cfg.allowed_border) | |
flat_anchors = torch.cat(anchors) | |
sampling_result = self.sampler.sample(assign_result, flat_anchors, | |
gt_bboxes) | |
num_anchors = flat_anchors.shape[0] | |
bbox_targets = torch.zeros_like(flat_anchors) | |
bbox_weights = torch.zeros_like(flat_anchors) | |
labels = flat_anchors.new_zeros(num_anchors, dtype=torch.long) | |
label_weights = flat_anchors.new_zeros(num_anchors, dtype=torch.float) | |
pos_inds = sampling_result.pos_inds | |
neg_inds = sampling_result.neg_inds | |
if len(pos_inds) > 0: | |
if not self.reg_decoded_bbox: | |
pos_bbox_targets = self.bbox_coder.encode( | |
sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) | |
else: | |
pos_bbox_targets = sampling_result.pos_gt_bboxes | |
bbox_targets[pos_inds, :] = pos_bbox_targets | |
bbox_weights[pos_inds, :] = 1.0 | |
if gt_labels is None: | |
labels[pos_inds] = 1 | |
else: | |
labels[pos_inds] = gt_labels[ | |
sampling_result.pos_assigned_gt_inds] | |
if self.train_cfg.pos_weight <= 0: | |
label_weights[pos_inds] = 1.0 | |
else: | |
label_weights[pos_inds] = self.train_cfg.pos_weight | |
if len(neg_inds) > 0: | |
label_weights[neg_inds] = 1.0 | |
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, | |
neg_inds) | |
def region_targets(self, | |
anchor_list, | |
valid_flag_list, | |
gt_bboxes_list, | |
img_metas, | |
featmap_sizes, | |
gt_bboxes_ignore_list=None, | |
gt_labels_list=None, | |
label_channels=1, | |
unmap_outputs=True): | |
"""See :func:`StageCascadeRPNHead.get_targets`.""" | |
num_imgs = len(img_metas) | |
assert len(anchor_list) == len(valid_flag_list) == num_imgs | |
# anchor number of multi levels | |
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] | |
# compute targets for each image | |
if gt_bboxes_ignore_list is None: | |
gt_bboxes_ignore_list = [None for _ in range(num_imgs)] | |
if gt_labels_list is None: | |
gt_labels_list = [None for _ in range(num_imgs)] | |
(all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, | |
pos_inds_list, neg_inds_list) = multi_apply( | |
self._region_targets_single, | |
anchor_list, | |
valid_flag_list, | |
gt_bboxes_list, | |
gt_bboxes_ignore_list, | |
gt_labels_list, | |
img_metas, | |
featmap_sizes=featmap_sizes, | |
label_channels=label_channels) | |
# no valid anchors | |
if any([labels is None for labels in all_labels]): | |
return None | |
# sampled anchors of all images | |
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) | |
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) | |
# split targets to a list w.r.t. multiple levels | |
labels_list = images_to_levels(all_labels, num_level_anchors) | |
label_weights_list = images_to_levels(all_label_weights, | |
num_level_anchors) | |
bbox_targets_list = images_to_levels(all_bbox_targets, | |
num_level_anchors) | |
bbox_weights_list = images_to_levels(all_bbox_weights, | |
num_level_anchors) | |
return (labels_list, label_weights_list, bbox_targets_list, | |
bbox_weights_list, num_total_pos, num_total_neg) | |
def get_targets(self, | |
anchor_list, | |
valid_flag_list, | |
gt_bboxes, | |
img_metas, | |
featmap_sizes, | |
gt_bboxes_ignore=None, | |
label_channels=1): | |
"""Compute regression and classification targets for anchors. | |
Args: | |
anchor_list (list[list]): Multi level anchors of each image. | |
valid_flag_list (list[list]): Multi level valid flags of each | |
image. | |
gt_bboxes (list[Tensor]): Ground truth bboxes of each image. | |
img_metas (list[dict]): Meta info of each image. | |
featmap_sizes (list[Tensor]): Feature mapsize each level | |
gt_bboxes_ignore (list[Tensor]): Ignore bboxes of each images | |
label_channels (int): Channel of label. | |
Returns: | |
cls_reg_targets (tuple) | |
""" | |
if isinstance(self.assigner, RegionAssigner): | |
cls_reg_targets = self.region_targets( | |
anchor_list, | |
valid_flag_list, | |
gt_bboxes, | |
img_metas, | |
featmap_sizes, | |
gt_bboxes_ignore_list=gt_bboxes_ignore, | |
label_channels=label_channels) | |
else: | |
cls_reg_targets = super(StageCascadeRPNHead, self).get_targets( | |
anchor_list, | |
valid_flag_list, | |
gt_bboxes, | |
img_metas, | |
gt_bboxes_ignore_list=gt_bboxes_ignore, | |
label_channels=label_channels) | |
return cls_reg_targets | |
def anchor_offset(self, anchor_list, anchor_strides, featmap_sizes): | |
""" Get offset for deformable conv based on anchor shape | |
NOTE: currently support deformable kernel_size=3 and dilation=1 | |
Args: | |
anchor_list (list[list[tensor])): [NI, NLVL, NA, 4] list of | |
multi-level anchors | |
anchor_strides (list[int]): anchor stride of each level | |
Returns: | |
offset_list (list[tensor]): [NLVL, NA, 2, 18]: offset of DeformConv | |
kernel. | |
""" | |
def _shape_offset(anchors, stride, ks=3, dilation=1): | |
# currently support kernel_size=3 and dilation=1 | |
assert ks == 3 and dilation == 1 | |
pad = (ks - 1) // 2 | |
idx = torch.arange(-pad, pad + 1, dtype=dtype, device=device) | |
yy, xx = torch.meshgrid(idx, idx) # return order matters | |
xx = xx.reshape(-1) | |
yy = yy.reshape(-1) | |
w = (anchors[:, 2] - anchors[:, 0]) / stride | |
h = (anchors[:, 3] - anchors[:, 1]) / stride | |
w = w / (ks - 1) - dilation | |
h = h / (ks - 1) - dilation | |
offset_x = w[:, None] * xx # (NA, ks**2) | |
offset_y = h[:, None] * yy # (NA, ks**2) | |
return offset_x, offset_y | |
def _ctr_offset(anchors, stride, featmap_size): | |
feat_h, feat_w = featmap_size | |
assert len(anchors) == feat_h * feat_w | |
x = (anchors[:, 0] + anchors[:, 2]) * 0.5 | |
y = (anchors[:, 1] + anchors[:, 3]) * 0.5 | |
# compute centers on feature map | |
x = x / stride | |
y = y / stride | |
# compute predefine centers | |
xx = torch.arange(0, feat_w, device=anchors.device) | |
yy = torch.arange(0, feat_h, device=anchors.device) | |
yy, xx = torch.meshgrid(yy, xx) | |
xx = xx.reshape(-1).type_as(x) | |
yy = yy.reshape(-1).type_as(y) | |
offset_x = x - xx # (NA, ) | |
offset_y = y - yy # (NA, ) | |
return offset_x, offset_y | |
num_imgs = len(anchor_list) | |
num_lvls = len(anchor_list[0]) | |
dtype = anchor_list[0][0].dtype | |
device = anchor_list[0][0].device | |
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] | |
offset_list = [] | |
for i in range(num_imgs): | |
mlvl_offset = [] | |
for lvl in range(num_lvls): | |
c_offset_x, c_offset_y = _ctr_offset(anchor_list[i][lvl], | |
anchor_strides[lvl], | |
featmap_sizes[lvl]) | |
s_offset_x, s_offset_y = _shape_offset(anchor_list[i][lvl], | |
anchor_strides[lvl]) | |
# offset = ctr_offset + shape_offset | |
offset_x = s_offset_x + c_offset_x[:, None] | |
offset_y = s_offset_y + c_offset_y[:, None] | |
# offset order (y0, x0, y1, x2, .., y8, x8, y9, x9) | |
offset = torch.stack([offset_y, offset_x], dim=-1) | |
offset = offset.reshape(offset.size(0), -1) # [NA, 2*ks**2] | |
mlvl_offset.append(offset) | |
offset_list.append(torch.cat(mlvl_offset)) # [totalNA, 2*ks**2] | |
offset_list = images_to_levels(offset_list, num_level_anchors) | |
return offset_list | |
def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights, | |
bbox_targets, bbox_weights, num_total_samples): | |
"""Loss function on single scale.""" | |
# classification loss | |
if self.with_cls: | |
labels = labels.reshape(-1) | |
label_weights = label_weights.reshape(-1) | |
cls_score = cls_score.permute(0, 2, 3, | |
1).reshape(-1, self.cls_out_channels) | |
loss_cls = self.loss_cls( | |
cls_score, labels, label_weights, avg_factor=num_total_samples) | |
# regression loss | |
bbox_targets = bbox_targets.reshape(-1, 4) | |
bbox_weights = bbox_weights.reshape(-1, 4) | |
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) | |
if self.reg_decoded_bbox: | |
# When the regression loss (e.g. `IouLoss`, `GIouLoss`) | |
# is applied directly on the decoded bounding boxes, it | |
# decodes the already encoded coordinates to absolute format. | |
anchors = anchors.reshape(-1, 4) | |
bbox_pred = self.bbox_coder.decode(anchors, bbox_pred) | |
loss_reg = self.loss_bbox( | |
bbox_pred, | |
bbox_targets, | |
bbox_weights, | |
avg_factor=num_total_samples) | |
if self.with_cls: | |
return loss_cls, loss_reg | |
return None, loss_reg | |
def loss(self, | |
anchor_list, | |
valid_flag_list, | |
cls_scores, | |
bbox_preds, | |
gt_bboxes, | |
img_metas, | |
gt_bboxes_ignore=None): | |
"""Compute losses of the head. | |
Args: | |
anchor_list (list[list]): Multi level anchors of each image. | |
cls_scores (list[Tensor]): Box scores for each scale level | |
Has shape (N, num_anchors * num_classes, H, W) | |
bbox_preds (list[Tensor]): Box energies / deltas for each scale | |
level with shape (N, num_anchors * 4, H, W) | |
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with | |
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. | |
img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
gt_bboxes_ignore (None | list[Tensor]): specify which bounding | |
boxes can be ignored when computing the loss. Default: None | |
Returns: | |
dict[str, Tensor]: A dictionary of loss components. | |
""" | |
featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds] | |
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 | |
cls_reg_targets = self.get_targets( | |
anchor_list, | |
valid_flag_list, | |
gt_bboxes, | |
img_metas, | |
featmap_sizes, | |
gt_bboxes_ignore=gt_bboxes_ignore, | |
label_channels=label_channels) | |
if cls_reg_targets is None: | |
return None | |
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, | |
num_total_pos, num_total_neg) = cls_reg_targets | |
if self.sampling: | |
num_total_samples = num_total_pos + num_total_neg | |
else: | |
# 200 is hard-coded average factor, | |
# which follows guided anchoring. | |
num_total_samples = sum([label.numel() | |
for label in labels_list]) / 200.0 | |
# change per image, per level anchor_list to per_level, per_image | |
mlvl_anchor_list = list(zip(*anchor_list)) | |
# concat mlvl_anchor_list | |
mlvl_anchor_list = [ | |
torch.cat(anchors, dim=0) for anchors in mlvl_anchor_list | |
] | |
losses = multi_apply( | |
self.loss_single, | |
cls_scores, | |
bbox_preds, | |
mlvl_anchor_list, | |
labels_list, | |
label_weights_list, | |
bbox_targets_list, | |
bbox_weights_list, | |
num_total_samples=num_total_samples) | |
if self.with_cls: | |
return dict(loss_rpn_cls=losses[0], loss_rpn_reg=losses[1]) | |
return dict(loss_rpn_reg=losses[1]) | |
def get_bboxes(self, | |
anchor_list, | |
cls_scores, | |
bbox_preds, | |
img_metas, | |
cfg, | |
rescale=False): | |
"""Get proposal predict. | |
Args: | |
anchor_list (list[list]): Multi level anchors of each image. | |
cls_scores (list[Tensor]): Classification scores for all | |
scale levels, each is a 4D-tensor, has shape | |
(batch_size, num_priors * num_classes, H, W). | |
bbox_preds (list[Tensor]): Box energies / deltas for all | |
scale levels, each is a 4D-tensor, has shape | |
(batch_size, num_priors * 4, H, W). | |
img_metas (list[dict], Optional): Image meta info. Default None. | |
cfg (mmcv.Config, Optional): Test / postprocessing configuration, | |
if None, test_cfg would be used. | |
rescale (bool): If True, return boxes in original image space. | |
Default: False. | |
Returns: | |
Tensor: Labeled boxes in shape (n, 5), where the first 4 columns | |
are bounding box positions (tl_x, tl_y, br_x, br_y) and the | |
5-th column is a score between 0 and 1. | |
""" | |
assert len(cls_scores) == len(bbox_preds) | |
result_list = [] | |
for img_id in range(len(img_metas)): | |
cls_score_list = select_single_mlvl(cls_scores, img_id) | |
bbox_pred_list = select_single_mlvl(bbox_preds, img_id) | |
img_shape = img_metas[img_id]['img_shape'] | |
scale_factor = img_metas[img_id]['scale_factor'] | |
proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list, | |
anchor_list[img_id], img_shape, | |
scale_factor, cfg, rescale) | |
result_list.append(proposals) | |
return result_list | |
def _get_bboxes_single(self, | |
cls_scores, | |
bbox_preds, | |
mlvl_anchors, | |
img_shape, | |
scale_factor, | |
cfg, | |
rescale=False): | |
"""Transform outputs of a single image into bbox predictions. | |
Args: | |
cls_scores (list[Tensor]): Box scores from all scale | |
levels of a single image, each item has shape | |
(num_anchors * num_classes, H, W). | |
bbox_preds (list[Tensor]): Box energies / deltas from | |
all scale levels of a single image, each item has | |
shape (num_anchors * 4, H, W). | |
mlvl_anchors (list[Tensor]): Box reference from all scale | |
levels of a single image, each item has shape | |
(num_total_anchors, 4). | |
img_shape (tuple[int]): Shape of the input image, | |
(height, width, 3). | |
scale_factor (ndarray): Scale factor of the image arange as | |
(w_scale, h_scale, w_scale, h_scale). | |
cfg (mmcv.Config): Test / postprocessing configuration, | |
if None, test_cfg would be used. | |
rescale (bool): If True, return boxes in original image space. | |
Default False. | |
Returns: | |
Tensor: Labeled boxes in shape (n, 5), where the first 4 columns | |
are bounding box positions (tl_x, tl_y, br_x, br_y) and the | |
5-th column is a score between 0 and 1. | |
""" | |
cfg = self.test_cfg if cfg is None else cfg | |
cfg = copy.deepcopy(cfg) | |
# bboxes from different level should be independent during NMS, | |
# level_ids are used as labels for batched NMS to separate them | |
level_ids = [] | |
mlvl_scores = [] | |
mlvl_bbox_preds = [] | |
mlvl_valid_anchors = [] | |
nms_pre = cfg.get('nms_pre', -1) | |
for idx in range(len(cls_scores)): | |
rpn_cls_score = cls_scores[idx] | |
rpn_bbox_pred = bbox_preds[idx] | |
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] | |
rpn_cls_score = rpn_cls_score.permute(1, 2, 0) | |
if self.use_sigmoid_cls: | |
rpn_cls_score = rpn_cls_score.reshape(-1) | |
scores = rpn_cls_score.sigmoid() | |
else: | |
rpn_cls_score = rpn_cls_score.reshape(-1, 2) | |
# We set FG labels to [0, num_class-1] and BG label to | |
# num_class in RPN head since mmdet v2.5, which is unified to | |
# be consistent with other head since mmdet v2.0. In mmdet v2.0 | |
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. | |
scores = rpn_cls_score.softmax(dim=1)[:, 0] | |
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4) | |
anchors = mlvl_anchors[idx] | |
if 0 < nms_pre < scores.shape[0]: | |
# sort is faster than topk | |
# _, topk_inds = scores.topk(cfg.nms_pre) | |
ranked_scores, rank_inds = scores.sort(descending=True) | |
topk_inds = rank_inds[:nms_pre] | |
scores = ranked_scores[:nms_pre] | |
rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] | |
anchors = anchors[topk_inds, :] | |
mlvl_scores.append(scores) | |
mlvl_bbox_preds.append(rpn_bbox_pred) | |
mlvl_valid_anchors.append(anchors) | |
level_ids.append( | |
scores.new_full((scores.size(0), ), idx, dtype=torch.long)) | |
scores = torch.cat(mlvl_scores) | |
anchors = torch.cat(mlvl_valid_anchors) | |
rpn_bbox_pred = torch.cat(mlvl_bbox_preds) | |
proposals = self.bbox_coder.decode( | |
anchors, rpn_bbox_pred, max_shape=img_shape) | |
ids = torch.cat(level_ids) | |
if cfg.min_bbox_size >= 0: | |
w = proposals[:, 2] - proposals[:, 0] | |
h = proposals[:, 3] - proposals[:, 1] | |
valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) | |
if not valid_mask.all(): | |
proposals = proposals[valid_mask] | |
scores = scores[valid_mask] | |
ids = ids[valid_mask] | |
# deprecate arguments warning | |
if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg: | |
warnings.warn( | |
'In rpn_proposal or test_cfg, ' | |
'nms_thr has been moved to a dict named nms as ' | |
'iou_threshold, max_num has been renamed as max_per_img, ' | |
'name of original arguments and the way to specify ' | |
'iou_threshold of NMS will be deprecated.') | |
if 'nms' not in cfg: | |
cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr)) | |
if 'max_num' in cfg: | |
if 'max_per_img' in cfg: | |
assert cfg.max_num == cfg.max_per_img, f'You ' \ | |
f'set max_num and ' \ | |
f'max_per_img at the same time, but get {cfg.max_num} ' \ | |
f'and {cfg.max_per_img} respectively' \ | |
'Please delete max_num which will be deprecated.' | |
else: | |
cfg.max_per_img = cfg.max_num | |
if 'nms_thr' in cfg: | |
assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set' \ | |
f' iou_threshold in nms and ' \ | |
f'nms_thr at the same time, but get' \ | |
f' {cfg.nms.iou_threshold} and {cfg.nms_thr}' \ | |
f' respectively. Please delete the nms_thr ' \ | |
f'which will be deprecated.' | |
if proposals.numel() > 0: | |
dets, _ = batched_nms(proposals, scores, ids, cfg.nms) | |
else: | |
return proposals.new_zeros(0, 5) | |
return dets[:cfg.max_per_img] | |
def refine_bboxes(self, anchor_list, bbox_preds, img_metas): | |
"""Refine bboxes through stages.""" | |
num_levels = len(bbox_preds) | |
new_anchor_list = [] | |
for img_id in range(len(img_metas)): | |
mlvl_anchors = [] | |
for i in range(num_levels): | |
bbox_pred = bbox_preds[i][img_id].detach() | |
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) | |
img_shape = img_metas[img_id]['img_shape'] | |
bboxes = self.bbox_coder.decode(anchor_list[img_id][i], | |
bbox_pred, img_shape) | |
mlvl_anchors.append(bboxes) | |
new_anchor_list.append(mlvl_anchors) | |
return new_anchor_list | |
class CascadeRPNHead(BaseDenseHead): | |
"""The CascadeRPNHead will predict more accurate region proposals, which is | |
required for two-stage detectors (such as Fast/Faster R-CNN). CascadeRPN | |
consists of a sequence of RPNStage to progressively improve the accuracy of | |
the detected proposals. | |
More details can be found in ``https://arxiv.org/abs/1909.06720``. | |
Args: | |
num_stages (int): number of CascadeRPN stages. | |
stages (list[dict]): list of configs to build the stages. | |
train_cfg (list[dict]): list of configs at training time each stage. | |
test_cfg (dict): config at testing time. | |
""" | |
def __init__(self, num_stages, stages, train_cfg, test_cfg, init_cfg=None): | |
super(CascadeRPNHead, self).__init__(init_cfg) | |
assert num_stages == len(stages) | |
self.num_stages = num_stages | |
# Be careful! Pretrained weights cannot be loaded when use | |
# nn.ModuleList | |
self.stages = ModuleList() | |
for i in range(len(stages)): | |
train_cfg_i = train_cfg[i] if train_cfg is not None else None | |
stages[i].update(train_cfg=train_cfg_i) | |
stages[i].update(test_cfg=test_cfg) | |
self.stages.append(build_head(stages[i])) | |
self.train_cfg = train_cfg | |
self.test_cfg = test_cfg | |
def loss(self): | |
"""loss() is implemented in StageCascadeRPNHead.""" | |
pass | |
def get_bboxes(self): | |
"""get_bboxes() is implemented in StageCascadeRPNHead.""" | |
pass | |
def forward_train(self, | |
x, | |
img_metas, | |
gt_bboxes, | |
gt_labels=None, | |
gt_bboxes_ignore=None, | |
proposal_cfg=None): | |
"""Forward train function.""" | |
assert gt_labels is None, 'RPN does not require gt_labels' | |
featmap_sizes = [featmap.size()[-2:] for featmap in x] | |
device = x[0].device | |
anchor_list, valid_flag_list = self.stages[0].get_anchors( | |
featmap_sizes, img_metas, device=device) | |
losses = dict() | |
for i in range(self.num_stages): | |
stage = self.stages[i] | |
if stage.adapt_cfg['type'] == 'offset': | |
offset_list = stage.anchor_offset(anchor_list, | |
stage.anchor_strides, | |
featmap_sizes) | |
else: | |
offset_list = None | |
x, cls_score, bbox_pred = stage(x, offset_list) | |
rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score, | |
bbox_pred, gt_bboxes, img_metas) | |
stage_loss = stage.loss(*rpn_loss_inputs) | |
for name, value in stage_loss.items(): | |
losses['s{}.{}'.format(i, name)] = value | |
# refine boxes | |
if i < self.num_stages - 1: | |
anchor_list = stage.refine_bboxes(anchor_list, bbox_pred, | |
img_metas) | |
if proposal_cfg is None: | |
return losses | |
else: | |
proposal_list = self.stages[-1].get_bboxes(anchor_list, cls_score, | |
bbox_pred, img_metas, | |
self.test_cfg) | |
return losses, proposal_list | |
def simple_test_rpn(self, x, img_metas): | |
"""Simple forward test function.""" | |
featmap_sizes = [featmap.size()[-2:] for featmap in x] | |
device = x[0].device | |
anchor_list, _ = self.stages[0].get_anchors( | |
featmap_sizes, img_metas, device=device) | |
for i in range(self.num_stages): | |
stage = self.stages[i] | |
if stage.adapt_cfg['type'] == 'offset': | |
offset_list = stage.anchor_offset(anchor_list, | |
stage.anchor_strides, | |
featmap_sizes) | |
else: | |
offset_list = None | |
x, cls_score, bbox_pred = stage(x, offset_list) | |
if i < self.num_stages - 1: | |
anchor_list = stage.refine_bboxes(anchor_list, bbox_pred, | |
img_metas) | |
proposal_list = self.stages[-1].get_bboxes(anchor_list, cls_score, | |
bbox_pred, img_metas, | |
self.test_cfg) | |
return proposal_list | |
def aug_test_rpn(self, x, img_metas): | |
"""Augmented forward test function.""" | |
raise NotImplementedError( | |
'CascadeRPNHead does not support test-time augmentation') | |