Spaces:
Running
on
A10G
Running
on
A10G
import bisect | |
import random | |
from typing import Iterable | |
from torch.utils.data import Dataset, IterableDataset | |
class ConcatRepeatDataset(Dataset): | |
datasets: list[Dataset] | |
cumulative_sizes: list[int] | |
repeats: list[int] | |
def cumsum(sequence, repeats): | |
r, s = [], 0 | |
for dataset, repeat in zip(sequence, repeats): | |
l = len(dataset) * repeat | |
r.append(l + s) | |
s += l | |
return r | |
def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): | |
super().__init__() | |
self.datasets = list(datasets) | |
self.repeats = repeats | |
assert len(self.datasets) > 0, "datasets should not be an empty iterable" | |
assert len(self.datasets) == len( | |
repeats | |
), "datasets and repeats should have the same length" | |
for d in self.datasets: | |
assert not isinstance( | |
d, IterableDataset | |
), "ConcatRepeatDataset does not support IterableDataset" | |
self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) | |
def __len__(self): | |
return self.cumulative_sizes[-1] | |
def __getitem__(self, idx): | |
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
if dataset_idx == 0: | |
sample_idx = idx | |
else: | |
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
dataset = self.datasets[dataset_idx] | |
return dataset[sample_idx % len(dataset)] | |