Spaces:
Runtime error
Runtime error
File size: 4,435 Bytes
2ae34e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# Copyright (c) OpenMMLab. All rights reserved.
import collections
import copy
from typing import List, Optional, Sequence, Union
from mmengine.dataset import ConcatDataset, force_full_init
from mmseg.registry import DATASETS, TRANSFORMS
@DATASETS.register_module()
class MultiImageMixDataset:
"""A wrapper of multiple images mixed dataset.
Suitable for training on multiple images mixed data augmentation like
mosaic and mixup.
Args:
dataset (ConcatDataset or dict): The dataset to be mixed.
pipeline (Sequence[dict]): Sequence of transform object or
config dict to be composed.
skip_type_keys (list[str], optional): Sequence of type string to
be skip pipeline. Default to None.
"""
def __init__(self,
dataset: Union[ConcatDataset, dict],
pipeline: Sequence[dict],
skip_type_keys: Optional[List[str]] = None,
lazy_init: bool = False) -> None:
assert isinstance(pipeline, collections.abc.Sequence)
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
elif isinstance(dataset, ConcatDataset):
self.dataset = dataset
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`ConcatDataset` instance, but got {type(dataset)}')
if skip_type_keys is not None:
assert all([
isinstance(skip_type_key, str)
for skip_type_key in skip_type_keys
])
self._skip_type_keys = skip_type_keys
self.pipeline = []
self.pipeline_types = []
for transform in pipeline:
if isinstance(transform, dict):
self.pipeline_types.append(transform['type'])
transform = TRANSFORMS.build(transform)
self.pipeline.append(transform)
else:
raise TypeError('pipeline must be a dict')
self._metainfo = self.dataset.metainfo
self.num_samples = len(self.dataset)
self._fully_initialized = False
if not lazy_init:
self.full_init()
@property
def metainfo(self) -> dict:
"""Get the meta information of the multi-image-mixed dataset.
Returns:
dict: The meta information of multi-image-mixed dataset.
"""
return copy.deepcopy(self._metainfo)
def full_init(self):
"""Loop to ``full_init`` each dataset."""
if self._fully_initialized:
return
self.dataset.full_init()
self._ori_len = len(self.dataset)
self._fully_initialized = True
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): Global index of ``ConcatDataset``.
Returns:
dict: The idx-th annotation of the datasets.
"""
return self.dataset.get_data_info(idx)
@force_full_init
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
results = copy.deepcopy(self.dataset[idx])
for (transform, transform_type) in zip(self.pipeline,
self.pipeline_types):
if self._skip_type_keys is not None and \
transform_type in self._skip_type_keys:
continue
if hasattr(transform, 'get_indices'):
indices = transform.get_indices(self.dataset)
if not isinstance(indices, collections.abc.Sequence):
indices = [indices]
mix_results = [
copy.deepcopy(self.dataset[index]) for index in indices
]
results['mix_results'] = mix_results
results = transform(results)
if 'mix_results' in results:
results.pop('mix_results')
return results
def update_skip_type_keys(self, skip_type_keys):
"""Update skip_type_keys.
It is called by an external hook.
Args:
skip_type_keys (list[str], optional): Sequence of type
string to be skip pipeline.
"""
assert all([
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
])
self._skip_type_keys = skip_type_keys
|