from glob import glob from PIL import Image from typing import Callable, Optional from torch.utils.data import DataLoader from torchvision.datasets import VisionDataset __DATASET__ = {} def register_dataset(name: str): def wrapper(cls): if __DATASET__.get(name, None): raise NameError(f"Name {name} is already registered!") __DATASET__[name] = cls return cls return wrapper def get_dataset(name: str, root: str, **kwargs): if __DATASET__.get(name, None) is None: raise NameError(f"Dataset {name} is not defined.") return __DATASET__[name](root=root, **kwargs) def get_dataloader(dataset: VisionDataset, batch_size: int, num_workers: int, train: bool): dataloader = DataLoader(dataset, batch_size, shuffle=train, num_workers=num_workers, drop_last=train) return dataloader @register_dataset(name='ffhq') class FFHQDataset(VisionDataset): def __init__(self, root: str, transforms: Optional[Callable]=None): super().__init__(root, transforms) self.fpaths = sorted(glob(root + '/**/*.png', recursive=True)) assert len(self.fpaths) > 0, "File list is empty. Check the root." def __len__(self): return len(self.fpaths) def __getitem__(self, index: int): fpath = self.fpaths[index] img = Image.open(fpath).convert('RGB') if self.transforms is not None: img = self.transforms(img) return img