Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,478 Bytes
e85fecb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
from typing import Dict
import torch
import torch.distributed
import torch.nn.functional as F
import torchvision
from torch import Tensor
from ...core import register
__all__ = [
"DetNMSPostProcessor",
]
@register()
class DetNMSPostProcessor(torch.nn.Module):
def __init__(
self,
iou_threshold=0.7,
score_threshold=0.01,
keep_topk=300,
box_fmt="cxcywh",
logit_fmt="sigmoid",
) -> None:
super().__init__()
self.iou_threshold = iou_threshold
self.score_threshold = score_threshold
self.keep_topk = keep_topk
self.box_fmt = box_fmt.lower()
self.logit_fmt = logit_fmt.lower()
self.logit_func = getattr(F, self.logit_fmt, None)
self.deploy_mode = False
def forward(self, outputs: Dict[str, Tensor], orig_target_sizes: Tensor):
logits, boxes = outputs["pred_logits"], outputs["pred_boxes"]
pred_boxes = torchvision.ops.box_convert(boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
pred_boxes *= orig_target_sizes.repeat(1, 2).unsqueeze(1)
values, pred_labels = torch.max(logits, dim=-1)
if self.logit_func:
pred_scores = self.logit_func(values)
else:
pred_scores = values
# TODO for onnx export
if self.deploy_mode:
blobs = {
"pred_labels": pred_labels,
"pred_boxes": pred_boxes,
"pred_scores": pred_scores,
}
return blobs
results = []
for i in range(logits.shape[0]):
score_keep = pred_scores[i] > self.score_threshold
pred_box = pred_boxes[i][score_keep]
pred_label = pred_labels[i][score_keep]
pred_score = pred_scores[i][score_keep]
keep = torchvision.ops.batched_nms(pred_box, pred_score, pred_label, self.iou_threshold)
keep = keep[: self.keep_topk]
blob = {
"labels": pred_label[keep],
"boxes": pred_box[keep],
"scores": pred_score[keep],
}
results.append(blob)
return results
def deploy(
self,
):
self.eval()
self.deploy_mode = True
return self
|