import bisect import warnings from torch._utils import _accumulate from torch import randperm class Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other]) class TensorDataset(Dataset): """Dataset wrapping data and target tensors. Each sample will be retrieved by indexing both tensors along the first dimension. Arguments: data_tensor (Tensor): contains sample data. target_tensor (Tensor): contains sample targets (labels). """ def __init__(self, data_tensor, target_tensor): assert data_tensor.size(0) == target_tensor.size(0) self.data_tensor = data_tensor self.target_tensor = target_tensor def __getitem__(self, index): return self.data_tensor[index], self.target_tensor[index] def __len__(self): return self.data_tensor.size(0) class ConcatDataset(Dataset): """ Dataset to concatenate multiple datasets. Purpose: useful to assemble different existing datasets, possibly large-scale datasets as the concatenation operation is done in an on-the-fly manner. Arguments: datasets (iterable): List of datasets to be concatenated """ @staticmethod def cumsum(sequence): r, s = [], 0 for e in sequence: l = len(e) r.append(l + s) s += l return r def __init__(self, datasets): super(ConcatDataset, self).__init__() assert len(datasets) > 0, 'datasets should not be an empty iterable' self.datasets = list(datasets) self.cumulative_sizes = self.cumsum(self.datasets) def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return self.datasets[dataset_idx][sample_idx] @property def cummulative_sizes(self): warnings.warn("cummulative_sizes attribute is renamed to " "cumulative_sizes", DeprecationWarning, stacklevel=2) return self.cumulative_sizes class Subset(Dataset): def __init__(self, dataset, indices): self.dataset = dataset self.indices = indices def __getitem__(self, idx): return self.dataset[self.indices[idx]] def __len__(self): return len(self.indices) def random_split(dataset, lengths): """ Randomly split a dataset into non-overlapping new datasets of given lengths ds Arguments: dataset (Dataset): Dataset to be split lengths (iterable): lengths of splits to be produced """ if sum(lengths) != len(dataset): raise ValueError("Sum of input lengths does not equal the length of the input dataset!") indices = randperm(sum(lengths)) return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]