File size: 7,348 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn.functional as F
from torch import nn

from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from maskrcnn_benchmark.utils.amp import custom_fwd, custom_bwd


class PostProcessor(nn.Module):
    """

    From a set of classification scores, box regression and proposals,

    computes the post-processed boxes, and applies NMS to obtain the

    final results

    """

    def __init__(self, score_thresh=0.05, nms=0.5, detections_per_img=100, box_coder=None):
        """

        Arguments:

            score_thresh (float)

            nms (float)

            detections_per_img (int)

            box_coder (BoxCoder)

        """
        super(PostProcessor, self).__init__()
        self.score_thresh = score_thresh
        self.nms = nms
        self.detections_per_img = detections_per_img
        if box_coder is None:
            box_coder = BoxCoder(weights=(10.0, 10.0, 5.0, 5.0))
        self.box_coder = box_coder

    @custom_fwd(cast_inputs=torch.float32)
    def forward(self, x, boxes):
        """

        Arguments:

            x (tuple[tensor, tensor]): x contains the class logits

                and the box_regression from the model.

            boxes (list[BoxList]): bounding boxes that are used as

                reference, one for ech image



        Returns:

            results (list[BoxList]): one BoxList for each image, containing

                the extra fields labels and scores

        """
        class_logits, box_regression = x
        class_prob = F.softmax(class_logits, -1)

        # TODO think about a representation of batch of boxes
        image_shapes = [box.size for box in boxes]
        boxes_per_image = [len(box) for box in boxes]
        concat_boxes = torch.cat([a.bbox for a in boxes], dim=0)

        extra_fields = [{} for box in boxes]
        if boxes[0].has_field("cbox"):
            concat_cboxes = torch.cat([a.get_field("cbox").bbox for a in boxes], dim=0)
            concat_cscores = torch.cat([a.get_field("cbox").get_field("scores") for a in boxes], dim=0)
            for cbox, cscore, extra_field in zip(
                concat_cboxes.split(boxes_per_image, dim=0), concat_cscores.split(boxes_per_image, dim=0), extra_fields
            ):
                extra_field["cbox"] = cbox
                extra_field["cscore"] = cscore

        proposals = self.box_coder.decode(box_regression.view(sum(boxes_per_image), -1), concat_boxes)

        num_classes = class_prob.shape[1]

        proposals = proposals.split(boxes_per_image, dim=0)
        class_prob = class_prob.split(boxes_per_image, dim=0)

        results = []
        for prob, boxes_per_img, image_shape, extra_field in zip(class_prob, proposals, image_shapes, extra_fields):
            boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape, extra_field)
            boxlist = boxlist.clip_to_image(remove_empty=False)
            boxlist = self.filter_results(boxlist, num_classes)
            results.append(boxlist)
        return results

    def prepare_boxlist(self, boxes, scores, image_shape, extra_field={}):
        """

        Returns BoxList from `boxes` and adds probability scores information

        as an extra field

        `boxes` has shape (#detections, 4 * #classes), where each row represents

        a list of predicted bounding boxes for each of the object classes in the

        dataset (including the background class). The detections in each row

        originate from the same object proposal.

        `scores` has shape (#detection, #classes), where each row represents a list

        of object detection confidence scores for each of the object classes in the

        dataset (including the background class). `scores[i, j]`` corresponds to the

        box at `boxes[i, j * 4:(j + 1) * 4]`.

        """
        boxes = boxes.reshape(-1, 4)
        scores = scores.reshape(-1)
        boxlist = BoxList(boxes, image_shape, mode="xyxy")
        boxlist.add_field("scores", scores)
        for key, val in extra_field.items():
            boxlist.add_field(key, val)
        return boxlist

    def filter_results(self, boxlist, num_classes):
        """Returns bounding-box detection results by thresholding on scores and

        applying non-maximum suppression (NMS).

        """
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
        scores = boxlist.get_field("scores").reshape(-1, num_classes)
        if boxlist.has_field("cbox"):
            cboxes = boxlist.get_field("cbox").reshape(-1, 4)
            cscores = boxlist.get_field("cscore")
        else:
            cboxes = None

        device = scores.device
        result = []
        # Apply threshold on detection probabilities and apply NMS
        # Skip j = 0, because it's the background class
        inds_all = scores > self.score_thresh
        for j in range(1, num_classes):
            inds = inds_all[:, j].nonzero().squeeze(1)
            scores_j = scores[inds, j]
            boxes_j = boxes[inds, j * 4 : (j + 1) * 4]
            boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
            boxlist_for_class.add_field("scores", scores_j)
            if cboxes is not None:
                cboxes_j = cboxes[inds, :]
                cscores_j = cscores[inds]
                cbox_boxlist = BoxList(cboxes_j, boxlist.size, mode="xyxy")
                cbox_boxlist.add_field("scores", cscores_j)
                boxlist_for_class.add_field("cbox", cbox_boxlist)

            boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms, score_field="scores")
            num_labels = len(boxlist_for_class)
            boxlist_for_class.add_field("labels", torch.full((num_labels,), j, dtype=torch.int64, device=device))
            result.append(boxlist_for_class)

        result = cat_boxlist(result)
        number_of_detections = len(result)

        # Limit to max_per_image detections **over all classes**
        if number_of_detections > self.detections_per_img > 0:
            cls_scores = result.get_field("scores")
            image_thresh, _ = torch.kthvalue(cls_scores.cpu(), number_of_detections - self.detections_per_img + 1)
            keep = cls_scores >= image_thresh.item()
            keep = torch.nonzero(keep).squeeze(1)
            result = result[keep]
        return result


def make_roi_box_post_processor(cfg):
    use_fpn = cfg.MODEL.ROI_HEADS.USE_FPN

    bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS
    box_coder = BoxCoder(weights=bbox_reg_weights)

    score_thresh = cfg.MODEL.ROI_HEADS.SCORE_THRESH
    nms_thresh = cfg.MODEL.ROI_HEADS.NMS
    detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG

    postprocessor = PostProcessor(score_thresh, nms_thresh, detections_per_img, box_coder)
    return postprocessor