Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| from dataclasses import dataclass | |
| from typing import Any, Optional | |
| import torch | |
| from detectron2.structures import BoxMode, Instances | |
| from .utils import AnnotationsAccumulator | |
| class PackedCseAnnotations: | |
| x_gt: torch.Tensor | |
| y_gt: torch.Tensor | |
| coarse_segm_gt: Optional[torch.Tensor] | |
| vertex_mesh_ids_gt: torch.Tensor | |
| vertex_ids_gt: torch.Tensor | |
| bbox_xywh_gt: torch.Tensor | |
| bbox_xywh_est: torch.Tensor | |
| point_bbox_with_dp_indices: torch.Tensor | |
| point_bbox_indices: torch.Tensor | |
| bbox_indices: torch.Tensor | |
| class CseAnnotationsAccumulator(AnnotationsAccumulator): | |
| """ | |
| Accumulates annotations by batches that correspond to objects detected on | |
| individual images. Can pack them together into single tensors. | |
| """ | |
| def __init__(self): | |
| self.x_gt = [] | |
| self.y_gt = [] | |
| self.s_gt = [] | |
| self.vertex_mesh_ids_gt = [] | |
| self.vertex_ids_gt = [] | |
| self.bbox_xywh_gt = [] | |
| self.bbox_xywh_est = [] | |
| self.point_bbox_with_dp_indices = [] | |
| self.point_bbox_indices = [] | |
| self.bbox_indices = [] | |
| self.nxt_bbox_with_dp_index = 0 | |
| self.nxt_bbox_index = 0 | |
| def accumulate(self, instances_one_image: Instances): | |
| """ | |
| Accumulate instances data for one image | |
| Args: | |
| instances_one_image (Instances): instances data to accumulate | |
| """ | |
| boxes_xywh_est = BoxMode.convert( | |
| instances_one_image.proposal_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS | |
| ) | |
| boxes_xywh_gt = BoxMode.convert( | |
| instances_one_image.gt_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS | |
| ) | |
| n_matches = len(boxes_xywh_gt) | |
| assert n_matches == len( | |
| boxes_xywh_est | |
| ), f"Got {len(boxes_xywh_est)} proposal boxes and {len(boxes_xywh_gt)} GT boxes" | |
| if not n_matches: | |
| # no detection - GT matches | |
| return | |
| if ( | |
| not hasattr(instances_one_image, "gt_densepose") | |
| or instances_one_image.gt_densepose is None | |
| ): | |
| # no densepose GT for the detections, just increase the bbox index | |
| self.nxt_bbox_index += n_matches | |
| return | |
| for box_xywh_est, box_xywh_gt, dp_gt in zip( | |
| boxes_xywh_est, boxes_xywh_gt, instances_one_image.gt_densepose | |
| ): | |
| if (dp_gt is not None) and (len(dp_gt.x) > 0): | |
| # pyre-fixme[6]: For 1st argument expected `Tensor` but got `float`. | |
| # pyre-fixme[6]: For 2nd argument expected `Tensor` but got `float`. | |
| self._do_accumulate(box_xywh_gt, box_xywh_est, dp_gt) | |
| self.nxt_bbox_index += 1 | |
| def _do_accumulate(self, box_xywh_gt: torch.Tensor, box_xywh_est: torch.Tensor, dp_gt: Any): | |
| """ | |
| Accumulate instances data for one image, given that the data is not empty | |
| Args: | |
| box_xywh_gt (tensor): GT bounding box | |
| box_xywh_est (tensor): estimated bounding box | |
| dp_gt: GT densepose data with the following attributes: | |
| - x: normalized X coordinates | |
| - y: normalized Y coordinates | |
| - segm: tensor of size [S, S] with coarse segmentation | |
| - | |
| """ | |
| self.x_gt.append(dp_gt.x) | |
| self.y_gt.append(dp_gt.y) | |
| if hasattr(dp_gt, "segm"): | |
| self.s_gt.append(dp_gt.segm.unsqueeze(0)) | |
| self.vertex_ids_gt.append(dp_gt.vertex_ids) | |
| self.vertex_mesh_ids_gt.append(torch.full_like(dp_gt.vertex_ids, dp_gt.mesh_id)) | |
| self.bbox_xywh_gt.append(box_xywh_gt.view(-1, 4)) | |
| self.bbox_xywh_est.append(box_xywh_est.view(-1, 4)) | |
| self.point_bbox_with_dp_indices.append( | |
| torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_with_dp_index) | |
| ) | |
| self.point_bbox_indices.append(torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_index)) | |
| self.bbox_indices.append(self.nxt_bbox_index) | |
| self.nxt_bbox_with_dp_index += 1 | |
| def pack(self) -> Optional[PackedCseAnnotations]: | |
| """ | |
| Pack data into tensors | |
| """ | |
| if not len(self.x_gt): | |
| # TODO: | |
| # returning proper empty annotations would require | |
| # creating empty tensors of appropriate shape and | |
| # type on an appropriate device; | |
| # we return None so far to indicate empty annotations | |
| return None | |
| return PackedCseAnnotations( | |
| x_gt=torch.cat(self.x_gt, 0), | |
| y_gt=torch.cat(self.y_gt, 0), | |
| vertex_mesh_ids_gt=torch.cat(self.vertex_mesh_ids_gt, 0), | |
| vertex_ids_gt=torch.cat(self.vertex_ids_gt, 0), | |
| # ignore segmentation annotations, if not all the instances contain those | |
| coarse_segm_gt=torch.cat(self.s_gt, 0) | |
| if len(self.s_gt) == len(self.bbox_xywh_gt) | |
| else None, | |
| bbox_xywh_gt=torch.cat(self.bbox_xywh_gt, 0), | |
| bbox_xywh_est=torch.cat(self.bbox_xywh_est, 0), | |
| point_bbox_with_dp_indices=torch.cat(self.point_bbox_with_dp_indices, 0), | |
| point_bbox_indices=torch.cat(self.point_bbox_indices, 0), | |
| bbox_indices=torch.as_tensor( | |
| self.bbox_indices, dtype=torch.long, device=self.x_gt[0].device | |
| ), | |
| ) | |