File size: 17,311 Bytes
f1dd031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
import collections
import copy
import random
from typing import List, Sequence, Union

import numpy as np
from mmdet.datasets.base_det_dataset import BaseDetDataset
from mmdet.datasets.base_video_dataset import BaseVideoDataset
from mmdet.registry import DATASETS, TRANSFORMS
from mmengine.dataset import BaseDataset, force_full_init

from .rsconcat_dataset import RandomSampleJointVideoConcatDataset


@DATASETS.register_module(force=True)
class SeqMultiImageMixDataset:
    """A wrapper of multiple images mixed dataset.

    Suitable for training on multiple images mixed data augmentation like
    mosaic and mixup. For the augmentation pipeline of mixed image data,
    the `get_indexes` method needs to be provided to obtain the image
    indexes, and you can set `skip_flags` to change the pipeline running
    process. At the same time, we provide the `dynamic_scale` parameter
    to dynamically change the output image size.

    Args:
        dataset (:obj:`CustomDataset`): The dataset to be mixed.
        pipeline (Sequence[dict]): Sequence of transform object or
            config dict to be composed.
        dynamic_scale (tuple[int], optional): The image scale can be changed
            dynamically. Default to None. It is deprecated.
        skip_type_keys (list[str], optional): Sequence of type string to
            be skip pipeline. Default to None.
        max_refetch (int): The maximum number of retry iterations for getting
            valid results from the pipeline. If the number of iterations is
            greater than `max_refetch`, but results is still None, then the
            iteration is terminated and raise the error. Default: 15.
    """

    def __init__(
        self,
        dataset: Union[BaseDataset, dict],
        pipeline: Sequence[str],
        skip_type_keys: Union[Sequence[str], None] = None,
        max_refetch: int = 15,
        lazy_init: bool = False,
    ) -> None:
        assert isinstance(pipeline, collections.abc.Sequence)
        if skip_type_keys is not None:
            assert all(
                [isinstance(skip_type_key, str) for skip_type_key in skip_type_keys]
            )
        self._skip_type_keys = skip_type_keys

        self.pipeline = []
        self.pipeline_types = []
        for transform in pipeline:
            if isinstance(transform, dict):
                self.pipeline_types.append(transform["type"])
                transform = TRANSFORMS.build(transform)
                self.pipeline.append(transform)
            else:
                raise TypeError("pipeline must be a dict")

        self.dataset: BaseDataset
        if isinstance(dataset, dict):
            self.dataset = DATASETS.build(dataset)
        elif isinstance(dataset, BaseDataset):
            self.dataset = dataset
        else:
            raise TypeError(
                "elements in datasets sequence should be config or "
                f"`BaseDataset` instance, but got {type(dataset)}"
            )

        self._metainfo = self.dataset.metainfo
        if hasattr(self.dataset, "flag"):
            self.flag = self.dataset.flag
        self.num_samples = len(self.dataset)
        self.max_refetch = max_refetch

        self._fully_initialized = False
        if not lazy_init:
            self.full_init()

        self.generate_indices()

    def generate_indices(self):
        cat_datasets = self.dataset.datasets
        for dataset in cat_datasets:
            self.test_mode = dataset.test_mode
            assert not self.test_mode, "'ConcatDataset' should not exist in "
            "test mode"
            video_indices = []
            img_indices = []
            if isinstance(dataset, BaseVideoDataset):
                num_videos = len(dataset)
                for video_ind in range(num_videos):
                    video_indices.extend(
                        [
                            (video_ind, frame_ind)
                            for frame_ind in range(dataset.get_len_per_video(video_ind))
                        ]
                    )
            elif isinstance(dataset, BaseDetDataset):
                num_imgs = len(dataset)
                for img_ind in range(num_imgs):
                    img_indices.extend([img_ind])

        ###### special process to make debug task easier #####
        def alternate_merge(list1, list2):
            # Create a new list to hold the merged elements
            merged_list = []

            # Get the length of the shorter list
            min_length = min(len(list1), len(list2))

            # Append elements alternately from both lists
            for i in range(min_length):
                merged_list.append(list1[i])
                merged_list.append(list2[i])

            # Append the remaining elements from the longer list
            if len(list1) > len(list2):
                merged_list.extend(list1[min_length:])
            else:
                merged_list.extend(list2[min_length:])

            return merged_list

        self.indices = alternate_merge(img_indices, video_indices)

    @property
    def metainfo(self) -> dict:
        """Get the meta information of the multi-image-mixed dataset.

        Returns:
            dict: The meta information of multi-image-mixed dataset.
        """
        return copy.deepcopy(self._metainfo)

    def full_init(self):
        """Loop to ``full_init`` each dataset."""
        if self._fully_initialized:
            return

        self.dataset.full_init()
        self._ori_len = len(self.dataset)
        self._fully_initialized = True

    @force_full_init
    def get_data_info(self, idx: int) -> dict:
        """Get annotation by index.

        Args:
            idx (int): Global index of ``ConcatDataset``.

        Returns:
            dict: The idx-th annotation of the datasets.
        """
        return self.dataset.get_data_info(idx)

    @force_full_init
    def get_transform_indexes(self, transform, results, t_type="SeqMosaic"):
        num_samples = len(results["img_id"])
        for i in range(self.max_refetch):
            # Make sure the results passed the loading pipeline
            # of the original dataset is not None.
            indexes = transform.get_indexes(self.dataset)
            if not isinstance(indexes, collections.abc.Sequence):
                indexes = [indexes]
            mix_results = [copy.deepcopy(self.dataset[index]) for index in indexes]
            if None not in mix_results:
                if t_type == "SeqMosaic":
                    results["mosaic_mix_results"] = [mix_results] * num_samples
                elif t_type == "SeqMixUp":
                    results["mixup_mix_results"] = [mix_results] * num_samples
                elif t_type == "SeqCopyPaste":
                    results["copypaste_mix_results"] = [mix_results] * num_samples
                return results
        else:
            raise RuntimeError(
                "The loading pipeline of the original dataset"
                " always return None. Please check the correctness "
                "of the dataset and its pipeline."
            )

    @force_full_init
    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        while True:
            results = copy.deepcopy(self.dataset[idx])

            for (transform, transform_type) in zip(self.pipeline, self.pipeline_types):
                if (
                    self._skip_type_keys is not None
                    and transform_type in self._skip_type_keys
                ):
                    continue
                if transform_type == "MasaTransformBroadcaster":
                    for sub_transform in transform.transforms:
                        if hasattr(sub_transform, "get_indexes"):
                            sub_transform_type = type(sub_transform).__name__
                            results = self.get_transform_indexes(
                                sub_transform, results, sub_transform_type
                            )

                elif hasattr(transform, "get_indexes"):
                    for i in range(self.max_refetch):
                        # Make sure the results passed the loading pipeline
                        # of the original dataset is not None.
                        indexes = transform.get_indexes(self.dataset)
                        if not isinstance(indexes, collections.abc.Sequence):
                            indexes = [indexes]
                        mix_results = [
                            copy.deepcopy(self.dataset[index]) for index in indexes
                        ]
                        if None not in mix_results:
                            results["mix_results"] = mix_results
                            break
                    else:
                        raise RuntimeError(
                            "The loading pipeline of the original dataset"
                            " always return None. Please check the correctness "
                            "of the dataset and its pipeline."
                        )

                for i in range(self.max_refetch):
                    # To confirm the results passed the training pipeline
                    # of the wrapper is not None.
                    try:
                        updated_results = transform(copy.deepcopy(results))
                    except Exception as e:
                        print(
                            "Error occurred while running pipeline",
                            f"{transform} with error: {e}",
                        )
                        # print('Empty instances due to augmentation, re-sampling...')
                        idx = self._rand_another(idx)
                        continue
                    if updated_results is not None:
                        results = updated_results
                        break
                else:
                    raise RuntimeError(
                        "The training pipeline of the dataset wrapper"
                        " always return None.Please check the correctness "
                        "of the dataset and its pipeline."
                    )

                if "mosaic_mix_results" in results:
                    results.pop("mosaic_mix_results")

                if "mixup_mix_results" in results:
                    results.pop("mixup_mix_results")

                if "copypaste_mix_results" in results:
                    results.pop("copypaste_mix_results")

            return results

    def update_skip_type_keys(self, skip_type_keys):
        """Update skip_type_keys. It is called by an external hook.

        Args:
            skip_type_keys (list[str], optional): Sequence of type
                string to be skip pipeline.
        """
        assert all([isinstance(skip_type_key, str) for skip_type_key in skip_type_keys])
        self._skip_type_keys = skip_type_keys

    def _rand_another(self, idx):
        """Get another random index from the same group as the given index."""
        return np.random.choice(self.indices)


