Spaces:
Paused
Paused
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset | |
from .builder import DATASETS | |
class ConcatDataset(_ConcatDataset): | |
"""A wrapper of concatenated dataset. | |
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but | |
concat the group flag for image aspect ratio. | |
Args: | |
datasets (list[:obj:`Dataset`]): A list of datasets. | |
""" | |
def __init__(self, datasets): | |
super(ConcatDataset, self).__init__(datasets) | |
self.CLASSES = datasets[0].CLASSES | |
self.PALETTE = datasets[0].PALETTE | |
class RepeatDataset(object): | |
"""A wrapper of repeated dataset. | |
The length of repeated dataset will be `times` larger than the original | |
dataset. This is useful when the data loading time is long but the dataset | |
is small. Using RepeatDataset can reduce the data loading time between | |
epochs. | |
Args: | |
dataset (:obj:`Dataset`): The dataset to be repeated. | |
times (int): Repeat times. | |
""" | |
def __init__(self, dataset, times): | |
self.dataset = dataset | |
self.times = times | |
self.CLASSES = dataset.CLASSES | |
self.PALETTE = dataset.PALETTE | |
self._ori_len = len(self.dataset) | |
def __getitem__(self, idx): | |
"""Get item from original dataset.""" | |
return self.dataset[idx % self._ori_len] | |
def __len__(self): | |
"""The length is multiplied by ``times``""" | |
return self.times * self._ori_len | |