| | |
| |
|
| | import random |
| | from typing import Optional, Tuple |
| | import torch |
| | from torch.nn import functional as F |
| |
|
| | from detectron2.config import CfgNode |
| | from detectron2.structures import Instances |
| |
|
| | from densepose.converters.base import IntTupleBox |
| |
|
| | from .densepose_cse_base import DensePoseCSEBaseSampler |
| |
|
| |
|
| | class DensePoseCSEConfidenceBasedSampler(DensePoseCSEBaseSampler): |
| | """ |
| | Samples DensePose data from DensePose predictions. |
| | Samples for each class are drawn using confidence value estimates. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | cfg: CfgNode, |
| | use_gt_categories: bool, |
| | embedder: torch.nn.Module, |
| | confidence_channel: str, |
| | count_per_class: int = 8, |
| | search_count_multiplier: Optional[float] = None, |
| | search_proportion: Optional[float] = None, |
| | ): |
| | """ |
| | Constructor |
| | |
| | Args: |
| | cfg (CfgNode): the config of the model |
| | embedder (torch.nn.Module): necessary to compute mesh vertex embeddings |
| | confidence_channel (str): confidence channel to use for sampling; |
| | possible values: |
| | "coarse_segm_confidence": confidences for coarse segmentation |
| | (default: "coarse_segm_confidence") |
| | count_per_class (int): the sampler produces at most `count_per_class` |
| | samples for each category (default: 8) |
| | search_count_multiplier (float or None): if not None, the total number |
| | of the most confident estimates of a given class to consider is |
| | defined as `min(search_count_multiplier * count_per_class, N)`, |
| | where `N` is the total number of estimates of the class; cannot be |
| | specified together with `search_proportion` (default: None) |
| | search_proportion (float or None): if not None, the total number of the |
| | of the most confident estimates of a given class to consider is |
| | defined as `min(max(search_proportion * N, count_per_class), N)`, |
| | where `N` is the total number of estimates of the class; cannot be |
| | specified together with `search_count_multiplier` (default: None) |
| | """ |
| | super().__init__(cfg, use_gt_categories, embedder, count_per_class) |
| | self.confidence_channel = confidence_channel |
| | self.search_count_multiplier = search_count_multiplier |
| | self.search_proportion = search_proportion |
| | assert (search_count_multiplier is None) or (search_proportion is None), ( |
| | f"Cannot specify both search_count_multiplier (={search_count_multiplier})" |
| | f"and search_proportion (={search_proportion})" |
| | ) |
| |
|
| | def _produce_index_sample(self, values: torch.Tensor, count: int): |
| | """ |
| | Produce a sample of indices to select data based on confidences |
| | |
| | Args: |
| | values (torch.Tensor): a tensor of length k that contains confidences |
| | k: number of points labeled with part_id |
| | count (int): number of samples to produce, should be positive and <= k |
| | |
| | Return: |
| | list(int): indices of values (along axis 1) selected as a sample |
| | """ |
| | k = values.shape[1] |
| | if k == count: |
| | index_sample = list(range(k)) |
| | else: |
| | |
| | |
| | |
| | _, sorted_confidence_indices = torch.sort(values[0]) |
| | if self.search_count_multiplier is not None: |
| | search_count = min(int(count * self.search_count_multiplier), k) |
| | elif self.search_proportion is not None: |
| | search_count = min(max(int(k * self.search_proportion), count), k) |
| | else: |
| | search_count = min(count, k) |
| | sample_from_top = random.sample(range(search_count), count) |
| | index_sample = sorted_confidence_indices[-search_count:][sample_from_top] |
| | return index_sample |
| |
|
| | def _produce_mask_and_results( |
| | self, instance: Instances, bbox_xywh: IntTupleBox |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Method to get labels and DensePose results from an instance |
| | |
| | Args: |
| | instance (Instances): an instance of |
| | `DensePoseEmbeddingPredictorOutputWithConfidences` |
| | bbox_xywh (IntTupleBox): the corresponding bounding box |
| | |
| | Return: |
| | mask (torch.Tensor): shape [H, W], DensePose segmentation mask |
| | embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W] |
| | DensePose CSE Embeddings |
| | other_values: a tensor of shape [1, H, W], DensePose CSE confidence |
| | """ |
| | _, _, w, h = bbox_xywh |
| | densepose_output = instance.pred_densepose |
| | mask, embeddings, _ = super()._produce_mask_and_results(instance, bbox_xywh) |
| | other_values = F.interpolate( |
| | getattr(densepose_output, self.confidence_channel), |
| | size=(h, w), |
| | mode="bilinear", |
| | )[0].cpu() |
| | return mask, embeddings, other_values |
| |
|