Spaces:
Build error
Build error
import os | |
import numpy as np | |
import albumentations | |
from torch.utils.data import Dataset | |
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex | |
class FacesBase(Dataset): | |
def __init__(self, *args, **kwargs): | |
super().__init__() | |
self.data = None | |
self.keys = None | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, i): | |
example = self.data[i] | |
ex = {} | |
if self.keys is not None: | |
for k in self.keys: | |
ex[k] = example[k] | |
else: | |
ex = example | |
return ex | |
class CelebAHQTrain(FacesBase): | |
def __init__(self, size, keys=None): | |
super().__init__() | |
root = "data/celebahq" | |
with open("data/celebahqtrain.txt", "r") as f: | |
relpaths = f.read().splitlines() | |
paths = [os.path.join(root, relpath) for relpath in relpaths] | |
self.data = NumpyPaths(paths=paths, size=size, random_crop=False) | |
self.keys = keys | |
class CelebAHQValidation(FacesBase): | |
def __init__(self, size, keys=None): | |
super().__init__() | |
root = "data/celebahq" | |
with open("data/celebahqvalidation.txt", "r") as f: | |
relpaths = f.read().splitlines() | |
paths = [os.path.join(root, relpath) for relpath in relpaths] | |
self.data = NumpyPaths(paths=paths, size=size, random_crop=False) | |
self.keys = keys | |
class FFHQTrain(FacesBase): | |
def __init__(self, size, keys=None): | |
super().__init__() | |
root = "data/ffhq" | |
with open("data/ffhqtrain.txt", "r") as f: | |
relpaths = f.read().splitlines() | |
paths = [os.path.join(root, relpath) for relpath in relpaths] | |
self.data = ImagePaths(paths=paths, size=size, random_crop=False) | |
self.keys = keys | |
class FFHQValidation(FacesBase): | |
def __init__(self, size, keys=None): | |
super().__init__() | |
root = "data/ffhq" | |
with open("data/ffhqvalidation.txt", "r") as f: | |
relpaths = f.read().splitlines() | |
paths = [os.path.join(root, relpath) for relpath in relpaths] | |
self.data = ImagePaths(paths=paths, size=size, random_crop=False) | |
self.keys = keys | |
class FacesHQTrain(Dataset): | |
# CelebAHQ [0] + FFHQ [1] | |
def __init__(self, size, keys=None, crop_size=None, coord=False): | |
d1 = CelebAHQTrain(size=size, keys=keys) | |
d2 = FFHQTrain(size=size, keys=keys) | |
self.data = ConcatDatasetWithIndex([d1, d2]) | |
self.coord = coord | |
if crop_size is not None: | |
self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) | |
if self.coord: | |
self.cropper = albumentations.Compose([self.cropper], | |
additional_targets={"coord": "image"}) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, i): | |
ex, y = self.data[i] | |
if hasattr(self, "cropper"): | |
if not self.coord: | |
out = self.cropper(image=ex["image"]) | |
ex["image"] = out["image"] | |
else: | |
h,w,_ = ex["image"].shape | |
coord = np.arange(h*w).reshape(h,w,1)/(h*w) | |
out = self.cropper(image=ex["image"], coord=coord) | |
ex["image"] = out["image"] | |
ex["coord"] = out["coord"] | |
ex["class"] = y | |
return ex | |
class FacesHQValidation(Dataset): | |
# CelebAHQ [0] + FFHQ [1] | |
def __init__(self, size, keys=None, crop_size=None, coord=False): | |
d1 = CelebAHQValidation(size=size, keys=keys) | |
d2 = FFHQValidation(size=size, keys=keys) | |
self.data = ConcatDatasetWithIndex([d1, d2]) | |
self.coord = coord | |
if crop_size is not None: | |
self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) | |
if self.coord: | |
self.cropper = albumentations.Compose([self.cropper], | |
additional_targets={"coord": "image"}) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, i): | |
ex, y = self.data[i] | |
if hasattr(self, "cropper"): | |
if not self.coord: | |
out = self.cropper(image=ex["image"]) | |
ex["image"] = out["image"] | |
else: | |
h,w,_ = ex["image"].shape | |
coord = np.arange(h*w).reshape(h,w,1)/(h*w) | |
out = self.cropper(image=ex["image"], coord=coord) | |
ex["image"] = out["image"] | |
ex["coord"] = out["coord"] | |
ex["class"] = y | |
return ex | |