File size: 7,952 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""Mask RCNN model implementation and runtime."""

from __future__ import annotations

from typing import NamedTuple

import torch
from torch import nn

from vis4d.common.ckpt import load_model_checkpoint
from vis4d.op.base import BaseModel, ResNet
from vis4d.op.box.box2d import apply_mask, scale_and_clip_boxes
from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder
from vis4d.op.detect.common import DetOut
from vis4d.op.detect.faster_rcnn import FasterRCNNHead, FRCNNOut
from vis4d.op.detect.mask_rcnn import (
    Det2Mask,
    MaskOut,
    MaskRCNNHead,
    MaskRCNNHeadOut,
)
from vis4d.op.detect.rcnn import RoI2Det
from vis4d.op.fpp.fpn import FPN


class MaskDetectionOut(NamedTuple):
    """Mask detection output."""

    boxes: DetOut
    masks: MaskOut


class MaskRCNNOut(NamedTuple):
    """Mask RCNN output."""

    boxes: FRCNNOut
    masks: MaskRCNNHeadOut


REV_KEYS = [
    (r"^backbone\.", "basemodel."),
    (r"^rpn_head.rpn_reg\.", "rpn_head.rpn_box."),
    (r"^roi_head.bbox_head\.", "roi_head."),
    (r"^roi_head.mask_head\.", "mask_head."),
    (r"^convs\.", "mask_head.convs."),
    (r"^upsample\.", "mask_head.upsample."),
    (r"^conv_logits\.", "mask_head.conv_logits."),
    (r"^roi_head\.", "faster_rcnn_head.roi_head."),
    (r"^rpn_head\.", "faster_rcnn_head.rpn_head."),
    (r"^neck.lateral_convs\.", "fpn.inner_blocks."),
    (r"^neck.fpn_convs\.", "fpn.layer_blocks."),
    (r"\.conv.weight", ".weight"),
    (r"\.conv.bias", ".bias"),
]


class MaskRCNN(nn.Module):
    """Mask RCNN model.

    Args:
        num_classes (int): Number of classes.
        basemodel (BaseModel, optional): Base model network. Defaults to
            None. If None, will use ResNet50.
        faster_rcnn_head (FasterRCNNHead, optional): Faster RCNN head.
            Defaults to None. if None, will use default FasterRCNNHead.
        mask_head (MaskRCNNHead, optional): Mask RCNN head. Defaults to
            None. if None, will use default MaskRCNNHead.
        rcnn_box_decoder (DeltaXYWHBBoxDecoder, optional): Decoder for RCNN
            bounding boxes. Defaults to None.
        no_overlap (bool, optional): Whether to remove overlapping pixels
            between masks. Defaults to False.
        weights (None | str, optional): Weights to load for model. If set
            to "mmdet", will load MMDetection pre-trained weights.
            Defaults to None.
    """

    def __init__(
        self,
        num_classes: int,
        basemodel: BaseModel | None = None,
        faster_rcnn_head: FasterRCNNHead | None = None,
        mask_head: MaskRCNNHead | None = None,
        rcnn_box_decoder: DeltaXYWHBBoxDecoder | None = None,
        no_overlap: bool = False,
        weights: None | str = None,
    ) -> None:
        """Creates an instance of the class."""
        super().__init__()
        self.basemodel = (
            ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3)
            if basemodel is None
            else basemodel
        )

        self.fpn = FPN(self.basemodel.out_channels[2:], 256)

        if faster_rcnn_head is None:
            self.faster_rcnn_head = FasterRCNNHead(num_classes=num_classes)
        else:
            self.faster_rcnn_head = faster_rcnn_head

        if mask_head is None:
            self.mask_head = MaskRCNNHead(num_classes=num_classes)
        else:
            self.mask_head = mask_head

        self.transform_outs = RoI2Det(rcnn_box_decoder)
        self.det2mask = Det2Mask(no_overlap=no_overlap)

        if weights is not None:
            if weights == "mmdet":
                weights = (
                    "mmdet://mask_rcnn/mask_rcnn_r50_fpn_2x_coco/"
                    "mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392__segm_mAP-0.354_"
                    "20200505_003907-3e542a40.pth"
                )
            if weights.startswith("mmdet://") or weights.startswith(
                "bdd100k://"
            ):
                load_model_checkpoint(self, weights, rev_keys=REV_KEYS)
            else:
                load_model_checkpoint(self, weights)

    def forward(
        self,
        images: torch.Tensor,
        input_hw: list[tuple[int, int]],
        boxes2d: None | list[torch.Tensor] = None,
        boxes2d_classes: None | list[torch.Tensor] = None,
        original_hw: None | list[tuple[int, int]] = None,
    ) -> MaskRCNNOut | MaskDetectionOut:
        """Forward pass.

        Args:
            images (torch.Tensor): Input images.
            input_hw (list[tuple[int, int]]): Input image resolutions.
            boxes2d (None | list[torch.Tensor], optional): Bounding box
                labels. Required for training. Defaults to None.
            boxes2d_classes (None | list[torch.Tensor], optional): Class
                labels. Required for training. Defaults to None.
            original_hw (None | list[tuple[int, int]], optional): Original
                image resolutions (before padding and resizing). Required for
                testing. Defaults to None.

        Returns:
            MaskRCNNOut | MaskDetectionOut: Either raw model
                outputs (for training) or predicted outputs (for testing).
        """
        if self.training:
            assert boxes2d is not None and boxes2d_classes is not None
            return self.forward_train(
                images, input_hw, boxes2d, boxes2d_classes
            )
        assert original_hw is not None
        return self.forward_test(images, input_hw, original_hw)

    def forward_train(
        self,
        images: torch.Tensor,
        images_hw: list[tuple[int, int]],
        target_boxes: list[torch.Tensor],
        target_classes: list[torch.Tensor],
    ) -> MaskRCNNOut:
        """Forward training stage.

        Args:
            images (torch.Tensor): Input images.
            images_hw (list[tuple[int, int]]): Input image resolutions.
            target_boxes (list[torch.Tensor]): Bounding box labels. Required
                for training. Defaults to None.
            target_classes (list[torch.Tensor]): Class labels. Required for
                training. Defaults to None.

        Returns:
            MaskRCNNOut: Raw model outputs.
        """
        features = self.fpn(self.basemodel(images))
        outputs = self.faster_rcnn_head(
            features, images_hw, target_boxes, target_classes
        )
        assert outputs.sampled_proposals is not None
        assert outputs.sampled_targets is not None
        pos_proposals = apply_mask(
            [torch.eq(label, 1) for label in outputs.sampled_targets.labels],
            outputs.sampled_proposals.boxes,
        )[0]
        mask_outs = self.mask_head(features, pos_proposals)
        return MaskRCNNOut(outputs, mask_outs)

    def forward_test(
        self,
        images: torch.Tensor,
        images_hw: list[tuple[int, int]],
        original_hw: list[tuple[int, int]],
    ) -> MaskDetectionOut:
        """Forward testing stage.

        Args:
            images (torch.Tensor): Input images.
            images_hw (list[tuple[int, int]]): Input image resolutions.
            original_hw (list[tuple[int, int]]): Original image resolutions
                (before padding and resizing).

        Returns:
            MaskDetectionOut: Predicted outputs.
        """
        features = self.fpn(self.basemodel(images))
        outs = self.faster_rcnn_head(features, images_hw)
        boxes, scores, class_ids = self.transform_outs(
            *outs.roi, outs.proposals.boxes, images_hw
        )
        mask_outs = self.mask_head(features, boxes)
        for i, boxs in enumerate(boxes):
            boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i])
        mask_preds = [m.sigmoid() for m in mask_outs.mask_pred]
        masks = self.det2mask(
            mask_preds, boxes, scores, class_ids, original_hw
        )
        return MaskDetectionOut(DetOut(boxes, scores, class_ids), masks)