| | |
| |
|
| | from typing import Any, List |
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| | from detectron2.config import CfgNode |
| | from detectron2.structures import Instances |
| |
|
| | from densepose.data.meshes.catalog import MeshCatalog |
| | from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix |
| |
|
| | from .embed_utils import PackedCseAnnotations |
| | from .mask import extract_data_for_mask_loss_from_matches |
| |
|
| |
|
| | def _create_pixel_dist_matrix(grid_size: int) -> torch.Tensor: |
| | rows = torch.arange(grid_size) |
| | cols = torch.arange(grid_size) |
| | |
| | |
| | |
| | pix_coords = ( |
| | torch.stack(torch.meshgrid(rows, cols), -1).reshape((grid_size * grid_size, 2)).float() |
| | ) |
| | return squared_euclidean_distance_matrix(pix_coords, pix_coords) |
| |
|
| |
|
| | def _sample_fg_pixels_randperm(fg_mask: torch.Tensor, sample_size: int) -> torch.Tensor: |
| | fg_mask_flattened = fg_mask.reshape((-1,)) |
| | num_pixels = int(fg_mask_flattened.sum().item()) |
| | fg_pixel_indices = fg_mask_flattened.nonzero(as_tuple=True)[0] |
| | if (sample_size <= 0) or (num_pixels <= sample_size): |
| | return fg_pixel_indices |
| | sample_indices = torch.randperm(num_pixels, device=fg_mask.device)[:sample_size] |
| | return fg_pixel_indices[sample_indices] |
| |
|
| |
|
| | def _sample_fg_pixels_multinomial(fg_mask: torch.Tensor, sample_size: int) -> torch.Tensor: |
| | fg_mask_flattened = fg_mask.reshape((-1,)) |
| | num_pixels = int(fg_mask_flattened.sum().item()) |
| | if (sample_size <= 0) or (num_pixels <= sample_size): |
| | return fg_mask_flattened.nonzero(as_tuple=True)[0] |
| | return fg_mask_flattened.float().multinomial(sample_size, replacement=False) |
| |
|
| |
|
| | class PixToShapeCycleLoss(nn.Module): |
| | """ |
| | Cycle loss for pixel-vertex correspondence |
| | """ |
| |
|
| | def __init__(self, cfg: CfgNode): |
| | super().__init__() |
| | self.shape_names = list(cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.keys()) |
| | self.embed_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE |
| | self.norm_p = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P |
| | self.use_all_meshes_not_gt_only = ( |
| | cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY |
| | ) |
| | self.num_pixels_to_sample = ( |
| | cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE |
| | ) |
| | self.pix_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA |
| | self.temperature_pix_to_vertex = ( |
| | cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX |
| | ) |
| | self.temperature_vertex_to_pix = ( |
| | cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL |
| | ) |
| | self.pixel_dists = _create_pixel_dist_matrix(cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE) |
| |
|
| | def forward( |
| | self, |
| | proposals_with_gt: List[Instances], |
| | densepose_predictor_outputs: Any, |
| | packed_annotations: PackedCseAnnotations, |
| | embedder: nn.Module, |
| | ): |
| | """ |
| | Args: |
| | proposals_with_gt (list of Instances): detections with associated |
| | ground truth data; each item corresponds to instances detected |
| | on 1 image; the number of items corresponds to the number of |
| | images in a batch |
| | densepose_predictor_outputs: an object of a dataclass that contains predictor |
| | outputs with estimated values; assumed to have the following attributes: |
| | * embedding - embedding estimates, tensor of shape [N, D, S, S], where |
| | N = number of instances (= sum N_i, where N_i is the number of |
| | instances on image i) |
| | D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE) |
| | S = output size (width and height) |
| | packed_annotations (PackedCseAnnotations): contains various data useful |
| | for loss computation, each data is packed into a single tensor |
| | embedder (nn.Module): module that computes vertex embeddings for different meshes |
| | """ |
| | pix_embeds = densepose_predictor_outputs.embedding |
| | if self.pixel_dists.device != pix_embeds.device: |
| | |
| | self.pixel_dists = self.pixel_dists.to(device=pix_embeds.device) |
| | with torch.no_grad(): |
| | mask_loss_data = extract_data_for_mask_loss_from_matches( |
| | proposals_with_gt, densepose_predictor_outputs.coarse_segm |
| | ) |
| | |
| | masks_gt = mask_loss_data.masks_gt.long() |
| | assert len(pix_embeds) == len(masks_gt), ( |
| | f"Number of instances with embeddings {len(pix_embeds)} != " |
| | f"number of instances with GT masks {len(masks_gt)}" |
| | ) |
| | losses = [] |
| | mesh_names = ( |
| | self.shape_names |
| | if self.use_all_meshes_not_gt_only |
| | else [ |
| | MeshCatalog.get_mesh_name(mesh_id.item()) |
| | for mesh_id in packed_annotations.vertex_mesh_ids_gt.unique() |
| | ] |
| | ) |
| | for pixel_embeddings, mask_gt in zip(pix_embeds, masks_gt): |
| | |
| | |
| | for mesh_name in mesh_names: |
| | mesh_vertex_embeddings = embedder(mesh_name) |
| | |
| | pixel_indices_flattened = _sample_fg_pixels_randperm( |
| | mask_gt, self.num_pixels_to_sample |
| | ) |
| | |
| | pixel_dists = self.pixel_dists.to(pixel_embeddings.device)[ |
| | torch.meshgrid(pixel_indices_flattened, pixel_indices_flattened) |
| | ] |
| | |
| | pixel_embeddings_sampled = normalize_embeddings( |
| | pixel_embeddings.reshape((self.embed_size, -1))[:, pixel_indices_flattened].T |
| | ) |
| | |
| | sim_matrix = pixel_embeddings_sampled.mm(mesh_vertex_embeddings.T) |
| | c_pix_vertex = F.softmax(sim_matrix / self.temperature_pix_to_vertex, dim=1) |
| | c_vertex_pix = F.softmax(sim_matrix.T / self.temperature_vertex_to_pix, dim=1) |
| | c_cycle = c_pix_vertex.mm(c_vertex_pix) |
| | loss_cycle = torch.norm(pixel_dists * c_cycle, p=self.norm_p) |
| | losses.append(loss_cycle) |
| |
|
| | if len(losses) == 0: |
| | return pix_embeds.sum() * 0 |
| | return torch.stack(losses, dim=0).mean() |
| |
|
| | def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module): |
| | losses = [embedder(mesh_name).sum() * 0 for mesh_name in embedder.mesh_names] |
| | losses.append(densepose_predictor_outputs.embedding.sum() * 0) |
| | return torch.mean(torch.stack(losses)) |
| |
|