KyanChen's picture
Upload 787 files
3e06e1c
raw
history blame
6.99 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import torch
from mmcv.ops.nms import batched_nms
from torch import Tensor
from mmdet.structures.bbox import bbox_overlaps
from mmdet.utils import ConfigType
def multiclass_nms(
multi_bboxes: Tensor,
multi_scores: Tensor,
score_thr: float,
nms_cfg: ConfigType,
max_num: int = -1,
score_factors: Optional[Tensor] = None,
return_inds: bool = False,
box_dim: int = 4
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
"""NMS for multi-class bboxes.
Args:
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
multi_scores (Tensor): shape (n, #class), where the last column
contains scores of the background class, but this will be ignored.
score_thr (float): bbox threshold, bboxes with scores lower than it
will not be considered.
nms_cfg (Union[:obj:`ConfigDict`, dict]): a dict that contains
the arguments of nms operations.
max_num (int, optional): if there are more than max_num bboxes after
NMS, only top max_num will be kept. Default to -1.
score_factors (Tensor, optional): The factors multiplied to scores
before applying NMS. Default to None.
return_inds (bool, optional): Whether return the indices of kept
bboxes. Default to False.
box_dim (int): The dimension of boxes. Defaults to 4.
Returns:
Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
(dets, labels, indices (optional)), tensors of shape (k, 5),
(k), and (k). Dets are boxes with scores. Labels are 0-based.
"""
num_classes = multi_scores.size(1) - 1
# exclude background category
if multi_bboxes.shape[1] > box_dim:
bboxes = multi_bboxes.view(multi_scores.size(0), -1, box_dim)
else:
bboxes = multi_bboxes[:, None].expand(
multi_scores.size(0), num_classes, box_dim)
scores = multi_scores[:, :-1]
labels = torch.arange(num_classes, dtype=torch.long, device=scores.device)
labels = labels.view(1, -1).expand_as(scores)
bboxes = bboxes.reshape(-1, box_dim)
scores = scores.reshape(-1)
labels = labels.reshape(-1)
if not torch.onnx.is_in_onnx_export():
# NonZero not supported in TensorRT
# remove low scoring boxes
valid_mask = scores > score_thr
# multiply score_factor after threshold to preserve more bboxes, improve
# mAP by 1% for YOLOv3
if score_factors is not None:
# expand the shape to match original shape of score
score_factors = score_factors.view(-1, 1).expand(
multi_scores.size(0), num_classes)
score_factors = score_factors.reshape(-1)
scores = scores * score_factors
if not torch.onnx.is_in_onnx_export():
# NonZero not supported in TensorRT
inds = valid_mask.nonzero(as_tuple=False).squeeze(1)
bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds]
else:
# TensorRT NMS plugin has invalid output filled with -1
# add dummy data to make detection output correct.
bboxes = torch.cat([bboxes, bboxes.new_zeros(1, box_dim)], dim=0)
scores = torch.cat([scores, scores.new_zeros(1)], dim=0)
labels = torch.cat([labels, labels.new_zeros(1)], dim=0)
if bboxes.numel() == 0:
if torch.onnx.is_in_onnx_export():
raise RuntimeError('[ONNX Error] Can not record NMS '
'as it has not been executed this time')
dets = torch.cat([bboxes, scores[:, None]], -1)
if return_inds:
return dets, labels, inds
else:
return dets, labels
dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)
if max_num > 0:
dets = dets[:max_num]
keep = keep[:max_num]
if return_inds:
return dets, labels[keep], inds[keep]
else:
return dets, labels[keep]
def fast_nms(
multi_bboxes: Tensor,
multi_scores: Tensor,
multi_coeffs: Tensor,
score_thr: float,
iou_thr: float,
top_k: int,
max_num: int = -1
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
"""Fast NMS in `YOLACT <https://arxiv.org/abs/1904.02689>`_.
Fast NMS allows already-removed detections to suppress other detections so
that every instance can be decided to be kept or discarded in parallel,
which is not possible in traditional NMS. This relaxation allows us to
implement Fast NMS entirely in standard GPU-accelerated matrix operations.
Args:
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
multi_scores (Tensor): shape (n, #class+1), where the last column
contains scores of the background class, but this will be ignored.
multi_coeffs (Tensor): shape (n, #class*coeffs_dim).
score_thr (float): bbox threshold, bboxes with scores lower than it
will not be considered.
iou_thr (float): IoU threshold to be considered as conflicted.
top_k (int): if there are more than top_k bboxes before NMS,
only top top_k will be kept.
max_num (int): if there are more than max_num bboxes after NMS,
only top max_num will be kept. If -1, keep all the bboxes.
Default: -1.
Returns:
Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
(dets, labels, coefficients), tensors of shape (k, 5), (k, 1),
and (k, coeffs_dim). Dets are boxes with scores.
Labels are 0-based.
"""
scores = multi_scores[:, :-1].t() # [#class, n]
scores, idx = scores.sort(1, descending=True)
idx = idx[:, :top_k].contiguous()
scores = scores[:, :top_k] # [#class, topk]
num_classes, num_dets = idx.size()
boxes = multi_bboxes[idx.view(-1), :].view(num_classes, num_dets, 4)
coeffs = multi_coeffs[idx.view(-1), :].view(num_classes, num_dets, -1)
iou = bbox_overlaps(boxes, boxes) # [#class, topk, topk]
iou.triu_(diagonal=1)
iou_max, _ = iou.max(dim=1)
# Now just filter out the ones higher than the threshold
keep = iou_max <= iou_thr
# Second thresholding introduces 0.2 mAP gain at negligible time cost
keep *= scores > score_thr
# Assign each kept detection to its corresponding class
classes = torch.arange(
num_classes, device=boxes.device)[:, None].expand_as(keep)
classes = classes[keep]
boxes = boxes[keep]
coeffs = coeffs[keep]
scores = scores[keep]
# Only keep the top max_num highest scores across all classes
scores, idx = scores.sort(0, descending=True)
if max_num > 0:
idx = idx[:max_num]
scores = scores[:max_num]
classes = classes[idx]
boxes = boxes[idx]
coeffs = coeffs[idx]
cls_dets = torch.cat([boxes, scores[:, None]], dim=1)
return cls_dets, classes, coeffs