Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional, Tuple, Union | |
import torch | |
from mmcv.ops.nms import batched_nms | |
from torch import Tensor | |
from mmdet.structures.bbox import bbox_overlaps | |
from mmdet.utils import ConfigType | |
def multiclass_nms( | |
multi_bboxes: Tensor, | |
multi_scores: Tensor, | |
score_thr: float, | |
nms_cfg: ConfigType, | |
max_num: int = -1, | |
score_factors: Optional[Tensor] = None, | |
return_inds: bool = False, | |
box_dim: int = 4 | |
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]: | |
"""NMS for multi-class bboxes. | |
Args: | |
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | |
multi_scores (Tensor): shape (n, #class), where the last column | |
contains scores of the background class, but this will be ignored. | |
score_thr (float): bbox threshold, bboxes with scores lower than it | |
will not be considered. | |
nms_cfg (Union[:obj:`ConfigDict`, dict]): a dict that contains | |
the arguments of nms operations. | |
max_num (int, optional): if there are more than max_num bboxes after | |
NMS, only top max_num will be kept. Default to -1. | |
score_factors (Tensor, optional): The factors multiplied to scores | |
before applying NMS. Default to None. | |
return_inds (bool, optional): Whether return the indices of kept | |
bboxes. Default to False. | |
box_dim (int): The dimension of boxes. Defaults to 4. | |
Returns: | |
Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]: | |
(dets, labels, indices (optional)), tensors of shape (k, 5), | |
(k), and (k). Dets are boxes with scores. Labels are 0-based. | |
""" | |
num_classes = multi_scores.size(1) - 1 | |
# exclude background category | |
if multi_bboxes.shape[1] > box_dim: | |
bboxes = multi_bboxes.view(multi_scores.size(0), -1, box_dim) | |
else: | |
bboxes = multi_bboxes[:, None].expand( | |
multi_scores.size(0), num_classes, box_dim) | |
scores = multi_scores[:, :-1] | |
labels = torch.arange(num_classes, dtype=torch.long, device=scores.device) | |
labels = labels.view(1, -1).expand_as(scores) | |
bboxes = bboxes.reshape(-1, box_dim) | |
scores = scores.reshape(-1) | |
labels = labels.reshape(-1) | |
if not torch.onnx.is_in_onnx_export(): | |
# NonZero not supported in TensorRT | |
# remove low scoring boxes | |
valid_mask = scores > score_thr | |
# multiply score_factor after threshold to preserve more bboxes, improve | |
# mAP by 1% for YOLOv3 | |
if score_factors is not None: | |
# expand the shape to match original shape of score | |
score_factors = score_factors.view(-1, 1).expand( | |
multi_scores.size(0), num_classes) | |
score_factors = score_factors.reshape(-1) | |
scores = scores * score_factors | |
if not torch.onnx.is_in_onnx_export(): | |
# NonZero not supported in TensorRT | |
inds = valid_mask.nonzero(as_tuple=False).squeeze(1) | |
bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds] | |
else: | |
# TensorRT NMS plugin has invalid output filled with -1 | |
# add dummy data to make detection output correct. | |
bboxes = torch.cat([bboxes, bboxes.new_zeros(1, box_dim)], dim=0) | |
scores = torch.cat([scores, scores.new_zeros(1)], dim=0) | |
labels = torch.cat([labels, labels.new_zeros(1)], dim=0) | |
if bboxes.numel() == 0: | |
if torch.onnx.is_in_onnx_export(): | |
raise RuntimeError('[ONNX Error] Can not record NMS ' | |
'as it has not been executed this time') | |
dets = torch.cat([bboxes, scores[:, None]], -1) | |
if return_inds: | |
return dets, labels, inds | |
else: | |
return dets, labels | |
dets, keep = batched_nms(bboxes, scores, labels, nms_cfg) | |
if max_num > 0: | |
dets = dets[:max_num] | |
keep = keep[:max_num] | |
if return_inds: | |
return dets, labels[keep], inds[keep] | |
else: | |
return dets, labels[keep] | |
def fast_nms( | |
multi_bboxes: Tensor, | |
multi_scores: Tensor, | |
multi_coeffs: Tensor, | |
score_thr: float, | |
iou_thr: float, | |
top_k: int, | |
max_num: int = -1 | |
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]: | |
"""Fast NMS in `YOLACT <https://arxiv.org/abs/1904.02689>`_. | |
Fast NMS allows already-removed detections to suppress other detections so | |
that every instance can be decided to be kept or discarded in parallel, | |
which is not possible in traditional NMS. This relaxation allows us to | |
implement Fast NMS entirely in standard GPU-accelerated matrix operations. | |
Args: | |
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | |
multi_scores (Tensor): shape (n, #class+1), where the last column | |
contains scores of the background class, but this will be ignored. | |
multi_coeffs (Tensor): shape (n, #class*coeffs_dim). | |
score_thr (float): bbox threshold, bboxes with scores lower than it | |
will not be considered. | |
iou_thr (float): IoU threshold to be considered as conflicted. | |
top_k (int): if there are more than top_k bboxes before NMS, | |
only top top_k will be kept. | |
max_num (int): if there are more than max_num bboxes after NMS, | |
only top max_num will be kept. If -1, keep all the bboxes. | |
Default: -1. | |
Returns: | |
Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]: | |
(dets, labels, coefficients), tensors of shape (k, 5), (k, 1), | |
and (k, coeffs_dim). Dets are boxes with scores. | |
Labels are 0-based. | |
""" | |
scores = multi_scores[:, :-1].t() # [#class, n] | |
scores, idx = scores.sort(1, descending=True) | |
idx = idx[:, :top_k].contiguous() | |
scores = scores[:, :top_k] # [#class, topk] | |
num_classes, num_dets = idx.size() | |
boxes = multi_bboxes[idx.view(-1), :].view(num_classes, num_dets, 4) | |
coeffs = multi_coeffs[idx.view(-1), :].view(num_classes, num_dets, -1) | |
iou = bbox_overlaps(boxes, boxes) # [#class, topk, topk] | |
iou.triu_(diagonal=1) | |
iou_max, _ = iou.max(dim=1) | |
# Now just filter out the ones higher than the threshold | |
keep = iou_max <= iou_thr | |
# Second thresholding introduces 0.2 mAP gain at negligible time cost | |
keep *= scores > score_thr | |
# Assign each kept detection to its corresponding class | |
classes = torch.arange( | |
num_classes, device=boxes.device)[:, None].expand_as(keep) | |
classes = classes[keep] | |
boxes = boxes[keep] | |
coeffs = coeffs[keep] | |
scores = scores[keep] | |
# Only keep the top max_num highest scores across all classes | |
scores, idx = scores.sort(0, descending=True) | |
if max_num > 0: | |
idx = idx[:max_num] | |
scores = scores[:max_num] | |
classes = classes[idx] | |
boxes = boxes[idx] | |
coeffs = coeffs[idx] | |
cls_dets = torch.cat([boxes, scores[:, None]], dim=1) | |
return cls_dets, classes, coeffs | |