File size: 4,392 Bytes
a80d6bb
 
 
 
 
c74a070
a80d6bb
 
c74a070
a80d6bb
 
 
 
 
 
 
 
 
 
 
c74a070
 
 
 
 
 
 
 
 
 
a80d6bb
 
c74a070
a80d6bb
 
 
 
 
 
 
 
 
c74a070
a80d6bb
 
c74a070
a80d6bb
 
 
 
c74a070
a80d6bb
 
c74a070
 
 
 
 
 
 
a80d6bb
 
 
 
c74a070
 
 
 
 
 
 
 
 
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
c74a070
a80d6bb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from torch.utils.data import Sampler, ConcatDataset


class RandomConcatSampler(Sampler):
    """Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
    in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
    However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.

    For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
    Args:
        shuffle (bool): shuffle the random sampled indices across all sub-datsets.
        repeat (int): repeatedly use the sampled indices multiple times for training.
            [arXiv:1902.05509, arXiv:1901.09335]
    NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
    NOTE: This sampler behaves differently with DistributedSampler.
          It assume the dataset is splitted across ranks instead of replicated.
    TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
          ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
    """

    def __init__(
        self,
        data_source: ConcatDataset,
        n_samples_per_subset: int,
        subset_replacement: bool = True,
        shuffle: bool = True,
        repeat: int = 1,
        seed: int = None,
    ):
        if not isinstance(data_source, ConcatDataset):
            raise TypeError("data_source should be torch.utils.data.ConcatDataset")

        self.data_source = data_source
        self.n_subset = len(self.data_source.datasets)
        self.n_samples_per_subset = n_samples_per_subset
        self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
        self.subset_replacement = subset_replacement
        self.repeat = repeat
        self.shuffle = shuffle
        self.generator = torch.manual_seed(seed)
        assert self.repeat >= 1

    def __len__(self):
        return self.n_samples

    def __iter__(self):
        indices = []
        # sample from each sub-dataset
        for d_idx in range(self.n_subset):
            low = 0 if d_idx == 0 else self.data_source.cumulative_sizes[d_idx - 1]
            high = self.data_source.cumulative_sizes[d_idx]
            if self.subset_replacement:
                rand_tensor = torch.randint(
                    low,
                    high,
                    (self.n_samples_per_subset,),
                    generator=self.generator,
                    dtype=torch.int64,
                )
            else:  # sample without replacement
                len_subset = len(self.data_source.datasets[d_idx])
                rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
                if len_subset >= self.n_samples_per_subset:
                    rand_tensor = rand_tensor[: self.n_samples_per_subset]
                else:  # padding with replacement
                    rand_tensor_replacement = torch.randint(
                        low,
                        high,
                        (self.n_samples_per_subset - len_subset,),
                        generator=self.generator,
                        dtype=torch.int64,
                    )
                    rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
            indices.append(rand_tensor)
        indices = torch.cat(indices)
        if self.shuffle:  # shuffle the sampled dataset (from multiple subsets)
            rand_tensor = torch.randperm(len(indices), generator=self.generator)
            indices = indices[rand_tensor]

        # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling)
        if self.repeat > 1:
            repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
            if self.shuffle:
                _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
                repeat_indices = map(_choice, repeat_indices)
            indices = torch.cat([indices, *repeat_indices], 0)

        assert indices.shape[0] == self.n_samples
        return iter(indices.tolist())