Vincentqyw
fix: roma
8b973ee
raw
history blame
4.39 kB
import torch
from torch.utils.data import Sampler, ConcatDataset
class RandomConcatSampler(Sampler):
"""Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
Args:
shuffle (bool): shuffle the random sampled indices across all sub-datsets.
repeat (int): repeatedly use the sampled indices multiple times for training.
[arXiv:1902.05509, arXiv:1901.09335]
NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
NOTE: This sampler behaves differently with DistributedSampler.
It assume the dataset is splitted across ranks instead of replicated.
TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
"""
def __init__(
self,
data_source: ConcatDataset,
n_samples_per_subset: int,
subset_replacement: bool = True,
shuffle: bool = True,
repeat: int = 1,
seed: int = None,
):
if not isinstance(data_source, ConcatDataset):
raise TypeError("data_source should be torch.utils.data.ConcatDataset")
self.data_source = data_source
self.n_subset = len(self.data_source.datasets)
self.n_samples_per_subset = n_samples_per_subset
self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
self.subset_replacement = subset_replacement
self.repeat = repeat
self.shuffle = shuffle
self.generator = torch.manual_seed(seed)
assert self.repeat >= 1
def __len__(self):
return self.n_samples
def __iter__(self):
indices = []
# sample from each sub-dataset
for d_idx in range(self.n_subset):
low = 0 if d_idx == 0 else self.data_source.cumulative_sizes[d_idx - 1]
high = self.data_source.cumulative_sizes[d_idx]
if self.subset_replacement:
rand_tensor = torch.randint(
low,
high,
(self.n_samples_per_subset,),
generator=self.generator,
dtype=torch.int64,
)
else: # sample without replacement
len_subset = len(self.data_source.datasets[d_idx])
rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
if len_subset >= self.n_samples_per_subset:
rand_tensor = rand_tensor[: self.n_samples_per_subset]
else: # padding with replacement
rand_tensor_replacement = torch.randint(
low,
high,
(self.n_samples_per_subset - len_subset,),
generator=self.generator,
dtype=torch.int64,
)
rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
indices.append(rand_tensor)
indices = torch.cat(indices)
if self.shuffle: # shuffle the sampled dataset (from multiple subsets)
rand_tensor = torch.randperm(len(indices), generator=self.generator)
indices = indices[rand_tensor]
# repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling)
if self.repeat > 1:
repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
if self.shuffle:
_choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
repeat_indices = map(_choice, repeat_indices)
indices = torch.cat([indices, *repeat_indices], 0)
assert indices.shape[0] == self.n_samples
return iter(indices.tolist())