Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 4,761 Bytes
			
			| 938e515 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, List
from torch import nn
from detectron2.config import CfgNode
from detectron2.structures import Instances
from .cycle_pix2shape import PixToShapeCycleLoss
from .cycle_shape2shape import ShapeToShapeCycleLoss
from .embed import EmbeddingLoss
from .embed_utils import CseAnnotationsAccumulator
from .mask_or_segm import MaskOrSegmentationLoss
from .registry import DENSEPOSE_LOSS_REGISTRY
from .soft_embed import SoftEmbeddingLoss
from .utils import BilinearInterpolationHelper, LossDict, extract_packed_annotations_from_matches
@DENSEPOSE_LOSS_REGISTRY.register()
class DensePoseCseLoss:
    """ """
    _EMBED_LOSS_REGISTRY = {
        EmbeddingLoss.__name__: EmbeddingLoss,
        SoftEmbeddingLoss.__name__: SoftEmbeddingLoss,
    }
    def __init__(self, cfg: CfgNode):
        """
        Initialize CSE loss from configuration options
        Args:
            cfg (CfgNode): configuration options
        """
        self.w_segm = cfg.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS
        self.w_embed = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT
        self.segm_loss = MaskOrSegmentationLoss(cfg)
        self.embed_loss = DensePoseCseLoss.create_embed_loss(cfg)
        self.do_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.ENABLED
        if self.do_shape2shape:
            self.w_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT
            self.shape2shape_loss = ShapeToShapeCycleLoss(cfg)
        self.do_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.ENABLED
        if self.do_pix2shape:
            self.w_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT
            self.pix2shape_loss = PixToShapeCycleLoss(cfg)
    @classmethod
    def create_embed_loss(cls, cfg: CfgNode):
        # registry not used here, since embedding losses are currently local
        # and are not used anywhere else
        return cls._EMBED_LOSS_REGISTRY[cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME](cfg)
    def __call__(
        self,
        proposals_with_gt: List[Instances],
        densepose_predictor_outputs: Any,
        embedder: nn.Module,
    ) -> LossDict:
        if not len(proposals_with_gt):
            return self.produce_fake_losses(densepose_predictor_outputs, embedder)
        accumulator = CseAnnotationsAccumulator()
        packed_annotations = extract_packed_annotations_from_matches(proposals_with_gt, accumulator)
        if packed_annotations is None:
            return self.produce_fake_losses(densepose_predictor_outputs, embedder)
        h, w = densepose_predictor_outputs.embedding.shape[2:]
        interpolator = BilinearInterpolationHelper.from_matches(
            packed_annotations,
            (h, w),
        )
        meshid_to_embed_losses = self.embed_loss(
            proposals_with_gt,
            densepose_predictor_outputs,
            packed_annotations,
            interpolator,
            embedder,
        )
        embed_loss_dict = {
            f"loss_densepose_E{meshid}": self.w_embed * meshid_to_embed_losses[meshid]
            for meshid in meshid_to_embed_losses
        }
        all_loss_dict = {
            "loss_densepose_S": self.w_segm
            * self.segm_loss(proposals_with_gt, densepose_predictor_outputs, packed_annotations),
            **embed_loss_dict,
        }
        if self.do_shape2shape:
            all_loss_dict["loss_shape2shape"] = self.w_shape2shape * self.shape2shape_loss(embedder)
        if self.do_pix2shape:
            all_loss_dict["loss_pix2shape"] = self.w_pix2shape * self.pix2shape_loss(
                proposals_with_gt, densepose_predictor_outputs, packed_annotations, embedder
            )
        return all_loss_dict
    def produce_fake_losses(
        self, densepose_predictor_outputs: Any, embedder: nn.Module
    ) -> LossDict:
        meshname_to_embed_losses = self.embed_loss.fake_values(
            densepose_predictor_outputs, embedder=embedder
        )
        embed_loss_dict = {
            f"loss_densepose_E{mesh_name}": meshname_to_embed_losses[mesh_name]
            for mesh_name in meshname_to_embed_losses
        }
        all_loss_dict = {
            "loss_densepose_S": self.segm_loss.fake_value(densepose_predictor_outputs),
            **embed_loss_dict,
        }
        if self.do_shape2shape:
            all_loss_dict["loss_shape2shape"] = self.shape2shape_loss.fake_value(embedder)
        if self.do_pix2shape:
            all_loss_dict["loss_pix2shape"] = self.pix2shape_loss.fake_value(
                densepose_predictor_outputs, embedder
            )
        return all_loss_dict
 | 
