import torch import torchvision.transforms as transforms import torchvision.datasets as dset class Invert: def __call__(self, x): return 1 - x class Gray: def __call__(self, x): return x[0:1] def load_dataset(dataset_name, split='full'): if dataset_name == 'mnist': dataset = dset.MNIST( root='data/mnist', download=True, transform=transforms.Compose([ transforms.ToTensor(), ]) ) return dataset elif dataset_name == 'coco': dataset = dset.ImageFolder(root='data/coco', transform=transforms.Compose([ transforms.Scale(64), transforms.CenterCrop(64), transforms.ToTensor(), ])) return dataset elif dataset_name == 'quickdraw': X = (np.load('data/quickdraw/teapot.npy')) X = X.reshape((X.shape[0], 28, 28)) X = X / 255. X = X.astype(np.float32) X = torch.from_numpy(X) dataset = TensorDataset(X, X) return dataset elif dataset_name == 'shoes': dataset = dset.ImageFolder(root='data/shoes/ut-zap50k-images/Shoes', transform=transforms.Compose([ transforms.Scale(64), transforms.CenterCrop(64), transforms.ToTensor(), ])) return dataset elif dataset_name == 'footwear': dataset = dset.ImageFolder(root='data/shoes/ut-zap50k-images', transform=transforms.Compose([ transforms.Scale(64), transforms.CenterCrop(64), transforms.ToTensor(), ])) return dataset elif dataset_name == 'celeba': dataset = dset.ImageFolder(root='data/celeba', transform=transforms.Compose([ transforms.Scale(32), transforms.CenterCrop(32), transforms.ToTensor(), ])) return dataset elif dataset_name == 'birds': dataset = dset.ImageFolder(root='data/birds/'+split, transform=transforms.Compose([ transforms.Scale(32), transforms.CenterCrop(32), transforms.ToTensor(), ])) return dataset elif dataset_name == 'sketchy': dataset = dset.ImageFolder(root='data/sketchy/'+split, transform=transforms.Compose([ transforms.Scale(64), transforms.CenterCrop(64), transforms.ToTensor(), Gray() ])) return dataset elif dataset_name == 'fonts': dataset = dset.ImageFolder(root='data/fonts/'+split, transform=transforms.Compose([ transforms.ToTensor(), Invert(), Gray(), ])) return dataset else: raise ValueError('Error : unknown dataset')