Robert001's picture
first commit
b334e29
raw
history blame
No virus
6.9 kB
import copy
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv import ConfigDict
from mmcv.cnn import normal_init
from mmcv.ops import nms
from ..builder import HEADS
from .guided_anchor_head import GuidedAnchorHead
from .rpn_test_mixin import RPNTestMixin
@HEADS.register_module()
class GARPNHead(RPNTestMixin, GuidedAnchorHead):
"""Guided-Anchor-based RPN head."""
def __init__(self, in_channels, **kwargs):
super(GARPNHead, self).__init__(1, in_channels, **kwargs)
def _init_layers(self):
"""Initialize layers of the head."""
self.rpn_conv = nn.Conv2d(
self.in_channels, self.feat_channels, 3, padding=1)
super(GARPNHead, self)._init_layers()
def init_weights(self):
"""Initialize weights of the head."""
normal_init(self.rpn_conv, std=0.01)
super(GARPNHead, self).init_weights()
def forward_single(self, x):
"""Forward feature of a single scale level."""
x = self.rpn_conv(x)
x = F.relu(x, inplace=True)
(cls_score, bbox_pred, shape_pred,
loc_pred) = super(GARPNHead, self).forward_single(x)
return cls_score, bbox_pred, shape_pred, loc_pred
def loss(self,
cls_scores,
bbox_preds,
shape_preds,
loc_preds,
gt_bboxes,
img_metas,
gt_bboxes_ignore=None):
losses = super(GARPNHead, self).loss(
cls_scores,
bbox_preds,
shape_preds,
loc_preds,
gt_bboxes,
None,
img_metas,
gt_bboxes_ignore=gt_bboxes_ignore)
return dict(
loss_rpn_cls=losses['loss_cls'],
loss_rpn_bbox=losses['loss_bbox'],
loss_anchor_shape=losses['loss_shape'],
loss_anchor_loc=losses['loss_loc'])
def _get_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_anchors,
mlvl_masks,
img_shape,
scale_factor,
cfg,
rescale=False):
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
# deprecate arguments warning
if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
warnings.warn(
'In rpn_proposal or test_cfg, '
'nms_thr has been moved to a dict named nms as '
'iou_threshold, max_num has been renamed as max_per_img, '
'name of original arguments and the way to specify '
'iou_threshold of NMS will be deprecated.')
if 'nms' not in cfg:
cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
if 'max_num' in cfg:
if 'max_per_img' in cfg:
assert cfg.max_num == cfg.max_per_img, f'You ' \
f'set max_num and max_per_img at the same time, ' \
f'but get {cfg.max_num} ' \
f'and {cfg.max_per_img} respectively' \
'Please delete max_num which will be deprecated.'
else:
cfg.max_per_img = cfg.max_num
if 'nms_thr' in cfg:
assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \
f'iou_threshold in nms and ' \
f'nms_thr at the same time, but get ' \
f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \
f' respectively. Please delete the ' \
f'nms_thr which will be deprecated.'
assert cfg.nms.get('type', 'nms') == 'nms', 'GARPNHead only support ' \
'naive nms.'
mlvl_proposals = []
for idx in range(len(cls_scores)):
rpn_cls_score = cls_scores[idx]
rpn_bbox_pred = bbox_preds[idx]
anchors = mlvl_anchors[idx]
mask = mlvl_masks[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
# if no location is kept, end.
if mask.sum() == 0:
continue
rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.reshape(-1)
scores = rpn_cls_score.sigmoid()
else:
rpn_cls_score = rpn_cls_score.reshape(-1, 2)
# remind that we set FG labels to [0, num_class-1]
# since mmdet v2.0
# BG cat_id: num_class
scores = rpn_cls_score.softmax(dim=1)[:, :-1]
# filter scores, bbox_pred w.r.t. mask.
# anchors are filtered in get_anchors() beforehand.
scores = scores[mask]
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1,
4)[mask, :]
if scores.dim() == 0:
rpn_bbox_pred = rpn_bbox_pred.unsqueeze(0)
anchors = anchors.unsqueeze(0)
scores = scores.unsqueeze(0)
# filter anchors, bbox_pred, scores w.r.t. scores
if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
_, topk_inds = scores.topk(cfg.nms_pre)
rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
anchors = anchors[topk_inds, :]
scores = scores[topk_inds]
# get proposals w.r.t. anchors and rpn_bbox_pred
proposals = self.bbox_coder.decode(
anchors, rpn_bbox_pred, max_shape=img_shape)
# filter out too small bboxes
if cfg.min_bbox_size > 0:
w = proposals[:, 2] - proposals[:, 0]
h = proposals[:, 3] - proposals[:, 1]
valid_inds = torch.nonzero(
(w >= cfg.min_bbox_size) & (h >= cfg.min_bbox_size),
as_tuple=False).squeeze()
proposals = proposals[valid_inds, :]
scores = scores[valid_inds]
# NMS in current level
proposals, _ = nms(proposals, scores, cfg.nms.iou_threshold)
proposals = proposals[:cfg.nms_post, :]
mlvl_proposals.append(proposals)
proposals = torch.cat(mlvl_proposals, 0)
if cfg.get('nms_across_levels', False):
# NMS across multi levels
proposals, _ = nms(proposals[:, :4], proposals[:, -1],
cfg.nms.iou_threshold)
proposals = proposals[:cfg.max_per_img, :]
else:
scores = proposals[:, 4]
num = min(cfg.max_per_img, proposals.shape[0])
_, topk_inds = scores.topk(num)
proposals = proposals[topk_inds, :]
return proposals