# Copyright (c) Facebook, Inc. and its affiliates.

from typing import Any, Dict, List, Tuple
import torch
from torch.nn import functional as F

from detectron2.config import CfgNode
from detectron2.structures import Instances

from densepose.converters.base import IntTupleBox
from densepose.data.utils import get_class_to_mesh_name_mapping
from densepose.modeling.cse.utils import squared_euclidean_distance_matrix
from densepose.structures import DensePoseDataRelative

from .densepose_base import DensePoseBaseSampler

class DensePoseCSEBaseSampler(DensePoseBaseSampler):
    Base DensePose sampler to produce DensePose data from DensePose predictions.
    Samples for each class are drawn according to some distribution over all pixels estimated
    to belong to that class.

    def __init__(
        cfg: CfgNode,
        use_gt_categories: bool,
        embedder: torch.nn.Module,
        count_per_class: int = 8,

          cfg (CfgNode): the config of the model
          embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
          count_per_class (int): the sampler produces at most `count_per_class`
              samples for each category
        self.embedder = embedder
        self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
        self.use_gt_categories = use_gt_categories

    def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
        Sample DensPoseDataRelative from estimation results
        if self.use_gt_categories:
            instance_class = instance.dataset_classes.tolist()[0]
            instance_class = instance.pred_classes.tolist()[0]
        mesh_name = self.class_to_mesh_name[instance_class]

        annotation = {
            DensePoseDataRelative.X_KEY: [],
            DensePoseDataRelative.Y_KEY: [],
            DensePoseDataRelative.VERTEX_IDS_KEY: [],
            DensePoseDataRelative.MESH_NAME_KEY: mesh_name,

        mask, embeddings, other_values = self._produce_mask_and_results(instance, bbox_xywh)
        indices = torch.nonzero(mask, as_tuple=True)
        selected_embeddings = embeddings.permute(1, 2, 0)[indices].cpu()
        values = other_values[:, indices[0], indices[1]]
        k = values.shape[1]

        count = min(self.count_per_class, k)
        if count <= 0:
            return annotation

        index_sample = self._produce_index_sample(values, count)
        closest_vertices = squared_euclidean_distance_matrix(
            selected_embeddings[index_sample], self.embedder(mesh_name)
        closest_vertices = torch.argmin(closest_vertices, dim=1)

        sampled_y = indices[0][index_sample] + 0.5
        sampled_x = indices[1][index_sample] + 0.5
        # prepare / normalize data
        _, _, w, h = bbox_xywh
        x = (sampled_x / w * 256.0).cpu().tolist()
        y = (sampled_y / h * 256.0).cpu().tolist()
        # extend annotations
        return annotation

    def _produce_mask_and_results(
        self, instance: Instances, bbox_xywh: IntTupleBox
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        Method to get labels and DensePose results from an instance

            instance (Instances): an instance of `DensePoseEmbeddingPredictorOutput`
            bbox_xywh (IntTupleBox): the corresponding bounding box

            mask (torch.Tensor): shape [H, W], DensePose segmentation mask
            embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W],
                DensePose CSE Embeddings
            other_values (Tuple[torch.Tensor]): a tensor of shape [0, H, W],
                for potential other values
        densepose_output = instance.pred_densepose
        S = densepose_output.coarse_segm
        E = densepose_output.embedding
        _, _, w, h = bbox_xywh
        embeddings = F.interpolate(E, size=(h, w), mode="bilinear")[0]
        coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0]
        mask = coarse_segm_resized.argmax(0) > 0
        other_values = torch.empty((0, h, w), device=E.device)
        return mask, embeddings, other_values

    def _resample_mask(self, output: Any) -> torch.Tensor:
        Convert DensePose predictor output to segmentation annotation - tensors of size
        (256, 256) and type `int64`.

            output: DensePose predictor output with the following attributes:
             - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
               segmentation scores
            Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
            where S = DensePoseDataRelative.MASK_SIZE
        sz = DensePoseDataRelative.MASK_SIZE
        mask = (
            F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
        return mask