@DATASETS.register_module()
class SeqRandomMultiImageVideoMixDataset(SeqMultiImageMixDataset):
    def __init__(
        self, video_pipeline: Sequence[str], video_sample_ratio=0.5, *args, **kwargs
    ):
        super().__init__(*args, **kwargs)

        self.video_pipeline = []
        self.video_pipeline_types = []
        for transform in video_pipeline:
            if isinstance(transform, dict):
                self.video_pipeline_types.append(transform["type"])
                transform = TRANSFORMS.build(transform)
                self.video_pipeline.append(transform)
            else:
                raise TypeError("pipeline must be a dict")

        self.video_sample_ratio = video_sample_ratio
        assert isinstance(self.dataset, RandomSampleJointVideoConcatDataset)

    @force_full_init
    def get_transform_indexes(
        self, transform, results, sample_video, t_type="SeqMosaic"
    ):
        num_samples = len(results["img_id"])
        for i in range(self.max_refetch):
            # Make sure the results passed the loading pipeline
            # of the original dataset is not None.

            indexes = transform.get_indexes(self.dataset.datasets[0])
            if not isinstance(indexes, collections.abc.Sequence):
                indexes = [indexes]
            if sample_video:
                mix_results = [copy.deepcopy(self.dataset[0]) for index in indexes]
            else:
                mix_results = [copy.deepcopy(self.dataset[1]) for index in indexes]

            if None not in mix_results:
                if t_type == "SeqMosaic":
                    results["mosaic_mix_results"] = [mix_results] * num_samples
                elif t_type == "SeqMixUp":
                    results["mixup_mix_results"] = [mix_results] * num_samples
                elif t_type == "SeqCopyPaste":
                    results["copypaste_mix_results"] = [mix_results] * num_samples
                return results
        else:
            raise RuntimeError(
                "The loading pipeline of the original dataset"
                " always return None. Please check the correctness "
                "of the dataset and its pipeline."
            )

    def __getitem__(self, idx):

        while True:
            if random.random() < self.video_sample_ratio:
                sample_video = True
            else:
                sample_video = False
            if sample_video:
                results = copy.deepcopy(self.dataset[0])
                pipeline = self.video_pipeline
                pipeline_type = self.video_pipeline_types

            else:
                results = copy.deepcopy(self.dataset[1])
                pipeline = self.pipeline
                pipeline_type = self.pipeline_types
                # if results['img_id'][0] != results['img_id'][1]:
                #     self.update_skip_type_keys(['SeqMosaic', 'SeqMixUp'])
                # else:
                #     self._skip_type_keys = None

            for (transform, transform_type) in zip(pipeline, pipeline_type):
                if (
                    self._skip_type_keys is not None
                    and transform_type in self._skip_type_keys
                ):
                    continue
                if transform_type == "MasaTransformBroadcaster":
                    for sub_transform in transform.transforms:
                        if hasattr(sub_transform, "get_indexes"):
                            sub_transform_type = type(sub_transform).__name__
                            results = self.get_transform_indexes(
                                sub_transform, results, sample_video, sub_transform_type
                            )

                elif hasattr(transform, "get_indexes"):
                    for i in range(self.max_refetch):
                        # Make sure the results passed the loading pipeline
                        # of the original dataset is not None.
                        indexes = transform.get_indexes(self.dataset)
                        if not isinstance(indexes, collections.abc.Sequence):
                            indexes = [indexes]
                        mix_results = [
                            copy.deepcopy(self.dataset[index]) for index in indexes
                        ]
                        if None not in mix_results:
                            results["mix_results"] = mix_results
                            break
                    else:
                        raise RuntimeError(
                            "The loading pipeline of the original dataset"
                            " always return None. Please check the correctness "
                            "of the dataset and its pipeline."
                        )

                for i in range(self.max_refetch):
                    # To confirm the results passed the training pipeline
                    # of the wrapper is not None.
                    try:
                        updated_results = transform(copy.deepcopy(results))
                    except Exception as e:
                        print(
                            "Error occurred while running pipeline",
                            f"{transform} with error: {e}",
                        )
                        # print('Empty instances due to augmentation, re-sampling...')
                        # idx = self._rand_another(idx)
                        continue
                    if updated_results is not None:
                        results = updated_results
                        break
                else:
                    raise RuntimeError(
                        "The training pipeline of the dataset wrapper"
                        " always return None.Please check the correctness "
                        "of the dataset and its pipeline."
                    )

                if "mosaic_mix_results" in results:
                    results.pop("mosaic_mix_results")

                if "mixup_mix_results" in results:
                    results.pop("mixup_mix_results")

                if "copypaste_mix_results" in results:
                    results.pop("copypaste_mix_results")

            return results