rlawjdghek's picture
prep (#1)
61c2d32 verified
raw
history blame
3 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import numpy as np
from typing import Any, List
from detectron2.modeling import ROI_MASK_HEAD_REGISTRY
from detectron2.modeling.roi_heads.mask_head import MaskRCNNConvUpsampleHead, mask_rcnn_inference
from detectron2.projects.point_rend import ImplicitPointRendMaskHead
from detectron2.projects.point_rend.point_features import point_sample
from detectron2.projects.point_rend.point_head import roi_mask_point_loss
from detectron2.structures import Instances
from .point_utils import get_point_coords_from_point_annotation
__all__ = [
"ImplicitPointRendPointSupHead",
"MaskRCNNConvUpsamplePointSupHead",
]
@ROI_MASK_HEAD_REGISTRY.register()
class MaskRCNNConvUpsamplePointSupHead(MaskRCNNConvUpsampleHead):
"""
A mask head with several conv layers, plus an upsample layer (with `ConvTranspose2d`).
Predictions are made with a final 1x1 conv layer.
The difference with `MaskRCNNConvUpsampleHead` is that this head is trained
with point supervision. Please use the `MaskRCNNConvUpsampleHead` if you want
to train the model with mask supervision.
"""
def forward(self, x, instances: List[Instances]) -> Any:
"""
Args:
x: input region feature(s) provided by :class:`ROIHeads`.
instances (list[Instances]): contains the boxes & labels corresponding
to the input features.
Exact format is up to its caller to decide.
Typically, this is the foreground instances in training, with
"proposal_boxes" field and other gt annotations.
In inference, it contains boxes that are already predicted.
Returns:
A dict of losses in training. The predicted "instances" in inference.
"""
x = self.layers(x)
if self.training:
N, C, H, W = x.shape
assert H == W
proposal_boxes = [x.proposal_boxes for x in instances]
assert N == np.sum(len(x) for x in proposal_boxes)
if N == 0:
return {"loss_mask": x.sum() * 0}
# Training with point supervision
point_coords, point_labels = get_point_coords_from_point_annotation(instances)
mask_logits = point_sample(
x,
point_coords,
align_corners=False,
)
return {"loss_mask": roi_mask_point_loss(mask_logits, instances, point_labels)}
else:
mask_rcnn_inference(x, instances)
return instances
@ROI_MASK_HEAD_REGISTRY.register()
class ImplicitPointRendPointSupHead(ImplicitPointRendMaskHead):
def _uniform_sample_train_points(self, instances):
assert self.training
# Please keep in mind that "gt_masks" is not used in this mask head.
point_coords, point_labels = get_point_coords_from_point_annotation(instances)
return point_coords, point_labels