Spaces:
Running
on
Zero
Running
on
Zero
"""Mask RCNN model implementation and runtime.""" | |
from __future__ import annotations | |
from typing import NamedTuple | |
import torch | |
from torch import nn | |
from vis4d.common.ckpt import load_model_checkpoint | |
from vis4d.op.base import BaseModel, ResNet | |
from vis4d.op.box.box2d import apply_mask, scale_and_clip_boxes | |
from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder | |
from vis4d.op.detect.common import DetOut | |
from vis4d.op.detect.faster_rcnn import FasterRCNNHead, FRCNNOut | |
from vis4d.op.detect.mask_rcnn import ( | |
Det2Mask, | |
MaskOut, | |
MaskRCNNHead, | |
MaskRCNNHeadOut, | |
) | |
from vis4d.op.detect.rcnn import RoI2Det | |
from vis4d.op.fpp.fpn import FPN | |
class MaskDetectionOut(NamedTuple): | |
"""Mask detection output.""" | |
boxes: DetOut | |
masks: MaskOut | |
class MaskRCNNOut(NamedTuple): | |
"""Mask RCNN output.""" | |
boxes: FRCNNOut | |
masks: MaskRCNNHeadOut | |
REV_KEYS = [ | |
(r"^backbone\.", "basemodel."), | |
(r"^rpn_head.rpn_reg\.", "rpn_head.rpn_box."), | |
(r"^roi_head.bbox_head\.", "roi_head."), | |
(r"^roi_head.mask_head\.", "mask_head."), | |
(r"^convs\.", "mask_head.convs."), | |
(r"^upsample\.", "mask_head.upsample."), | |
(r"^conv_logits\.", "mask_head.conv_logits."), | |
(r"^roi_head\.", "faster_rcnn_head.roi_head."), | |
(r"^rpn_head\.", "faster_rcnn_head.rpn_head."), | |
(r"^neck.lateral_convs\.", "fpn.inner_blocks."), | |
(r"^neck.fpn_convs\.", "fpn.layer_blocks."), | |
(r"\.conv.weight", ".weight"), | |
(r"\.conv.bias", ".bias"), | |
] | |
class MaskRCNN(nn.Module): | |
"""Mask RCNN model. | |
Args: | |
num_classes (int): Number of classes. | |
basemodel (BaseModel, optional): Base model network. Defaults to | |
None. If None, will use ResNet50. | |
faster_rcnn_head (FasterRCNNHead, optional): Faster RCNN head. | |
Defaults to None. if None, will use default FasterRCNNHead. | |
mask_head (MaskRCNNHead, optional): Mask RCNN head. Defaults to | |
None. if None, will use default MaskRCNNHead. | |
rcnn_box_decoder (DeltaXYWHBBoxDecoder, optional): Decoder for RCNN | |
bounding boxes. Defaults to None. | |
no_overlap (bool, optional): Whether to remove overlapping pixels | |
between masks. Defaults to False. | |
weights (None | str, optional): Weights to load for model. If set | |
to "mmdet", will load MMDetection pre-trained weights. | |
Defaults to None. | |
""" | |
def __init__( | |
self, | |
num_classes: int, | |
basemodel: BaseModel | None = None, | |
faster_rcnn_head: FasterRCNNHead | None = None, | |
mask_head: MaskRCNNHead | None = None, | |
rcnn_box_decoder: DeltaXYWHBBoxDecoder | None = None, | |
no_overlap: bool = False, | |
weights: None | str = None, | |
) -> None: | |
"""Creates an instance of the class.""" | |
super().__init__() | |
self.basemodel = ( | |
ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3) | |
if basemodel is None | |
else basemodel | |
) | |
self.fpn = FPN(self.basemodel.out_channels[2:], 256) | |
if faster_rcnn_head is None: | |
self.faster_rcnn_head = FasterRCNNHead(num_classes=num_classes) | |
else: | |
self.faster_rcnn_head = faster_rcnn_head | |
if mask_head is None: | |
self.mask_head = MaskRCNNHead(num_classes=num_classes) | |
else: | |
self.mask_head = mask_head | |
self.transform_outs = RoI2Det(rcnn_box_decoder) | |
self.det2mask = Det2Mask(no_overlap=no_overlap) | |
if weights is not None: | |
if weights == "mmdet": | |
weights = ( | |
"mmdet://mask_rcnn/mask_rcnn_r50_fpn_2x_coco/" | |
"mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392__segm_mAP-0.354_" | |
"20200505_003907-3e542a40.pth" | |
) | |
if weights.startswith("mmdet://") or weights.startswith( | |
"bdd100k://" | |
): | |
load_model_checkpoint(self, weights, rev_keys=REV_KEYS) | |
else: | |
load_model_checkpoint(self, weights) | |
def forward( | |
self, | |
images: torch.Tensor, | |
input_hw: list[tuple[int, int]], | |
boxes2d: None | list[torch.Tensor] = None, | |
boxes2d_classes: None | list[torch.Tensor] = None, | |
original_hw: None | list[tuple[int, int]] = None, | |
) -> MaskRCNNOut | MaskDetectionOut: | |
"""Forward pass. | |
Args: | |
images (torch.Tensor): Input images. | |
input_hw (list[tuple[int, int]]): Input image resolutions. | |
boxes2d (None | list[torch.Tensor], optional): Bounding box | |
labels. Required for training. Defaults to None. | |
boxes2d_classes (None | list[torch.Tensor], optional): Class | |
labels. Required for training. Defaults to None. | |
original_hw (None | list[tuple[int, int]], optional): Original | |
image resolutions (before padding and resizing). Required for | |
testing. Defaults to None. | |
Returns: | |
MaskRCNNOut | MaskDetectionOut: Either raw model | |
outputs (for training) or predicted outputs (for testing). | |
""" | |
if self.training: | |
assert boxes2d is not None and boxes2d_classes is not None | |
return self.forward_train( | |
images, input_hw, boxes2d, boxes2d_classes | |
) | |
assert original_hw is not None | |
return self.forward_test(images, input_hw, original_hw) | |
def forward_train( | |
self, | |
images: torch.Tensor, | |
images_hw: list[tuple[int, int]], | |
target_boxes: list[torch.Tensor], | |
target_classes: list[torch.Tensor], | |
) -> MaskRCNNOut: | |
"""Forward training stage. | |
Args: | |
images (torch.Tensor): Input images. | |
images_hw (list[tuple[int, int]]): Input image resolutions. | |
target_boxes (list[torch.Tensor]): Bounding box labels. Required | |
for training. Defaults to None. | |
target_classes (list[torch.Tensor]): Class labels. Required for | |
training. Defaults to None. | |
Returns: | |
MaskRCNNOut: Raw model outputs. | |
""" | |
features = self.fpn(self.basemodel(images)) | |
outputs = self.faster_rcnn_head( | |
features, images_hw, target_boxes, target_classes | |
) | |
assert outputs.sampled_proposals is not None | |
assert outputs.sampled_targets is not None | |
pos_proposals = apply_mask( | |
[torch.eq(label, 1) for label in outputs.sampled_targets.labels], | |
outputs.sampled_proposals.boxes, | |
)[0] | |
mask_outs = self.mask_head(features, pos_proposals) | |
return MaskRCNNOut(outputs, mask_outs) | |
def forward_test( | |
self, | |
images: torch.Tensor, | |
images_hw: list[tuple[int, int]], | |
original_hw: list[tuple[int, int]], | |
) -> MaskDetectionOut: | |
"""Forward testing stage. | |
Args: | |
images (torch.Tensor): Input images. | |
images_hw (list[tuple[int, int]]): Input image resolutions. | |
original_hw (list[tuple[int, int]]): Original image resolutions | |
(before padding and resizing). | |
Returns: | |
MaskDetectionOut: Predicted outputs. | |
""" | |
features = self.fpn(self.basemodel(images)) | |
outs = self.faster_rcnn_head(features, images_hw) | |
boxes, scores, class_ids = self.transform_outs( | |
*outs.roi, outs.proposals.boxes, images_hw | |
) | |
mask_outs = self.mask_head(features, boxes) | |
for i, boxs in enumerate(boxes): | |
boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i]) | |
mask_preds = [m.sigmoid() for m in mask_outs.mask_pred] | |
masks = self.det2mask( | |
mask_preds, boxes, scores, class_ids, original_hw | |
) | |
return MaskDetectionOut(DetOut(boxes, scores, class_ids), masks) | |