| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Union |
| |
|
| | from mmengine.structures import BaseDataElement, InstanceData, PixelData |
| |
|
| | from mmpose.structures import MultilevelPixelData |
| |
|
| |
|
| | class PoseDataSample(BaseDataElement): |
| | """The base data structure of MMPose that is used as the interface between |
| | modules. |
| | |
| | The attributes of ``PoseDataSample`` includes: |
| | |
| | - ``gt_instances``(InstanceData): Ground truth of instances with |
| | keypoint annotations |
| | - ``pred_instances``(InstanceData): Instances with keypoint |
| | predictions |
| | - ``gt_fields``(PixelData): Ground truth of spatial distribution |
| | annotations like keypoint heatmaps and part affine fields (PAF) |
| | - ``pred_fields``(PixelData): Predictions of spatial distributions |
| | |
| | Examples: |
| | >>> import torch |
| | >>> from mmengine.structures import InstanceData, PixelData |
| | >>> from mmpose.structures import PoseDataSample |
| | |
| | >>> pose_meta = dict(img_shape=(800, 1216), |
| | ... crop_size=(256, 192), |
| | ... heatmap_size=(64, 48)) |
| | >>> gt_instances = InstanceData() |
| | >>> gt_instances.bboxes = torch.rand((1, 4)) |
| | >>> gt_instances.keypoints = torch.rand((1, 17, 2)) |
| | >>> gt_instances.keypoints_visible = torch.rand((1, 17, 1)) |
| | >>> gt_fields = PixelData() |
| | >>> gt_fields.heatmaps = torch.rand((17, 64, 48)) |
| | |
| | >>> data_sample = PoseDataSample(gt_instances=gt_instances, |
| | ... gt_fields=gt_fields, |
| | ... metainfo=pose_meta) |
| | >>> assert 'img_shape' in data_sample |
| | >>> len(data_sample.gt_intances) |
| | 1 |
| | """ |
| |
|
| | @property |
| | def gt_instances(self) -> InstanceData: |
| | return self._gt_instances |
| |
|
| | @gt_instances.setter |
| | def gt_instances(self, value: InstanceData): |
| | self.set_field(value, '_gt_instances', dtype=InstanceData) |
| |
|
| | @gt_instances.deleter |
| | def gt_instances(self): |
| | del self._gt_instances |
| |
|
| | @property |
| | def gt_instance_labels(self) -> InstanceData: |
| | return self._gt_instance_labels |
| |
|
| | @gt_instance_labels.setter |
| | def gt_instance_labels(self, value: InstanceData): |
| | self.set_field(value, '_gt_instance_labels', dtype=InstanceData) |
| |
|
| | @gt_instance_labels.deleter |
| | def gt_instance_labels(self): |
| | del self._gt_instance_labels |
| |
|
| | @property |
| | def pred_instances(self) -> InstanceData: |
| | return self._pred_instances |
| |
|
| | @pred_instances.setter |
| | def pred_instances(self, value: InstanceData): |
| | self.set_field(value, '_pred_instances', dtype=InstanceData) |
| |
|
| | @pred_instances.deleter |
| | def pred_instances(self): |
| | del self._pred_instances |
| |
|
| | @property |
| | def gt_fields(self) -> Union[PixelData, MultilevelPixelData]: |
| | return self._gt_fields |
| |
|
| | @gt_fields.setter |
| | def gt_fields(self, value: Union[PixelData, MultilevelPixelData]): |
| | self.set_field(value, '_gt_fields', dtype=type(value)) |
| |
|
| | @gt_fields.deleter |
| | def gt_fields(self): |
| | del self._gt_fields |
| |
|
| | @property |
| | def pred_fields(self) -> PixelData: |
| | return self._pred_heatmaps |
| |
|
| | @pred_fields.setter |
| | def pred_fields(self, value: PixelData): |
| | self.set_field(value, '_pred_heatmaps', dtype=PixelData) |
| |
|
| | @pred_fields.deleter |
| | def pred_fields(self): |
| | del self._pred_heatmaps |
| |
|