Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import numpy as np | |
from typing import Any, List | |
from detectron2.modeling import ROI_MASK_HEAD_REGISTRY | |
from detectron2.modeling.roi_heads.mask_head import MaskRCNNConvUpsampleHead, mask_rcnn_inference | |
from detectron2.projects.point_rend import ImplicitPointRendMaskHead | |
from detectron2.projects.point_rend.point_features import point_sample | |
from detectron2.projects.point_rend.point_head import roi_mask_point_loss | |
from detectron2.structures import Instances | |
from .point_utils import get_point_coords_from_point_annotation | |
__all__ = [ | |
"ImplicitPointRendPointSupHead", | |
"MaskRCNNConvUpsamplePointSupHead", | |
] | |
class MaskRCNNConvUpsamplePointSupHead(MaskRCNNConvUpsampleHead): | |
""" | |
A mask head with several conv layers, plus an upsample layer (with `ConvTranspose2d`). | |
Predictions are made with a final 1x1 conv layer. | |
The difference with `MaskRCNNConvUpsampleHead` is that this head is trained | |
with point supervision. Please use the `MaskRCNNConvUpsampleHead` if you want | |
to train the model with mask supervision. | |
""" | |
def forward(self, x, instances: List[Instances]) -> Any: | |
""" | |
Args: | |
x: input region feature(s) provided by :class:`ROIHeads`. | |
instances (list[Instances]): contains the boxes & labels corresponding | |
to the input features. | |
Exact format is up to its caller to decide. | |
Typically, this is the foreground instances in training, with | |
"proposal_boxes" field and other gt annotations. | |
In inference, it contains boxes that are already predicted. | |
Returns: | |
A dict of losses in training. The predicted "instances" in inference. | |
""" | |
x = self.layers(x) | |
if self.training: | |
N, C, H, W = x.shape | |
assert H == W | |
proposal_boxes = [x.proposal_boxes for x in instances] | |
assert N == np.sum(len(x) for x in proposal_boxes) | |
if N == 0: | |
return {"loss_mask": x.sum() * 0} | |
# Training with point supervision | |
point_coords, point_labels = get_point_coords_from_point_annotation(instances) | |
mask_logits = point_sample( | |
x, | |
point_coords, | |
align_corners=False, | |
) | |
return {"loss_mask": roi_mask_point_loss(mask_logits, instances, point_labels)} | |
else: | |
mask_rcnn_inference(x, instances) | |
return instances | |
class ImplicitPointRendPointSupHead(ImplicitPointRendMaskHead): | |
def _uniform_sample_train_points(self, instances): | |
assert self.training | |
# Please keep in mind that "gt_masks" is not used in this mask head. | |
point_coords, point_labels = get_point_coords_from_point_annotation(instances) | |
return point_coords, point_labels | |