import os import numpy as np import albumentations from torch.utils.data import Dataset from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex class FacesBase(Dataset): def __init__(self, *args, **kwargs): super().__init__() self.data = None self.keys = None def __len__(self): return len(self.data) def __getitem__(self, i): example = self.data[i] ex = {} if self.keys is not None: for k in self.keys: ex[k] = example[k] else: ex = example return ex class CelebAHQTrain(FacesBase): def __init__(self, size, keys=None): super().__init__() root = "data/celebahq" with open("data/celebahqtrain.txt", "r") as f: relpaths = f.read().splitlines() paths = [os.path.join(root, relpath) for relpath in relpaths] self.data = NumpyPaths(paths=paths, size=size, random_crop=False) self.keys = keys class CelebAHQValidation(FacesBase): def __init__(self, size, keys=None): super().__init__() root = "data/celebahq" with open("data/celebahqvalidation.txt", "r") as f: relpaths = f.read().splitlines() paths = [os.path.join(root, relpath) for relpath in relpaths] self.data = NumpyPaths(paths=paths, size=size, random_crop=False) self.keys = keys class FFHQTrain(FacesBase): def __init__(self, size, keys=None): super().__init__() root = "data/ffhq" with open("data/ffhqtrain.txt", "r") as f: relpaths = f.read().splitlines() paths = [os.path.join(root, relpath) for relpath in relpaths] self.data = ImagePaths(paths=paths, size=size, random_crop=False) self.keys = keys class FFHQValidation(FacesBase): def __init__(self, size, keys=None): super().__init__() root = "data/ffhq" with open("data/ffhqvalidation.txt", "r") as f: relpaths = f.read().splitlines() paths = [os.path.join(root, relpath) for relpath in relpaths] self.data = ImagePaths(paths=paths, size=size, random_crop=False) self.keys = keys class FacesHQTrain(Dataset): # CelebAHQ [0] + FFHQ [1] def __init__(self, size, keys=None, crop_size=None, coord=False): d1 = CelebAHQTrain(size=size, keys=keys) d2 = FFHQTrain(size=size, keys=keys) self.data = ConcatDatasetWithIndex([d1, d2]) self.coord = coord if crop_size is not None: self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) if self.coord: self.cropper = albumentations.Compose([self.cropper], additional_targets={"coord": "image"}) def __len__(self): return len(self.data) def __getitem__(self, i): ex, y = self.data[i] if hasattr(self, "cropper"): if not self.coord: out = self.cropper(image=ex["image"]) ex["image"] = out["image"] else: h,w,_ = ex["image"].shape coord = np.arange(h*w).reshape(h,w,1)/(h*w) out = self.cropper(image=ex["image"], coord=coord) ex["image"] = out["image"] ex["coord"] = out["coord"] ex["class"] = y return ex class FacesHQValidation(Dataset): # CelebAHQ [0] + FFHQ [1] def __init__(self, size, keys=None, crop_size=None, coord=False): d1 = CelebAHQValidation(size=size, keys=keys) d2 = FFHQValidation(size=size, keys=keys) self.data = ConcatDatasetWithIndex([d1, d2]) self.coord = coord if crop_size is not None: self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) if self.coord: self.cropper = albumentations.Compose([self.cropper], additional_targets={"coord": "image"}) def __len__(self): return len(self.data) def __getitem__(self, i): ex, y = self.data[i] if hasattr(self, "cropper"): if not self.coord: out = self.cropper(image=ex["image"]) ex["image"] = out["image"] else: h,w,_ = ex["image"].shape coord = np.arange(h*w).reshape(h,w,1)/(h*w) out = self.cropper(image=ex["image"], coord=coord) ex["image"] = out["image"] ex["coord"] = out["coord"] ex["class"] = y return ex