IDM-VTON
update IDM-VTON Demo
938e515
raw history blame
No virus
6.33 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, Dict, List
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.structures import Instances
from densepose.data.meshes.catalog import MeshCatalog
from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix
from densepose.structures.mesh import create_mesh
from .embed_utils import PackedCseAnnotations
from .utils import BilinearInterpolationHelper
class SoftEmbeddingLoss:
"""
Computes losses for estimated embeddings given annotated vertices.
Instances in a minibatch that correspond to the same mesh are grouped
together. For each group, loss is computed as cross-entropy for
unnormalized scores given ground truth mesh vertex ids.
Scores are based on:
1) squared distances between estimated vertex embeddings
and mesh vertex embeddings;
2) geodesic distances between vertices of a mesh
"""
def __init__(self, cfg: CfgNode):
"""
Initialize embedding loss from config
"""
self.embdist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA
self.geodist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA
def __call__(
self,
proposals_with_gt: List[Instances],
densepose_predictor_outputs: Any,
packed_annotations: PackedCseAnnotations,
interpolator: BilinearInterpolationHelper,
embedder: nn.Module,
) -> Dict[int, torch.Tensor]:
"""
Produces losses for estimated embeddings given annotated vertices.
Embeddings for all the vertices of a mesh are computed by the embedder.
Embeddings for observed pixels are estimated by a predictor.
Losses are computed as cross-entropy for unnormalized scores given
ground truth vertex IDs.
1) squared distances between estimated vertex embeddings
and mesh vertex embeddings;
2) geodesic distances between vertices of a mesh
Args:
proposals_with_gt (list of Instances): detections with associated
ground truth data; each item corresponds to instances detected
on 1 image; the number of items corresponds to the number of
images in a batch
densepose_predictor_outputs: an object of a dataclass that contains predictor
outputs with estimated values; assumed to have the following attributes:
* embedding - embedding estimates, tensor of shape [N, D, S, S], where
N = number of instances (= sum N_i, where N_i is the number of
instances on image i)
D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE)
S = output size (width and height)
packed_annotations (PackedCseAnnotations): contains various data useful
for loss computation, each data is packed into a single tensor
interpolator (BilinearInterpolationHelper): bilinear interpolation helper
embedder (nn.Module): module that computes vertex embeddings for different meshes
Return:
dict(int -> tensor): losses for different mesh IDs
"""
losses = {}
for mesh_id_tensor in packed_annotations.vertex_mesh_ids_gt.unique():
mesh_id = mesh_id_tensor.item()
mesh_name = MeshCatalog.get_mesh_name(mesh_id)
# valid points are those that fall into estimated bbox
# and correspond to the current mesh
j_valid = interpolator.j_valid * ( # pyre-ignore[16]
packed_annotations.vertex_mesh_ids_gt == mesh_id
)
if not torch.any(j_valid):
continue
# extract estimated embeddings for valid points
# -> tensor [J, D]
vertex_embeddings_i = normalize_embeddings(
interpolator.extract_at_points(
densepose_predictor_outputs.embedding,
slice_fine_segm=slice(None),
w_ylo_xlo=interpolator.w_ylo_xlo[:, None], # pyre-ignore[16]
w_ylo_xhi=interpolator.w_ylo_xhi[:, None], # pyre-ignore[16]
w_yhi_xlo=interpolator.w_yhi_xlo[:, None], # pyre-ignore[16]
w_yhi_xhi=interpolator.w_yhi_xhi[:, None], # pyre-ignore[16]
)[j_valid, :]
)
# extract vertex ids for valid points
# -> tensor [J]
vertex_indices_i = packed_annotations.vertex_ids_gt[j_valid]
# embeddings for all mesh vertices
# -> tensor [K, D]
mesh_vertex_embeddings = embedder(mesh_name)
# softmax values of geodesic distances for GT mesh vertices
# -> tensor [J, K]
mesh = create_mesh(mesh_name, mesh_vertex_embeddings.device)
geodist_softmax_values = F.softmax(
mesh.geodists[vertex_indices_i] / (-self.geodist_gauss_sigma), dim=1
)
# logsoftmax values for valid points
# -> tensor [J, K]
embdist_logsoftmax_values = F.log_softmax(
squared_euclidean_distance_matrix(vertex_embeddings_i, mesh_vertex_embeddings)
/ (-self.embdist_gauss_sigma),
dim=1,
)
losses[mesh_name] = (-geodist_softmax_values * embdist_logsoftmax_values).sum(1).mean()
for mesh_name in embedder.mesh_names:
if mesh_name not in losses:
losses[mesh_name] = self.fake_value(
densepose_predictor_outputs, embedder, mesh_name
)
return losses
def fake_values(self, densepose_predictor_outputs: Any, embedder: nn.Module):
losses = {}
for mesh_name in embedder.mesh_names:
losses[mesh_name] = self.fake_value(densepose_predictor_outputs, embedder, mesh_name)
return losses
def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module, mesh_name: str):
return densepose_predictor_outputs.embedding.sum() * 0 + embedder(mesh_name).sum() * 0