# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from typing import Any, List import torch from torch.nn import functional as F from detectron2.config import CfgNode from detectron2.structures import Instances from .utils import resample_data class SegmentationLoss: """ Segmentation loss as cross-entropy for raw unnormalized scores given ground truth labels. Segmentation ground truth labels are defined for the bounding box of interest at some fixed resolution [S, S], where S = MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE. """ def __init__(self, cfg: CfgNode): """ Initialize segmentation loss from configuration options Args: cfg (CfgNode): configuration options """ self.heatmap_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE self.n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS def __call__( self, proposals_with_gt: List[Instances], densepose_predictor_outputs: Any, packed_annotations: Any, ) -> torch.Tensor: """ Compute segmentation loss as cross-entropy on aligned segmentation ground truth and estimated scores. Args: proposals_with_gt (list of Instances): detections with associated ground truth data densepose_predictor_outputs: an object of a dataclass that contains predictor outputs with estimated values; assumed to have the following attributes: * coarse_segm - coarse segmentation estimates, tensor of shape [N, D, S, S] packed_annotations: packed annotations for efficient loss computation; the following attributes are used: - coarse_segm_gt - bbox_xywh_gt - bbox_xywh_est """ if packed_annotations.coarse_segm_gt is None: return self.fake_value(densepose_predictor_outputs) coarse_segm_est = densepose_predictor_outputs.coarse_segm[packed_annotations.bbox_indices] with torch.no_grad(): coarse_segm_gt = resample_data( packed_annotations.coarse_segm_gt.unsqueeze(1), packed_annotations.bbox_xywh_gt, packed_annotations.bbox_xywh_est, self.heatmap_size, self.heatmap_size, mode="nearest", padding_mode="zeros", ).squeeze(1) if self.n_segm_chan == 2: coarse_segm_gt = coarse_segm_gt > 0 return F.cross_entropy(coarse_segm_est, coarse_segm_gt.long()) def fake_value(self, densepose_predictor_outputs: Any) -> torch.Tensor: """ Fake segmentation loss used when no suitable ground truth data was found in a batch. The loss has a value 0 and is primarily used to construct the computation graph, so that `DistributedDataParallel` has similar graphs on all GPUs and can perform reduction properly. Args: densepose_predictor_outputs: DensePose predictor outputs, an object of a dataclass that is assumed to have `coarse_segm` attribute Return: Zero value loss with proper computation graph """ return densepose_predictor_outputs.coarse_segm.sum() * 0