File size: 6,388 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""RetinaNet model implementation and runtime."""

from __future__ import annotations

from torch import Tensor, nn

from vis4d.common.ckpt import load_model_checkpoint
from vis4d.common.typing import LossesType
from vis4d.op.base.resnet import ResNet
from vis4d.op.box.anchor import AnchorGenerator
from vis4d.op.box.box2d import scale_and_clip_boxes
from vis4d.op.box.encoder import DeltaXYWHBBoxEncoder
from vis4d.op.box.matchers import Matcher
from vis4d.op.box.samplers import Sampler
from vis4d.op.detect.common import DetOut
from vis4d.op.detect.retinanet import (
    Dense2Det,
    RetinaNetHead,
    RetinaNetHeadLoss,
    RetinaNetOut,
)
from vis4d.op.fpp.fpn import FPN, ExtraFPNBlock

REV_KEYS = [
    (r"^backbone\.", "basemodel."),
    (r"^bbox_head\.", "retinanet_head."),
    (r"^neck.lateral_convs\.", "fpn.inner_blocks."),
    (r"^neck.fpn_convs\.", "fpn.layer_blocks."),
    (r"^fpn.layer_blocks.3\.", "fpn.extra_blocks.convs.0."),
    (r"^fpn.layer_blocks.4\.", "fpn.extra_blocks.convs.1."),
    (r"\.conv.weight", ".weight"),
    (r"\.conv.bias", ".bias"),
]


class RetinaNet(nn.Module):
    """RetinaNet wrapper class for checkpointing etc."""

    def __init__(self, num_classes: int, weights: None | str = None) -> None:
        """Creates an instance of the class.

        Args:
            num_classes (int): Number of classes.
            weights (None | str, optional): Weights to load for model. If
                set to "mmdet", will load MMDetection pre-trained weights.
                Defaults to None.
        """
        super().__init__()
        self.basemodel = ResNet(
            "resnet50", pretrained=True, trainable_layers=3
        )
        self.fpn = FPN(
            self.basemodel.out_channels[3:],
            256,
            ExtraFPNBlock(2, 2048, 256, add_extra_convs="on_input"),
            start_index=3,
        )
        self.retinanet_head = RetinaNetHead(
            num_classes=num_classes, in_channels=256
        )
        self.transform_outs = Dense2Det(
            self.retinanet_head.anchor_generator,
            self.retinanet_head.box_decoder,
            num_pre_nms=1000,
            max_per_img=100,
            nms_threshold=0.5,
            score_thr=0.05,
        )

        if weights == "mmdet":
            weights = (
                "mmdet://retinanet/retinanet_r50_fpn_2x_coco/"
                "retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth"
            )
            load_model_checkpoint(self, weights, rev_keys=REV_KEYS)
        elif weights is not None:
            load_model_checkpoint(self, weights)

    def forward(
        self,
        images: Tensor,
        input_hw: None | list[tuple[int, int]] = None,
        original_hw: None | list[tuple[int, int]] = None,
    ) -> RetinaNetOut | DetOut:
        """Forward pass.

        Args:
            images (Tensor): Input images.
            input_hw (None | list[tuple[int, int]], optional): Input image
                resolutions. Defaults to None.
            original_hw (None | list[tuple[int, int]], optional): Original
                image resolutions (before padding and resizing). Required for
                testing. Defaults to None.

        Returns:
            RetinaNetOut | DetOut: Either raw model outputs (for training) or
                predicted outputs (for testing).
        """
        if self.training:
            return self.forward_train(images)
        assert input_hw is not None and original_hw is not None
        return self.forward_test(images, input_hw, original_hw)

    def forward_train(self, images: Tensor) -> RetinaNetOut:
        """Forward training stage.

        Args:
            images (Tensor): Input images.

        Returns:
            RetinaNetOut: Raw model outputs.
        """
        features = self.fpn(self.basemodel(images))
        return self.retinanet_head(features[-5:])

    def forward_test(
        self,
        images: Tensor,
        images_hw: list[tuple[int, int]],
        original_hw: list[tuple[int, int]],
    ) -> DetOut:
        """Forward testing stage.

        Args:
            images (Tensor): Input images.
            images_hw (list[tuple[int, int]]): Input image resolutions.
            original_hw (list[tuple[int, int]]): Original image resolutions
                (before padding and resizing).

        Returns:
            DetOut: Predicted outputs.
        """
        features = self.fpn(self.basemodel(images))
        outs = self.retinanet_head(features[-5:])
        boxes, scores, class_ids = self.transform_outs(
            cls_outs=outs.cls_score,
            reg_outs=outs.bbox_pred,
            images_hw=images_hw,
        )
        for i, boxs in enumerate(boxes):
            boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i])
        return DetOut(boxes, scores, class_ids)


class RetinaNetLoss(nn.Module):
    """RetinaNet Loss."""

    def __init__(
        self,
        anchor_generator: AnchorGenerator,
        box_encoder: DeltaXYWHBBoxEncoder,
        box_matcher: Matcher,
        box_sampler: Sampler,
    ) -> None:
        """Creates an instance of the class.

        Args:
            anchor_generator (AnchorGenerator): Anchor generator for RPN.
            box_encoder (BoxEncoder2D): Bounding box encoder.
            box_matcher (BaseMatcher): Bounding box matcher.
            box_sampler (BaseSampler): Bounding box sampler.
        """
        super().__init__()
        self.retinanet_loss = RetinaNetHeadLoss(
            anchor_generator, box_encoder, box_matcher, box_sampler
        )

    def forward(
        self,
        outputs: RetinaNetOut,
        images_hw: list[tuple[int, int]],
        target_boxes: list[Tensor],
        target_classes: list[Tensor],
    ) -> LossesType:
        """Forward of loss function.

        Args:
            outputs (RetinaNetOut): Raw model outputs.
            images_hw (list[tuple[int, int]]): Input image resolutions.
            target_boxes (list[Tensor]): Bounding box labels.
            target_classes (list[Tensor]): Class labels.

        Returns:
            LossesType: Dictionary of model losses.
        """
        losses = self.retinanet_loss(
            outputs.cls_score,
            outputs.bbox_pred,
            target_boxes,
            images_hw,
            target_classes,
        )
        return losses._asdict()