File size: 8,396 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import math
import torch
import torch.nn.functional as F
from torch import nn

from .inference import make_atss_postprocessor
from .loss import make_atss_loss_evaluator

from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from maskrcnn_benchmark.layers import Scale, DFConv2d, DYReLU, SELayer
from .anchor_generator import make_anchor_generator_complex


class BoxCoder(object):
    def __init__(self, cfg):
        self.cfg = cfg

    def encode(self, gt_boxes, anchors):

        TO_REMOVE = 1  # TODO remove
        ex_widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE
        ex_heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE
        ex_ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2
        ex_ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2

        gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0] + TO_REMOVE
        gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1] + TO_REMOVE
        gt_ctr_x = (gt_boxes[:, 2] + gt_boxes[:, 0]) / 2
        gt_ctr_y = (gt_boxes[:, 3] + gt_boxes[:, 1]) / 2

        wx, wy, ww, wh = (10.0, 10.0, 5.0, 5.0)
        targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
        targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
        targets_dw = ww * torch.log(gt_widths / ex_widths)
        targets_dh = wh * torch.log(gt_heights / ex_heights)
        targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)

        return targets

    def decode(self, preds, anchors):

        anchors = anchors.to(preds.dtype)

        TO_REMOVE = 1  # TODO remove
        widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE
        heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE
        ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2
        ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2

        wx, wy, ww, wh = (10.0, 10.0, 5.0, 5.0)
        dx = preds[:, 0::4] / wx
        dy = preds[:, 1::4] / wy
        dw = preds[:, 2::4] / ww
        dh = preds[:, 3::4] / wh

        # Prevent sending too large values into torch.exp()
        dw = torch.clamp(dw, max=math.log(1000.0 / 16))
        dh = torch.clamp(dh, max=math.log(1000.0 / 16))

        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
        pred_w = torch.exp(dw) * widths[:, None]
        pred_h = torch.exp(dh) * heights[:, None]

        pred_boxes = torch.zeros_like(preds)
        pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * (pred_w - 1)
        pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * (pred_h - 1)
        pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * (pred_w - 1)
        pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * (pred_h - 1)

        return pred_boxes


class ATSSHead(torch.nn.Module):
    def __init__(self, cfg):
        super(ATSSHead, self).__init__()
        self.cfg = cfg
        num_classes = cfg.MODEL.ATSS.NUM_CLASSES - 1
        num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE
        in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
        channels = cfg.MODEL.ATSS.CHANNELS
        use_gn = cfg.MODEL.ATSS.USE_GN
        use_bn = cfg.MODEL.ATSS.USE_BN
        use_dcn_in_tower = cfg.MODEL.ATSS.USE_DFCONV
        use_dyrelu = cfg.MODEL.ATSS.USE_DYRELU
        use_se = cfg.MODEL.ATSS.USE_SE

        cls_tower = []
        bbox_tower = []
        for i in range(cfg.MODEL.ATSS.NUM_CONVS):
            if use_dcn_in_tower and i == cfg.MODEL.ATSS.NUM_CONVS - 1:
                conv_func = DFConv2d
            else:
                conv_func = nn.Conv2d

            cls_tower.append(
                conv_func(in_channels if i == 0 else channels, channels, kernel_size=3, stride=1, padding=1, bias=True)
            )
            if use_gn:
                cls_tower.append(nn.GroupNorm(32, channels))
            if use_bn:
                cls_tower.append(nn.BatchNorm2d(channels))
            if use_se:
                cls_tower.append(SELayer(channels))
            if use_dyrelu:
                cls_tower.append(DYReLU(channels, channels))
            else:
                cls_tower.append(nn.ReLU())

            bbox_tower.append(
                conv_func(in_channels if i == 0 else channels, channels, kernel_size=3, stride=1, padding=1, bias=True)
            )
            if use_gn:
                bbox_tower.append(nn.GroupNorm(32, channels))
            if use_bn:
                bbox_tower.append(nn.BatchNorm2d(channels))
            if use_se:
                bbox_tower.append(SELayer(channels))
            if use_dyrelu:
                bbox_tower.append(DYReLU(channels, channels))
            else:
                bbox_tower.append(nn.ReLU())

        self.add_module("cls_tower", nn.Sequential(*cls_tower))
        self.add_module("bbox_tower", nn.Sequential(*bbox_tower))
        self.cls_logits = nn.Conv2d(channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
        self.bbox_pred = nn.Conv2d(channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
        self.centerness = nn.Conv2d(channels, num_anchors * 1, kernel_size=3, stride=1, padding=1)

        # initialization
        for modules in [self.cls_tower, self.bbox_tower, self.cls_logits, self.bbox_pred, self.centerness]:
            for l in modules.modules():
                if isinstance(l, nn.Conv2d):
                    torch.nn.init.normal_(l.weight, std=0.01)
                    torch.nn.init.constant_(l.bias, 0)

        # initialize the bias for focal loss
        prior_prob = cfg.MODEL.ATSS.PRIOR_PROB
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        torch.nn.init.constant_(self.cls_logits.bias, bias_value)

        self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)])

    def forward(self, x):
        logits = []
        bbox_reg = []
        centerness = []
        for l, feature in enumerate(x):
            cls_tower = self.cls_tower(feature)
            box_tower = self.bbox_tower(feature)

            logits.append(self.cls_logits(cls_tower))

            bbox_pred = self.scales[l](self.bbox_pred(box_tower))
            bbox_reg.append(bbox_pred)

            centerness.append(self.centerness(box_tower))
        return logits, bbox_reg, centerness


class ATSSModule(torch.nn.Module):
    def __init__(self, cfg):
        super(ATSSModule, self).__init__()
        self.cfg = cfg
        self.head = ATSSHead(cfg)
        box_coder = BoxCoder(cfg)
        self.loss_evaluator = make_atss_loss_evaluator(cfg, box_coder)
        self.box_selector_train = make_atss_postprocessor(cfg, box_coder, is_train=True)
        self.box_selector_test = make_atss_postprocessor(cfg, box_coder, is_train=False)
        self.anchor_generator = make_anchor_generator_complex(cfg)

    def forward(self, images, features, targets=None):
        box_cls, box_regression, centerness = self.head(features)
        anchors = self.anchor_generator(images, features)

        if self.training:
            return self._forward_train(box_cls, box_regression, centerness, targets, anchors)
        else:
            return self._forward_test(box_cls, box_regression, centerness, anchors)

    def _forward_train(self, box_cls, box_regression, centerness, targets, anchors):
        loss_box_cls, loss_box_reg, loss_centerness = self.loss_evaluator(
            box_cls, box_regression, centerness, targets, anchors
        )
        losses = {"loss_cls": loss_box_cls, "loss_reg": loss_box_reg, "loss_centerness": loss_centerness}
        if self.cfg.MODEL.RPN_ONLY:
            return None, losses
        else:
            boxes = self.box_selector_train(box_cls, box_regression, centerness, anchors)
            train_boxes = []
            for b, a in zip(boxes, anchors):
                a = cat_boxlist(a)
                b.add_field("visibility", torch.ones(b.bbox.shape[0], dtype=torch.bool, device=b.bbox.device))
                del b.extra_fields["scores"]
                del b.extra_fields["labels"]
                train_boxes.append(cat_boxlist([b, a]))
            return train_boxes, losses

    def _forward_test(self, box_cls, box_regression, centerness, anchors):
        boxes = self.box_selector_test(box_cls, box_regression, centerness, anchors)
        return boxes, {}