IDM-VTON
update IDM-VTON Demo
938e515
raw
history blame
No virus
4.76 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, List
from torch import nn
from detectron2.config import CfgNode
from detectron2.structures import Instances
from .cycle_pix2shape import PixToShapeCycleLoss
from .cycle_shape2shape import ShapeToShapeCycleLoss
from .embed import EmbeddingLoss
from .embed_utils import CseAnnotationsAccumulator
from .mask_or_segm import MaskOrSegmentationLoss
from .registry import DENSEPOSE_LOSS_REGISTRY
from .soft_embed import SoftEmbeddingLoss
from .utils import BilinearInterpolationHelper, LossDict, extract_packed_annotations_from_matches
@DENSEPOSE_LOSS_REGISTRY.register()
class DensePoseCseLoss:
""" """
_EMBED_LOSS_REGISTRY = {
EmbeddingLoss.__name__: EmbeddingLoss,
SoftEmbeddingLoss.__name__: SoftEmbeddingLoss,
}
def __init__(self, cfg: CfgNode):
"""
Initialize CSE loss from configuration options
Args:
cfg (CfgNode): configuration options
"""
self.w_segm = cfg.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS
self.w_embed = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT
self.segm_loss = MaskOrSegmentationLoss(cfg)
self.embed_loss = DensePoseCseLoss.create_embed_loss(cfg)
self.do_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.ENABLED
if self.do_shape2shape:
self.w_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT
self.shape2shape_loss = ShapeToShapeCycleLoss(cfg)
self.do_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.ENABLED
if self.do_pix2shape:
self.w_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT
self.pix2shape_loss = PixToShapeCycleLoss(cfg)
@classmethod
def create_embed_loss(cls, cfg: CfgNode):
# registry not used here, since embedding losses are currently local
# and are not used anywhere else
return cls._EMBED_LOSS_REGISTRY[cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME](cfg)
def __call__(
self,
proposals_with_gt: List[Instances],
densepose_predictor_outputs: Any,
embedder: nn.Module,
) -> LossDict:
if not len(proposals_with_gt):
return self.produce_fake_losses(densepose_predictor_outputs, embedder)
accumulator = CseAnnotationsAccumulator()
packed_annotations = extract_packed_annotations_from_matches(proposals_with_gt, accumulator)
if packed_annotations is None:
return self.produce_fake_losses(densepose_predictor_outputs, embedder)
h, w = densepose_predictor_outputs.embedding.shape[2:]
interpolator = BilinearInterpolationHelper.from_matches(
packed_annotations,
(h, w),
)
meshid_to_embed_losses = self.embed_loss(
proposals_with_gt,
densepose_predictor_outputs,
packed_annotations,
interpolator,
embedder,
)
embed_loss_dict = {
f"loss_densepose_E{meshid}": self.w_embed * meshid_to_embed_losses[meshid]
for meshid in meshid_to_embed_losses
}
all_loss_dict = {
"loss_densepose_S": self.w_segm
* self.segm_loss(proposals_with_gt, densepose_predictor_outputs, packed_annotations),
**embed_loss_dict,
}
if self.do_shape2shape:
all_loss_dict["loss_shape2shape"] = self.w_shape2shape * self.shape2shape_loss(embedder)
if self.do_pix2shape:
all_loss_dict["loss_pix2shape"] = self.w_pix2shape * self.pix2shape_loss(
proposals_with_gt, densepose_predictor_outputs, packed_annotations, embedder
)
return all_loss_dict
def produce_fake_losses(
self, densepose_predictor_outputs: Any, embedder: nn.Module
) -> LossDict:
meshname_to_embed_losses = self.embed_loss.fake_values(
densepose_predictor_outputs, embedder=embedder
)
embed_loss_dict = {
f"loss_densepose_E{mesh_name}": meshname_to_embed_losses[mesh_name]
for mesh_name in meshname_to_embed_losses
}
all_loss_dict = {
"loss_densepose_S": self.segm_loss.fake_value(densepose_predictor_outputs),
**embed_loss_dict,
}
if self.do_shape2shape:
all_loss_dict["loss_shape2shape"] = self.shape2shape_loss.fake_value(embedder)
if self.do_pix2shape:
all_loss_dict["loss_pix2shape"] = self.pix2shape_loss.fake_value(
densepose_predictor_outputs, embedder
)
return all_loss_dict