stevengrove
initial commit
186701e
raw
history blame
No virus
8.05 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import Tensor
_XYWH2XYXY = torch.tensor([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0],
[-0.5, 0.0, 0.5, 0.0], [0.0, -0.5, 0.0, 0.5]],
dtype=torch.float32)
class TRTEfficientNMSop(torch.autograd.Function):
@staticmethod
def forward(
ctx,
boxes: Tensor,
scores: Tensor,
background_class: int = -1,
box_coding: int = 0,
iou_threshold: float = 0.45,
max_output_boxes: int = 100,
plugin_version: str = '1',
score_activation: int = 0,
score_threshold: float = 0.25,
):
batch_size, _, num_classes = scores.shape
num_det = torch.randint(
0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
det_scores = torch.randn(batch_size, max_output_boxes)
det_classes = torch.randint(
0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
return num_det, det_boxes, det_scores, det_classes
@staticmethod
def symbolic(g,
boxes: Tensor,
scores: Tensor,
background_class: int = -1,
box_coding: int = 0,
iou_threshold: float = 0.45,
max_output_boxes: int = 100,
plugin_version: str = '1',
score_activation: int = 0,
score_threshold: float = 0.25):
out = g.op(
'TRT::EfficientNMS_TRT',
boxes,
scores,
background_class_i=background_class,
box_coding_i=box_coding,
iou_threshold_f=iou_threshold,
max_output_boxes_i=max_output_boxes,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
score_threshold_f=score_threshold,
outputs=4)
num_det, det_boxes, det_scores, det_classes = out
return num_det, det_boxes, det_scores, det_classes
class TRTbatchedNMSop(torch.autograd.Function):
"""TensorRT NMS operation."""
@staticmethod
def forward(
ctx,
boxes: Tensor,
scores: Tensor,
plugin_version: str = '1',
shareLocation: int = 1,
backgroundLabelId: int = -1,
numClasses: int = 80,
topK: int = 1000,
keepTopK: int = 100,
scoreThreshold: float = 0.25,
iouThreshold: float = 0.45,
isNormalized: int = 0,
clipBoxes: int = 0,
scoreBits: int = 16,
caffeSemantics: int = 1,
):
batch_size, _, numClasses = scores.shape
num_det = torch.randint(
0, keepTopK, (batch_size, 1), dtype=torch.int32)
det_boxes = torch.randn(batch_size, keepTopK, 4)
det_scores = torch.randn(batch_size, keepTopK)
det_classes = torch.randint(0, numClasses,
(batch_size, keepTopK)).float()
return num_det, det_boxes, det_scores, det_classes
@staticmethod
def symbolic(
g,
boxes: Tensor,
scores: Tensor,
plugin_version: str = '1',
shareLocation: int = 1,
backgroundLabelId: int = -1,
numClasses: int = 80,
topK: int = 1000,
keepTopK: int = 100,
scoreThreshold: float = 0.25,
iouThreshold: float = 0.45,
isNormalized: int = 0,
clipBoxes: int = 0,
scoreBits: int = 16,
caffeSemantics: int = 1,
):
out = g.op(
'TRT::BatchedNMSDynamic_TRT',
boxes,
scores,
shareLocation_i=shareLocation,
plugin_version_s=plugin_version,
backgroundLabelId_i=backgroundLabelId,
numClasses_i=numClasses,
topK_i=topK,
keepTopK_i=keepTopK,
scoreThreshold_f=scoreThreshold,
iouThreshold_f=iouThreshold,
isNormalized_i=isNormalized,
clipBoxes_i=clipBoxes,
scoreBits_i=scoreBits,
caffeSemantics_i=caffeSemantics,
outputs=4)
num_det, det_boxes, det_scores, det_classes = out
return num_det, det_boxes, det_scores, det_classes
def _efficient_nms(
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = 100,
box_coding: int = 0,
):
"""Wrapper for `efficient_nms` with TensorRT.
Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms. Defaults to 1000.
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
score_threshold (float): score threshold of nms.
Defaults to 0.05.
pre_top_k (int): Number of top K boxes to keep before nms.
Defaults to -1.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.
box_coding (int): Bounding boxes format for nms.
Defaults to 0 means [x1, y1 ,x2, y2].
Set to 1 means [x, y, w, h].
Returns:
tuple[Tensor, Tensor, Tensor, Tensor]:
(num_det, det_boxes, det_scores, det_classes),
`num_det` of shape [N, 1]
`det_boxes` of shape [N, num_det, 4]
`det_scores` of shape [N, num_det]
`det_classes` of shape [N, num_det]
"""
num_det, det_boxes, det_scores, det_classes = TRTEfficientNMSop.apply(
boxes, scores, -1, box_coding, iou_threshold, keep_top_k, '1', 0,
score_threshold)
return num_det, det_boxes, det_scores, det_classes
def _batched_nms(
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = 100,
box_coding: int = 0,
):
"""Wrapper for `efficient_nms` with TensorRT.
Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms. Defaults to 1000.
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
score_threshold (float): score threshold of nms.
Defaults to 0.05.
pre_top_k (int): Number of top K boxes to keep before nms.
Defaults to -1.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.
box_coding (int): Bounding boxes format for nms.
Defaults to 0 means [x1, y1 ,x2, y2].
Set to 1 means [x, y, w, h].
Returns:
tuple[Tensor, Tensor, Tensor, Tensor]:
(num_det, det_boxes, det_scores, det_classes),
`num_det` of shape [N, 1]
`det_boxes` of shape [N, num_det, 4]
`det_scores` of shape [N, num_det]
`det_classes` of shape [N, num_det]
"""
if box_coding == 1:
boxes = boxes @ (_XYWH2XYXY.to(boxes.device))
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
_, _, numClasses = scores.shape
num_det, det_boxes, det_scores, det_classes = TRTbatchedNMSop.apply(
boxes, scores, '1', 1, -1, int(numClasses), min(pre_top_k, 4096),
keep_top_k, score_threshold, iou_threshold, 0, 0, 16, 1)
det_classes = det_classes.int()
return num_det, det_boxes, det_scores, det_classes
def efficient_nms(*args, **kwargs):
"""Wrapper function for `_efficient_nms`."""
return _efficient_nms(*args, **kwargs)
def batched_nms(*args, **kwargs):
"""Wrapper function for `_batched_nms`."""
return _batched_nms(*args, **kwargs)