Spaces:
Runtime error
Runtime error
from torch.utils.data import Dataset, ConcatDataset | |
from diffab.utils.transforms import get_transform | |
_DATASET_DICT = {} | |
def register_dataset(name): | |
def decorator(cls): | |
_DATASET_DICT[name] = cls | |
return cls | |
return decorator | |
def get_dataset(cfg): | |
transform = get_transform(cfg.transform) if 'transform' in cfg else None | |
return _DATASET_DICT[cfg.type](cfg, transform=transform) | |
def get_concat_dataset(cfg): | |
datasets = [get_dataset(d) for d in cfg.datasets] | |
return ConcatDataset(datasets) | |
class BalancedConcatDataset(Dataset): | |
def __init__(self, cfg, transform=None): | |
super().__init__() | |
assert transform is None, 'transform is not supported.' | |
self.datasets = [get_dataset(d) for d in cfg.datasets] | |
self.max_size = max([len(d) for d in self.datasets]) | |
def __len__(self): | |
return self.max_size * len(self.datasets) | |
def __getitem__(self, idx): | |
dataset_idx = idx // self.max_size | |
return self.datasets[dataset_idx][idx % len(self.datasets[dataset_idx])] | |