Spaces:
Runtime error
Runtime error
import bisect | |
import warnings | |
from torch._utils import _accumulate | |
from torch import randperm | |
class Dataset(object): | |
"""An abstract class representing a Dataset. | |
All other datasets should subclass it. All subclasses should override | |
``__len__``, that provides the size of the dataset, and ``__getitem__``, | |
supporting integer indexing in range from 0 to len(self) exclusive. | |
""" | |
def __getitem__(self, index): | |
raise NotImplementedError | |
def __len__(self): | |
raise NotImplementedError | |
def __add__(self, other): | |
return ConcatDataset([self, other]) | |
class TensorDataset(Dataset): | |
"""Dataset wrapping data and target tensors. | |
Each sample will be retrieved by indexing both tensors along the first | |
dimension. | |
Arguments: | |
data_tensor (Tensor): contains sample data. | |
target_tensor (Tensor): contains sample targets (labels). | |
""" | |
def __init__(self, data_tensor, target_tensor): | |
assert data_tensor.size(0) == target_tensor.size(0) | |
self.data_tensor = data_tensor | |
self.target_tensor = target_tensor | |
def __getitem__(self, index): | |
return self.data_tensor[index], self.target_tensor[index] | |
def __len__(self): | |
return self.data_tensor.size(0) | |
class ConcatDataset(Dataset): | |
""" | |
Dataset to concatenate multiple datasets. | |
Purpose: useful to assemble different existing datasets, possibly | |
large-scale datasets as the concatenation operation is done in an | |
on-the-fly manner. | |
Arguments: | |
datasets (iterable): List of datasets to be concatenated | |
""" | |
def cumsum(sequence): | |
r, s = [], 0 | |
for e in sequence: | |
l = len(e) | |
r.append(l + s) | |
s += l | |
return r | |
def __init__(self, datasets): | |
super(ConcatDataset, self).__init__() | |
assert len(datasets) > 0, 'datasets should not be an empty iterable' | |
self.datasets = list(datasets) | |
self.cumulative_sizes = self.cumsum(self.datasets) | |
def __len__(self): | |
return self.cumulative_sizes[-1] | |
def __getitem__(self, idx): | |
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
if dataset_idx == 0: | |
sample_idx = idx | |
else: | |
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
return self.datasets[dataset_idx][sample_idx] | |
def cummulative_sizes(self): | |
warnings.warn("cummulative_sizes attribute is renamed to " | |
"cumulative_sizes", DeprecationWarning, stacklevel=2) | |
return self.cumulative_sizes | |
class Subset(Dataset): | |
def __init__(self, dataset, indices): | |
self.dataset = dataset | |
self.indices = indices | |
def __getitem__(self, idx): | |
return self.dataset[self.indices[idx]] | |
def __len__(self): | |
return len(self.indices) | |
def random_split(dataset, lengths): | |
""" | |
Randomly split a dataset into non-overlapping new datasets of given lengths | |
ds | |
Arguments: | |
dataset (Dataset): Dataset to be split | |
lengths (iterable): lengths of splits to be produced | |
""" | |
if sum(lengths) != len(dataset): | |
raise ValueError("Sum of input lengths does not equal the length of the input dataset!") | |
indices = randperm(sum(lengths)) | |
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] | |