Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from copy import deepcopy | |
| from typing import Any, Callable, List, Optional, Tuple, Union, Dict | |
| import numpy as np | |
| from mmengine.dataset import BaseDataset | |
| from mmengine.registry import build_from_cfg | |
| from mmpose.registry import DATASETS | |
| from .datasets.utils import parse_pose_metainfo | |
| class CombinedDataset(BaseDataset): | |
| """A wrapper of combined dataset. | |
| Args: | |
| metainfo (dict): The meta information of combined dataset. | |
| datasets (list): The configs of datasets to be combined. | |
| pipeline (list, optional): Processing pipeline. Defaults to []. | |
| sample_ratio_factor (list, optional): A list of sampling ratio | |
| factors for each dataset. Defaults to None | |
| """ | |
| def __init__(self, | |
| metainfo: dict, | |
| datasets: list, | |
| pipeline: List[Union[dict, Callable]] = [], | |
| sample_ratio_factor: Optional[List[float]] = None, | |
| dataset_ratio_factor: Optional[List[float]] = None, | |
| keypoints_mapping: Optional[List[Dict]] = None, | |
| **kwargs): | |
| self.datasets = [] | |
| self.resample = sample_ratio_factor is not None | |
| self.keypoints_mapping = keypoints_mapping | |
| self.num_joints = None | |
| if self.keypoints_mapping is not None: | |
| self.num_joints = 0 | |
| for mapping in self.keypoints_mapping: | |
| self.num_joints = max(self.num_joints, max(mapping.values()) +1) | |
| for cfg in datasets: | |
| dataset = build_from_cfg(cfg, DATASETS) | |
| self.datasets.append(dataset) | |
| # For each dataset, select its random subset based on the sample_ratio_factor | |
| if dataset_ratio_factor is not None: | |
| for i, dataset in enumerate(self.datasets): | |
| dataset_len = len(dataset) | |
| random_subset = np.random.choice( | |
| dataset_len, | |
| int(dataset_len * dataset_ratio_factor[i]), | |
| replace=False, | |
| ) | |
| self.datasets[i] = dataset.get_subset( | |
| random_subset.flatten().tolist(), | |
| ) | |
| self._lens = [len(dataset) for dataset in self.datasets] | |
| if self.resample: | |
| assert len(sample_ratio_factor) == len(datasets), f'the length ' \ | |
| f'of `sample_ratio_factor` {len(sample_ratio_factor)} does ' \ | |
| f'not match the length of `datasets` {len(datasets)}' | |
| assert min(sample_ratio_factor) >= 0.0, 'the ratio values in ' \ | |
| '`sample_ratio_factor` should not be negative.' | |
| self._lens_ori = self._lens | |
| self._lens = [ | |
| round(l * sample_ratio_factor[i]) | |
| for i, l in enumerate(self._lens_ori) | |
| ] | |
| self._len = sum(self._lens) | |
| super(CombinedDataset, self).__init__(pipeline=pipeline, **kwargs) | |
| self._metainfo = parse_pose_metainfo(metainfo) | |
| print("CombinedDataset initialized\n\tlen: {}\n\tlens: {}".format(self._len, self._lens)) | |
| def metainfo(self): | |
| return deepcopy(self._metainfo) | |
| def __len__(self): | |
| return self._len | |
| def _get_subset_index(self, index: int) -> Tuple[int, int]: | |
| """Given a data sample's global index, return the index of the sub- | |
| dataset the data sample belongs to, and the local index within that | |
| sub-dataset. | |
| Args: | |
| index (int): The global data sample index | |
| Returns: | |
| tuple[int, int]: | |
| - subset_index (int): The index of the sub-dataset | |
| - local_index (int): The index of the data sample within | |
| the sub-dataset | |
| """ | |
| if index >= len(self) or index < -len(self): | |
| raise ValueError( | |
| f'index({index}) is out of bounds for dataset with ' | |
| f'length({len(self)}).') | |
| if index < 0: | |
| index = index + len(self) | |
| subset_index = 0 | |
| while index >= self._lens[subset_index]: | |
| index -= self._lens[subset_index] | |
| subset_index += 1 | |
| if self.resample: | |
| gap = (self._lens_ori[subset_index] - | |
| 1e-4) / self._lens[subset_index] | |
| index = round(gap * index + np.random.rand() * gap - 0.5) | |
| return subset_index, index | |
| def prepare_data(self, idx: int) -> Any: | |
| """Get data processed by ``self.pipeline``.The source dataset is | |
| depending on the index. | |
| Args: | |
| idx (int): The index of ``data_info``. | |
| Returns: | |
| Any: Depends on ``self.pipeline``. | |
| """ | |
| data_info = self.get_data_info(idx) | |
| # the assignment of 'dataset' should not be performed within the | |
| # `get_data_info` function. Otherwise, it can lead to the mixed | |
| # data augmentation process getting stuck. | |
| data_info['dataset'] = self | |
| return self.pipeline(data_info) | |
| def get_data_info(self, idx: int) -> dict: | |
| """Get annotation by index. | |
| Args: | |
| idx (int): Global index of ``CombinedDataset``. | |
| Returns: | |
| dict: The idx-th annotation of the datasets. | |
| """ | |
| subset_idx, sample_idx = self._get_subset_index(idx) | |
| # Get data sample processed by ``subset.pipeline`` | |
| data_info = self.datasets[subset_idx][sample_idx] | |
| if 'dataset' in data_info: | |
| data_info.pop('dataset') | |
| # Add metainfo items that are required in the pipeline and the model | |
| metainfo_keys = [ | |
| 'upper_body_ids', 'lower_body_ids', 'flip_pairs', | |
| 'dataset_keypoint_weights', 'flip_indices' | |
| ] | |
| for key in metainfo_keys: | |
| data_info[key] = deepcopy(self._metainfo[key]) | |
| # Map keypoints based on the dataset keypoint mapping | |
| if self.keypoints_mapping is not None: | |
| mapping = self.keypoints_mapping[subset_idx] | |
| keypoints = data_info['keypoints'] | |
| N, K, D = keypoints.shape | |
| keypoints_visibility = data_info.get('keypoints_visibility', np.zeros((N, K))) | |
| keypoints_visible = data_info.get('keypoints_visible', np.zeros((N, K))) | |
| mapped_keypoints = np.zeros((N, self.num_joints, 2)) | |
| mapped_visibility = np.zeros((N, self.num_joints)) | |
| mapped_visible = np.zeros((N, self.num_joints)) | |
| map_idx = np.stack( | |
| [list(mapping.keys()), list(mapping.values())], axis=1) | |
| mapped_keypoints[:, map_idx[:, 1], :] = data_info['keypoints'][:, map_idx[:, 0], :] | |
| mapped_visibility[:, map_idx[:, 1]] = keypoints_visibility[:, map_idx[:, 0]] | |
| mapped_visible[:, map_idx[:, 1]] = keypoints_visible[:, map_idx[:, 0]] | |
| data_info['keypoints'] = mapped_keypoints.reshape((N, self.num_joints, 2) ) | |
| data_info['keypoints_visibility'] = mapped_visibility.reshape((N, self.num_joints)) | |
| data_info['keypoints_visible'] = mapped_visible.reshape((N, self.num_joints)) | |
| # print('data_info', data_info) | |
| return data_info | |
| def full_init(self): | |
| """Fully initialize all sub datasets.""" | |
| if self._fully_initialized: | |
| return | |
| for dataset in self.datasets: | |
| dataset.full_init() | |
| self._fully_initialized = True | |