# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from typing import Any, Dict, List import torch from torch import nn from torch.nn import functional as F from detectron2.config import CfgNode from detectron2.structures import Instances from densepose.data.meshes.catalog import MeshCatalog from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix from densepose.structures.mesh import create_mesh from .embed_utils import PackedCseAnnotations from .utils import BilinearInterpolationHelper class SoftEmbeddingLoss: """ Computes losses for estimated embeddings given annotated vertices. Instances in a minibatch that correspond to the same mesh are grouped together. For each group, loss is computed as cross-entropy for unnormalized scores given ground truth mesh vertex ids. Scores are based on: 1) squared distances between estimated vertex embeddings and mesh vertex embeddings; 2) geodesic distances between vertices of a mesh """ def __init__(self, cfg: CfgNode): """ Initialize embedding loss from config """ self.embdist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA self.geodist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA def __call__( self, proposals_with_gt: List[Instances], densepose_predictor_outputs: Any, packed_annotations: PackedCseAnnotations, interpolator: BilinearInterpolationHelper, embedder: nn.Module, ) -> Dict[int, torch.Tensor]: """ Produces losses for estimated embeddings given annotated vertices. Embeddings for all the vertices of a mesh are computed by the embedder. Embeddings for observed pixels are estimated by a predictor. Losses are computed as cross-entropy for unnormalized scores given ground truth vertex IDs. 1) squared distances between estimated vertex embeddings and mesh vertex embeddings; 2) geodesic distances between vertices of a mesh Args: proposals_with_gt (list of Instances): detections with associated ground truth data; each item corresponds to instances detected on 1 image; the number of items corresponds to the number of images in a batch densepose_predictor_outputs: an object of a dataclass that contains predictor outputs with estimated values; assumed to have the following attributes: * embedding - embedding estimates, tensor of shape [N, D, S, S], where N = number of instances (= sum N_i, where N_i is the number of instances on image i) D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE) S = output size (width and height) packed_annotations (PackedCseAnnotations): contains various data useful for loss computation, each data is packed into a single tensor interpolator (BilinearInterpolationHelper): bilinear interpolation helper embedder (nn.Module): module that computes vertex embeddings for different meshes Return: dict(int -> tensor): losses for different mesh IDs """ losses = {} for mesh_id_tensor in packed_annotations.vertex_mesh_ids_gt.unique(): mesh_id = mesh_id_tensor.item() mesh_name = MeshCatalog.get_mesh_name(mesh_id) # valid points are those that fall into estimated bbox # and correspond to the current mesh j_valid = interpolator.j_valid * ( # pyre-ignore[16] packed_annotations.vertex_mesh_ids_gt == mesh_id ) if not torch.any(j_valid): continue # extract estimated embeddings for valid points # -> tensor [J, D] vertex_embeddings_i = normalize_embeddings( interpolator.extract_at_points( densepose_predictor_outputs.embedding, slice_fine_segm=slice(None), w_ylo_xlo=interpolator.w_ylo_xlo[:, None], # pyre-ignore[16] w_ylo_xhi=interpolator.w_ylo_xhi[:, None], # pyre-ignore[16] w_yhi_xlo=interpolator.w_yhi_xlo[:, None], # pyre-ignore[16] w_yhi_xhi=interpolator.w_yhi_xhi[:, None], # pyre-ignore[16] )[j_valid, :] ) # extract vertex ids for valid points # -> tensor [J] vertex_indices_i = packed_annotations.vertex_ids_gt[j_valid] # embeddings for all mesh vertices # -> tensor [K, D] mesh_vertex_embeddings = embedder(mesh_name) # softmax values of geodesic distances for GT mesh vertices # -> tensor [J, K] mesh = create_mesh(mesh_name, mesh_vertex_embeddings.device) geodist_softmax_values = F.softmax( mesh.geodists[vertex_indices_i] / (-self.geodist_gauss_sigma), dim=1 ) # logsoftmax values for valid points # -> tensor [J, K] embdist_logsoftmax_values = F.log_softmax( squared_euclidean_distance_matrix(vertex_embeddings_i, mesh_vertex_embeddings) / (-self.embdist_gauss_sigma), dim=1, ) losses[mesh_name] = (-geodist_softmax_values * embdist_logsoftmax_values).sum(1).mean() for mesh_name in embedder.mesh_names: if mesh_name not in losses: losses[mesh_name] = self.fake_value( densepose_predictor_outputs, embedder, mesh_name ) return losses def fake_values(self, densepose_predictor_outputs: Any, embedder: nn.Module): losses = {} for mesh_name in embedder.mesh_names: losses[mesh_name] = self.fake_value(densepose_predictor_outputs, embedder, mesh_name) return losses def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module, mesh_name: str): return densepose_predictor_outputs.embedding.sum() * 0 + embedder(mesh_name).sum() * 0