File size: 2,478 Bytes
e85fecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""

Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)

Copyright(c) 2023 lyuwenyu. All Rights Reserved.

"""

from typing import Dict

import torch
import torch.distributed
import torch.nn.functional as F
import torchvision
from torch import Tensor

from ...core import register

__all__ = [
    "DetNMSPostProcessor",
]


@register()
class DetNMSPostProcessor(torch.nn.Module):
    def __init__(

        self,

        iou_threshold=0.7,

        score_threshold=0.01,

        keep_topk=300,

        box_fmt="cxcywh",

        logit_fmt="sigmoid",

    ) -> None:
        super().__init__()
        self.iou_threshold = iou_threshold
        self.score_threshold = score_threshold
        self.keep_topk = keep_topk
        self.box_fmt = box_fmt.lower()
        self.logit_fmt = logit_fmt.lower()
        self.logit_func = getattr(F, self.logit_fmt, None)
        self.deploy_mode = False

    def forward(self, outputs: Dict[str, Tensor], orig_target_sizes: Tensor):
        logits, boxes = outputs["pred_logits"], outputs["pred_boxes"]
        pred_boxes = torchvision.ops.box_convert(boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
        pred_boxes *= orig_target_sizes.repeat(1, 2).unsqueeze(1)

        values, pred_labels = torch.max(logits, dim=-1)

        if self.logit_func:
            pred_scores = self.logit_func(values)
        else:
            pred_scores = values

        # TODO for onnx export
        if self.deploy_mode:
            blobs = {
                "pred_labels": pred_labels,
                "pred_boxes": pred_boxes,
                "pred_scores": pred_scores,
            }
            return blobs

        results = []
        for i in range(logits.shape[0]):
            score_keep = pred_scores[i] > self.score_threshold
            pred_box = pred_boxes[i][score_keep]
            pred_label = pred_labels[i][score_keep]
            pred_score = pred_scores[i][score_keep]

            keep = torchvision.ops.batched_nms(pred_box, pred_score, pred_label, self.iou_threshold)
            keep = keep[: self.keep_topk]

            blob = {
                "labels": pred_label[keep],
                "boxes": pred_box[keep],
                "scores": pred_score[keep],
            }

            results.append(blob)

        return results

    def deploy(

        self,

    ):
        self.eval()
        self.deploy_mode = True
        return self