import torch from torch.utils.data import Sampler, ConcatDataset class RandomConcatSampler(Sampler): """Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement. However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase. For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not. Args: shuffle (bool): shuffle the random sampled indices across all sub-datsets. repeat (int): repeatedly use the sampled indices multiple times for training. [arXiv:1902.05509, arXiv:1901.09335] NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples) NOTE: This sampler behaves differently with DistributedSampler. It assume the dataset is splitted across ranks instead of replicated. TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs. ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373 """ def __init__( self, data_source: ConcatDataset, n_samples_per_subset: int, subset_replacement: bool = True, shuffle: bool = True, repeat: int = 1, seed: int = None, ): if not isinstance(data_source, ConcatDataset): raise TypeError("data_source should be torch.utils.data.ConcatDataset") self.data_source = data_source self.n_subset = len(self.data_source.datasets) self.n_samples_per_subset = n_samples_per_subset self.n_samples = self.n_subset * self.n_samples_per_subset * repeat self.subset_replacement = subset_replacement self.repeat = repeat self.shuffle = shuffle self.generator = torch.manual_seed(seed) assert self.repeat >= 1 def __len__(self): return self.n_samples def __iter__(self): indices = [] # sample from each sub-dataset for d_idx in range(self.n_subset): low = 0 if d_idx == 0 else self.data_source.cumulative_sizes[d_idx - 1] high = self.data_source.cumulative_sizes[d_idx] if self.subset_replacement: rand_tensor = torch.randint( low, high, (self.n_samples_per_subset,), generator=self.generator, dtype=torch.int64, ) else: # sample without replacement len_subset = len(self.data_source.datasets[d_idx]) rand_tensor = torch.randperm(len_subset, generator=self.generator) + low if len_subset >= self.n_samples_per_subset: rand_tensor = rand_tensor[: self.n_samples_per_subset] else: # padding with replacement rand_tensor_replacement = torch.randint( low, high, (self.n_samples_per_subset - len_subset,), generator=self.generator, dtype=torch.int64, ) rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement]) indices.append(rand_tensor) indices = torch.cat(indices) if self.shuffle: # shuffle the sampled dataset (from multiple subsets) rand_tensor = torch.randperm(len(indices), generator=self.generator) indices = indices[rand_tensor] # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling) if self.repeat > 1: repeat_indices = [indices.clone() for _ in range(self.repeat - 1)] if self.shuffle: _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)] repeat_indices = map(_choice, repeat_indices) indices = torch.cat([indices, *repeat_indices], 0) assert indices.shape[0] == self.n_samples return iter(indices.tolist())