| |
|
|
| from typing import Any, List |
| from torch import nn |
|
|
| from detectron2.config import CfgNode |
| from detectron2.structures import Instances |
|
|
| from .cycle_pix2shape import PixToShapeCycleLoss |
| from .cycle_shape2shape import ShapeToShapeCycleLoss |
| from .embed import EmbeddingLoss |
| from .embed_utils import CseAnnotationsAccumulator |
| from .mask_or_segm import MaskOrSegmentationLoss |
| from .registry import DENSEPOSE_LOSS_REGISTRY |
| from .soft_embed import SoftEmbeddingLoss |
| from .utils import BilinearInterpolationHelper, LossDict, extract_packed_annotations_from_matches |
|
|
|
|
| @DENSEPOSE_LOSS_REGISTRY.register() |
| class DensePoseCseLoss: |
| """ """ |
|
|
| _EMBED_LOSS_REGISTRY = { |
| EmbeddingLoss.__name__: EmbeddingLoss, |
| SoftEmbeddingLoss.__name__: SoftEmbeddingLoss, |
| } |
|
|
| def __init__(self, cfg: CfgNode): |
| """ |
| Initialize CSE loss from configuration options |
| |
| Args: |
| cfg (CfgNode): configuration options |
| """ |
| self.w_segm = cfg.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS |
| self.w_embed = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT |
| self.segm_loss = MaskOrSegmentationLoss(cfg) |
| self.embed_loss = DensePoseCseLoss.create_embed_loss(cfg) |
| self.do_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.ENABLED |
| if self.do_shape2shape: |
| self.w_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT |
| self.shape2shape_loss = ShapeToShapeCycleLoss(cfg) |
| self.do_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.ENABLED |
| if self.do_pix2shape: |
| self.w_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT |
| self.pix2shape_loss = PixToShapeCycleLoss(cfg) |
|
|
| @classmethod |
| def create_embed_loss(cls, cfg: CfgNode): |
| |
| |
| return cls._EMBED_LOSS_REGISTRY[cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME](cfg) |
|
|
| def __call__( |
| self, |
| proposals_with_gt: List[Instances], |
| densepose_predictor_outputs: Any, |
| embedder: nn.Module, |
| ) -> LossDict: |
| if not len(proposals_with_gt): |
| return self.produce_fake_losses(densepose_predictor_outputs, embedder) |
| accumulator = CseAnnotationsAccumulator() |
| packed_annotations = extract_packed_annotations_from_matches(proposals_with_gt, accumulator) |
| if packed_annotations is None: |
| return self.produce_fake_losses(densepose_predictor_outputs, embedder) |
| h, w = densepose_predictor_outputs.embedding.shape[2:] |
| interpolator = BilinearInterpolationHelper.from_matches( |
| packed_annotations, |
| (h, w), |
| ) |
| meshid_to_embed_losses = self.embed_loss( |
| proposals_with_gt, |
| densepose_predictor_outputs, |
| packed_annotations, |
| interpolator, |
| embedder, |
| ) |
| embed_loss_dict = { |
| f"loss_densepose_E{meshid}": self.w_embed * meshid_to_embed_losses[meshid] |
| for meshid in meshid_to_embed_losses |
| } |
| all_loss_dict = { |
| "loss_densepose_S": self.w_segm |
| * self.segm_loss(proposals_with_gt, densepose_predictor_outputs, packed_annotations), |
| **embed_loss_dict, |
| } |
| if self.do_shape2shape: |
| all_loss_dict["loss_shape2shape"] = self.w_shape2shape * self.shape2shape_loss(embedder) |
| if self.do_pix2shape: |
| all_loss_dict["loss_pix2shape"] = self.w_pix2shape * self.pix2shape_loss( |
| proposals_with_gt, densepose_predictor_outputs, packed_annotations, embedder |
| ) |
| return all_loss_dict |
|
|
| def produce_fake_losses( |
| self, densepose_predictor_outputs: Any, embedder: nn.Module |
| ) -> LossDict: |
| meshname_to_embed_losses = self.embed_loss.fake_values( |
| densepose_predictor_outputs, embedder=embedder |
| ) |
| embed_loss_dict = { |
| f"loss_densepose_E{mesh_name}": meshname_to_embed_losses[mesh_name] |
| for mesh_name in meshname_to_embed_losses |
| } |
| all_loss_dict = { |
| "loss_densepose_S": self.segm_loss.fake_value(densepose_predictor_outputs), |
| **embed_loss_dict, |
| } |
| if self.do_shape2shape: |
| all_loss_dict["loss_shape2shape"] = self.shape2shape_loss.fake_value(embedder) |
| if self.do_pix2shape: |
| all_loss_dict["loss_pix2shape"] = self.pix2shape_loss.fake_value( |
| densepose_predictor_outputs, embedder |
| ) |
| return all_loss_dict |
|
|