D-FINE / src /nn /postprocessor /nms_postprocessor.py
developer0hye's picture
Upload 76 files
e85fecb verified
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
from typing import Dict
import torch
import torch.distributed
import torch.nn.functional as F
import torchvision
from torch import Tensor
from ...core import register
__all__ = [
"DetNMSPostProcessor",
]
@register()
class DetNMSPostProcessor(torch.nn.Module):
def __init__(
self,
iou_threshold=0.7,
score_threshold=0.01,
keep_topk=300,
box_fmt="cxcywh",
logit_fmt="sigmoid",
) -> None:
super().__init__()
self.iou_threshold = iou_threshold
self.score_threshold = score_threshold
self.keep_topk = keep_topk
self.box_fmt = box_fmt.lower()
self.logit_fmt = logit_fmt.lower()
self.logit_func = getattr(F, self.logit_fmt, None)
self.deploy_mode = False
def forward(self, outputs: Dict[str, Tensor], orig_target_sizes: Tensor):
logits, boxes = outputs["pred_logits"], outputs["pred_boxes"]
pred_boxes = torchvision.ops.box_convert(boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
pred_boxes *= orig_target_sizes.repeat(1, 2).unsqueeze(1)
values, pred_labels = torch.max(logits, dim=-1)
if self.logit_func:
pred_scores = self.logit_func(values)
else:
pred_scores = values
# TODO for onnx export
if self.deploy_mode:
blobs = {
"pred_labels": pred_labels,
"pred_boxes": pred_boxes,
"pred_scores": pred_scores,
}
return blobs
results = []
for i in range(logits.shape[0]):
score_keep = pred_scores[i] > self.score_threshold
pred_box = pred_boxes[i][score_keep]
pred_label = pred_labels[i][score_keep]
pred_score = pred_scores[i][score_keep]
keep = torchvision.ops.batched_nms(pred_box, pred_score, pred_label, self.iou_threshold)
keep = keep[: self.keep_topk]
blob = {
"labels": pred_label[keep],
"boxes": pred_box[keep],
"scores": pred_score[keep],
}
results.append(blob)
return results
def deploy(
self,
):
self.eval()
self.deploy_mode = True
return self