File size: 3,340 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, List, Sequence, Union

from mmengine.dataset import BaseDataset, Compose
from mmengine.dataset import ConcatDataset as MMENGINE_CONCATDATASET

from mmocr.registry import DATASETS


@DATASETS.register_module()
class ConcatDataset(MMENGINE_CONCATDATASET):
    """A wrapper of concatenated dataset.

    Same as ``torch.utils.data.dataset.ConcatDataset`` and support lazy_init.

    Note:
        ``ConcatDataset`` should not inherit from ``BaseDataset`` since
        ``get_subset`` and ``get_subset_`` could produce ambiguous meaning
        sub-dataset which conflicts with original dataset. If you want to use
        a sub-dataset of ``ConcatDataset``, you should set ``indices``
        arguments for wrapped dataset which inherit from ``BaseDataset``.

    Args:
        datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets
            which will be concatenated.
        pipeline (list, optional): Processing pipeline to be applied to all
            of the concatenated datasets. Defaults to [].
        verify_meta (bool): Whether to verify the consistency of meta
            information of the concatenated datasets. Defaults to True.
        force_apply (bool): Whether to force apply pipeline to all datasets if
            any of them already has the pipeline configured. Defaults to False.
        lazy_init (bool, optional): Whether to load annotation during
            instantiation. Defaults to False.
    """

    def __init__(self,
                 datasets: Sequence[Union[BaseDataset, dict]],
                 pipeline: List[Union[dict, Callable]] = [],
                 verify_meta: bool = True,
                 force_apply: bool = False,
                 lazy_init: bool = False):
        self.datasets: List[BaseDataset] = []

        # Compose dataset
        pipeline = Compose(pipeline)

        for i, dataset in enumerate(datasets):
            if isinstance(dataset, dict):
                self.datasets.append(DATASETS.build(dataset))
            elif isinstance(dataset, BaseDataset):
                self.datasets.append(dataset)
            else:
                raise TypeError(
                    'elements in datasets sequence should be config or '
                    f'`BaseDataset` instance, but got {type(dataset)}')
            if len(pipeline.transforms) > 0:
                if len(self.datasets[-1].pipeline.transforms
                       ) > 0 and not force_apply:
                    raise ValueError(
                        f'The pipeline of dataset {i} is not empty, '
                        'please set `force_apply` to True.')
                self.datasets[-1].pipeline = pipeline

        self._metainfo = self.datasets[0].metainfo

        if verify_meta:
            # Only use metainfo of first dataset.
            for i, dataset in enumerate(self.datasets, 1):
                if self._metainfo != dataset.metainfo:
                    raise ValueError(
                        f'The meta information of the {i}-th dataset does not '
                        'match meta information of the first dataset')

        self._fully_initialized = False
        if not lazy_init:
            self.full_init()
            self._metainfo.update(dict(cumulative_sizes=self.cumulative_sizes))