File size: 1,498 Bytes
c310e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import bisect
import numpy as np
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset


class ConcatDataset(_ConcatDataset):
    """
    Same as torch.utils.data.dataset.ConcatDataset, but exposes an extra
    method for querying the sizes of the image
    """

    def get_idxs(self, idx):
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return dataset_idx, sample_idx

    def get_img_info(self, idx):
        dataset_idx, sample_idx = self.get_idxs(idx)
        return self.datasets[dataset_idx].get_img_info(sample_idx)

class MixDataset(object):
    def __init__(self, datasets, ratios):
        self.datasets = datasets
        self.ratios = ratios
        self.lengths = []
        for dataset in self.datasets:
            self.lengths.append(len(dataset))
        self.lengths = np.array(self.lengths)
        self.seperate_inds = []
        s = 0
        for i in self.ratios[:-1]:
            s += i
            self.seperate_inds.append(s)

    def __len__(self):
       return self.lengths.sum()
       
    def __getitem__(self, item):
        i = np.random.rand()
        ind = bisect.bisect_right(self.seperate_inds, i)
        b_ind = np.random.randint(self.lengths[ind])
        return self.datasets[ind][b_ind]