File size: 9,320 Bytes
f1dd031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# Copyright (c) OpenMMLab. All rights reserved.
import math
import random
from typing import Iterator, Optional, Sized

import numpy as np
from mmdet.datasets.base_det_dataset import BaseDetDataset
from mmdet.datasets.base_video_dataset import BaseVideoDataset
from mmdet.registry import DATA_SAMPLERS
from mmengine.dataset import ClassBalancedDataset, ConcatDataset
from mmengine.dist import get_dist_info, sync_random_seed
from torch.utils.data import Sampler
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset

from ..dataset_wrappers import SeqMultiImageMixDataset


@DATA_SAMPLERS.register_module()
class HybridVideoImgSampler(Sampler):
    """Sampler that providing image-level sampling outputs for video datasets
    in tracking tasks. It could be both used in both distributed and
    non-distributed environment.
    If using the default sampler in pytorch, the subsequent data receiver will
    get one video, which is not desired in some cases:
    (Take a non-distributed environment as an example)
    1. In test mode, we want only one image is fed into the data pipeline. This
    is in consideration of memory usage since feeding the whole video commonly
    requires a large amount of memory (>=20G on MOTChallenge17 dataset), which
    is not available in some machines.
    2. In training mode, we may want to make sure all the images in one video
    are randomly sampled once in one epoch and this can not be guaranteed in
    the default sampler in pytorch.

    Args:
        dataset (Sized): Dataset used for sampling.
        seed (int, optional): random seed used to shuffle the sampler. This
            number should be identical across all processes in the distributed
            group. Defaults to None.
    """

    def __init__(self, dataset: Sized, seed: Optional[int] = None,) -> None:
        rank, world_size = get_dist_info()
        self.rank = rank
        self.world_size = world_size
        self.epoch = 0
        if seed is None:
            self.seed = sync_random_seed()
        else:
            self.seed = seed

        self.dataset = dataset
        self.indices = []
        # Hard code here to handle different dataset wrapper
        if isinstance(self.dataset, ConcatDataset):
            cat_datasets = self.dataset.datasets
            assert isinstance(
                cat_datasets[0], BaseVideoDataset
            ), f"expected BaseVideoDataset, but got {type(cat_datasets[0])}"
            self.test_mode = cat_datasets[0].test_mode
            assert not self.test_mode, "'ConcatDataset' should not exist in "
            "test mode"
            for dataset in cat_datasets:
                num_videos = len(dataset)
                for video_ind in range(num_videos):
                    self.indices.extend(
                        [
                            (video_ind, frame_ind)
                            for frame_ind in range(dataset.get_len_per_video(video_ind))
                        ]
                    )
        elif isinstance(self.dataset, ClassBalancedDataset):
            ori_dataset = self.dataset.dataset
            assert isinstance(
                ori_dataset, BaseVideoDataset
            ), f"expected BaseVideoDataset, but got {type(ori_dataset)}"
            self.test_mode = ori_dataset.test_mode
            assert not self.test_mode, "'ClassBalancedDataset' should not "
            "exist in test mode"
            video_indices = self.dataset.repeat_indices
            for index in video_indices:
                self.indices.extend(
                    [
                        (index, frame_ind)
                        for frame_ind in range(ori_dataset.get_len_per_video(index))
                    ]
                )
        elif isinstance(self.dataset, BaseVideoDataset):
            self.test_mode = self.dataset.test_mode
            num_videos = len(self.dataset)

            if self.test_mode:
                # in test mode, the images belong to the same video must be put
                # on the same device.
                if num_videos < self.world_size:
                    raise ValueError(
                        f"only {num_videos} videos loaded,"
                        f"but {self.world_size} gpus were given."
                    )
                chunks = np.array_split(list(range(num_videos)), self.world_size)
                for videos_inds in chunks:
                    indices_chunk = []
                    for video_ind in videos_inds:
                        indices_chunk.extend(
                            [
                                (video_ind, frame_ind)
                                for frame_ind in range(
                                    self.dataset.get_len_per_video(video_ind)
                                )
                            ]
                        )
                    self.indices.append(indices_chunk)
            else:
                for video_ind in range(num_videos):
                    self.indices.extend(
                        [
                            (video_ind, frame_ind)
                            for frame_ind in range(
                                self.dataset.get_len_per_video(video_ind)
                            )
                        ]
                    )
        else:
            assert isinstance(self.dataset, SeqMultiImageMixDataset), (
                "HybridVideoImgSampler is only supported in BaseVideoDataset or "
                "dataset wrapper: ClassBalancedDataset and ConcatDataset,SeqMultiImageMixDataset, but "
                f"got {type(self.dataset)} "
            )
            self.test_mode = self.dataset.test_mode
            # num_videos = len(self.dataset)
            if self.test_mode:
                print("Not support test mode")
                raise NotImplementedError
            else:
                assert isinstance(
                    self.dataset.dataset, _ConcatDataset
                ), "HybridVideoImgSampler is only supported in _ConcatDataset"
                cat_datasets = self.dataset.dataset.datasets
                for dataset in cat_datasets:
                    self.test_mode = dataset.test_mode
                    assert not self.test_mode, "'ConcatDataset' should not exist in "
                    "test mode"
                    if isinstance(dataset, BaseVideoDataset):
                        num_videos = len(dataset)
                        video_indices = []
                        for video_ind in range(num_videos):
                            video_indices.extend(
                                [
                                    (video_ind, frame_ind)
                                    for frame_ind in range(
                                        dataset.get_len_per_video(video_ind)
                                    )
                                ]
                            )
                    elif isinstance(dataset, BaseDetDataset):
                        img_indices = []
                        num_imgs = len(dataset)
                        for img_ind in range(num_imgs):
                            img_indices.extend([img_ind])
                ###### special process to make debug task easier #####
                def alternate_merge(list1, list2):
                    # Create a new list to hold the merged elements
                    merged_list = []

                    # Get the length of the shorter list
                    min_length = min(len(list1), len(list2))

                    # Append elements alternately from both lists
                    for i in range(min_length):
                        merged_list.append(list1[i])
                        merged_list.append(list2[i])

                    # Append the remaining elements from the longer list
                    if len(list1) > len(list2):
                        merged_list.extend(list1[min_length:])
                    else:
                        merged_list.extend(list2[min_length:])

                    return merged_list

                self.indices = alternate_merge(img_indices, video_indices)

        if self.test_mode:
            self.num_samples = len(self.indices[self.rank])
            self.total_size = sum([len(index_list) for index_list in self.indices])
        else:
            self.num_samples = int(math.ceil(len(self.indices) * 1.0 / self.world_size))
            self.total_size = self.num_samples * self.world_size

    def __iter__(self) -> Iterator:
        if self.test_mode:
            # in test mode, the order of frames can not be shuffled.
            indices = self.indices[self.rank]
        else:
            # deterministically shuffle based on epoch
            rng = random.Random(self.epoch + self.seed)
            indices = rng.sample(self.indices, len(self.indices))

            # add extra samples to make it evenly divisible
            indices += indices[: (self.total_size - len(indices))]
            assert len(indices) == self.total_size

            # subsample
            indices = indices[self.rank : self.total_size : self.world_size]
            assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch