# Copyright (c) Facebook, Inc. and its affiliates. from dataclasses import dataclass from typing import Any, Optional, Tuple import torch @dataclass class DensePoseChartResult: """ DensePose results for chart-based methods represented by labels and inner coordinates (U, V) of individual charts. Each chart is a 2D manifold that has an associated label and is parameterized by two coordinates U and V. Both U and V take values in [0, 1]. Thus the results are represented by two tensors: - labels (tensor [H, W] of long): contains estimated label for each pixel of the detection bounding box of size (H, W) - uv (tensor [2, H, W] of float): contains estimated U and V coordinates for each pixel of the detection bounding box of size (H, W) """ labels: torch.Tensor uv: torch.Tensor def to(self, device: torch.device): """ Transfers all tensors to the given device """ labels = self.labels.to(device) uv = self.uv.to(device) return DensePoseChartResult(labels=labels, uv=uv) @dataclass class DensePoseChartResultWithConfidences: """ We add confidence values to DensePoseChartResult Thus the results are represented by two tensors: - labels (tensor [H, W] of long): contains estimated label for each pixel of the detection bounding box of size (H, W) - uv (tensor [2, H, W] of float): contains estimated U and V coordinates for each pixel of the detection bounding box of size (H, W) Plus one [H, W] tensor of float for each confidence type """ labels: torch.Tensor uv: torch.Tensor sigma_1: Optional[torch.Tensor] = None sigma_2: Optional[torch.Tensor] = None kappa_u: Optional[torch.Tensor] = None kappa_v: Optional[torch.Tensor] = None fine_segm_confidence: Optional[torch.Tensor] = None coarse_segm_confidence: Optional[torch.Tensor] = None def to(self, device: torch.device): """ Transfers all tensors to the given device, except if their value is None """ def to_device_if_tensor(var: Any): if isinstance(var, torch.Tensor): return var.to(device) return var return DensePoseChartResultWithConfidences( labels=self.labels.to(device), uv=self.uv.to(device), sigma_1=to_device_if_tensor(self.sigma_1), sigma_2=to_device_if_tensor(self.sigma_2), kappa_u=to_device_if_tensor(self.kappa_u), kappa_v=to_device_if_tensor(self.kappa_v), fine_segm_confidence=to_device_if_tensor(self.fine_segm_confidence), coarse_segm_confidence=to_device_if_tensor(self.coarse_segm_confidence), ) @dataclass class DensePoseChartResultQuantized: """ DensePose results for chart-based methods represented by labels and quantized inner coordinates (U, V) of individual charts. Each chart is a 2D manifold that has an associated label and is parameterized by two coordinates U and V. Both U and V take values in [0, 1]. Quantized coordinates Uq and Vq have uint8 values which are obtained as: Uq = U * 255 (hence 0 <= Uq <= 255) Vq = V * 255 (hence 0 <= Vq <= 255) Thus the results are represented by one tensor: - labels_uv_uint8 (tensor [3, H, W] of uint8): contains estimated label and quantized coordinates Uq and Vq for each pixel of the detection bounding box of size (H, W) """ labels_uv_uint8: torch.Tensor def to(self, device: torch.device): """ Transfers all tensors to the given device """ labels_uv_uint8 = self.labels_uv_uint8.to(device) return DensePoseChartResultQuantized(labels_uv_uint8=labels_uv_uint8) @dataclass class DensePoseChartResultCompressed: """ DensePose results for chart-based methods represented by a PNG-encoded string. The tensor of quantized DensePose results of size [3, H, W] is considered as an image with 3 color channels. PNG compression is applied and the result is stored as a Base64-encoded string. The following attributes are defined: - shape_chw (tuple of 3 int): contains shape of the result tensor (number of channels, height, width) - labels_uv_str (str): contains Base64-encoded results tensor of size [3, H, W] compressed with PNG compression methods """ shape_chw: Tuple[int, int, int] labels_uv_str: str def quantize_densepose_chart_result(result: DensePoseChartResult) -> DensePoseChartResultQuantized: """ Applies quantization to DensePose chart-based result. Args: result (DensePoseChartResult): DensePose chart-based result Return: Quantized DensePose chart-based result (DensePoseChartResultQuantized) """ h, w = result.labels.shape labels_uv_uint8 = torch.zeros([3, h, w], dtype=torch.uint8, device=result.labels.device) labels_uv_uint8[0] = result.labels labels_uv_uint8[1:] = (result.uv * 255).clamp(0, 255).byte() return DensePoseChartResultQuantized(labels_uv_uint8=labels_uv_uint8) def compress_quantized_densepose_chart_result( result: DensePoseChartResultQuantized, ) -> DensePoseChartResultCompressed: """ Compresses quantized DensePose chart-based result Args: result (DensePoseChartResultQuantized): quantized DensePose chart-based result Return: Compressed DensePose chart-based result (DensePoseChartResultCompressed) """ import base64 import numpy as np from io import BytesIO from PIL import Image labels_uv_uint8_np_chw = result.labels_uv_uint8.cpu().numpy() labels_uv_uint8_np_hwc = np.moveaxis(labels_uv_uint8_np_chw, 0, -1) im = Image.fromarray(labels_uv_uint8_np_hwc) fstream = BytesIO() im.save(fstream, format="png", optimize=True) labels_uv_str = base64.encodebytes(fstream.getvalue()).decode() shape_chw = labels_uv_uint8_np_chw.shape return DensePoseChartResultCompressed(labels_uv_str=labels_uv_str, shape_chw=shape_chw) def decompress_compressed_densepose_chart_result( result: DensePoseChartResultCompressed, ) -> DensePoseChartResultQuantized: """ Decompresses DensePose chart-based result encoded into a base64 string Args: result (DensePoseChartResultCompressed): compressed DensePose chart result Return: Quantized DensePose chart-based result (DensePoseChartResultQuantized) """ import base64 import numpy as np from io import BytesIO from PIL import Image fstream = BytesIO(base64.decodebytes(result.labels_uv_str.encode())) im = Image.open(fstream) labels_uv_uint8_np_chw = np.moveaxis(np.array(im, dtype=np.uint8), -1, 0) return DensePoseChartResultQuantized( labels_uv_uint8=torch.from_numpy(labels_uv_uint8_np_chw.reshape(result.shape_chw)) )