| |
| import collections |
| import copy |
| from typing import List, Optional, Sequence, Union |
|
|
| from mmengine.dataset import ConcatDataset, force_full_init |
|
|
| from mmseg.registry import DATASETS, TRANSFORMS |
|
|
|
|
| @DATASETS.register_module() |
| class MultiImageMixDataset: |
| """A wrapper of multiple images mixed dataset. |
| |
| Suitable for training on multiple images mixed data augmentation like |
| mosaic and mixup. |
| |
| Args: |
| dataset (ConcatDataset or dict): The dataset to be mixed. |
| pipeline (Sequence[dict]): Sequence of transform object or |
| config dict to be composed. |
| skip_type_keys (list[str], optional): Sequence of type string to |
| be skip pipeline. Default to None. |
| """ |
|
|
| def __init__(self, |
| dataset: Union[ConcatDataset, dict], |
| pipeline: Sequence[dict], |
| skip_type_keys: Optional[List[str]] = None, |
| lazy_init: bool = False) -> None: |
| assert isinstance(pipeline, collections.abc.Sequence) |
|
|
| if isinstance(dataset, dict): |
| self.dataset = DATASETS.build(dataset) |
| elif isinstance(dataset, ConcatDataset): |
| self.dataset = dataset |
| else: |
| raise TypeError( |
| 'elements in datasets sequence should be config or ' |
| f'`ConcatDataset` instance, but got {type(dataset)}') |
|
|
| if skip_type_keys is not None: |
| assert all([ |
| isinstance(skip_type_key, str) |
| for skip_type_key in skip_type_keys |
| ]) |
| self._skip_type_keys = skip_type_keys |
|
|
| self.pipeline = [] |
| self.pipeline_types = [] |
| for transform in pipeline: |
| if isinstance(transform, dict): |
| self.pipeline_types.append(transform['type']) |
| transform = TRANSFORMS.build(transform) |
| self.pipeline.append(transform) |
| else: |
| raise TypeError('pipeline must be a dict') |
|
|
| self._metainfo = self.dataset.metainfo |
| self.num_samples = len(self.dataset) |
|
|
| self._fully_initialized = False |
| if not lazy_init: |
| self.full_init() |
|
|
| @property |
| def metainfo(self) -> dict: |
| """Get the meta information of the multi-image-mixed dataset. |
| |
| Returns: |
| dict: The meta information of multi-image-mixed dataset. |
| """ |
| return copy.deepcopy(self._metainfo) |
|
|
| def full_init(self): |
| """Loop to ``full_init`` each dataset.""" |
| if self._fully_initialized: |
| return |
|
|
| self.dataset.full_init() |
| self._ori_len = len(self.dataset) |
| self._fully_initialized = True |
|
|
| @force_full_init |
| def get_data_info(self, idx: int) -> dict: |
| """Get annotation by index. |
| |
| Args: |
| idx (int): Global index of ``ConcatDataset``. |
| |
| Returns: |
| dict: The idx-th annotation of the datasets. |
| """ |
| return self.dataset.get_data_info(idx) |
|
|
| @force_full_init |
| def __len__(self): |
| return self.num_samples |
|
|
| def __getitem__(self, idx): |
| results = copy.deepcopy(self.dataset[idx]) |
| for (transform, transform_type) in zip(self.pipeline, |
| self.pipeline_types): |
| if self._skip_type_keys is not None and \ |
| transform_type in self._skip_type_keys: |
| continue |
|
|
| if hasattr(transform, 'get_indices'): |
| indices = transform.get_indices(self.dataset) |
| if not isinstance(indices, collections.abc.Sequence): |
| indices = [indices] |
| mix_results = [ |
| copy.deepcopy(self.dataset[index]) for index in indices |
| ] |
| results['mix_results'] = mix_results |
|
|
| results = transform(results) |
|
|
| if 'mix_results' in results: |
| results.pop('mix_results') |
|
|
| return results |
|
|
| def update_skip_type_keys(self, skip_type_keys): |
| """Update skip_type_keys. |
| |
| It is called by an external hook. |
| |
| Args: |
| skip_type_keys (list[str], optional): Sequence of type |
| string to be skip pipeline. |
| """ |
| assert all([ |
| isinstance(skip_type_key, str) for skip_type_key in skip_type_keys |
| ]) |
| self._skip_type_keys = skip_type_keys |
|
|