ae_gen / data.py
mehdidc's picture
add app and generation / model code
fa128ec
raw history blame
No virus
2.85 kB
import torch
import torchvision.transforms as transforms
import torchvision.datasets as dset
class Invert:
def __call__(self, x):
return 1 - x
class Gray:
def __call__(self, x):
return x[0:1]
def load_dataset(dataset_name, split='full'):
if dataset_name == 'mnist':
dataset = dset.MNIST(
root='data/mnist',
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
])
)
return dataset
elif dataset_name == 'coco':
dataset = dset.ImageFolder(root='data/coco',
transform=transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
]))
return dataset
elif dataset_name == 'quickdraw':
X = (np.load('data/quickdraw/teapot.npy'))
X = X.reshape((X.shape[0], 28, 28))
X = X / 255.
X = X.astype(np.float32)
X = torch.from_numpy(X)
dataset = TensorDataset(X, X)
return dataset
elif dataset_name == 'shoes':
dataset = dset.ImageFolder(root='data/shoes/ut-zap50k-images/Shoes',
transform=transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
]))
return dataset
elif dataset_name == 'footwear':
dataset = dset.ImageFolder(root='data/shoes/ut-zap50k-images',
transform=transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
]))
return dataset
elif dataset_name == 'celeba':
dataset = dset.ImageFolder(root='data/celeba',
transform=transforms.Compose([
transforms.Scale(32),
transforms.CenterCrop(32),
transforms.ToTensor(),
]))
return dataset
elif dataset_name == 'birds':
dataset = dset.ImageFolder(root='data/birds/'+split,
transform=transforms.Compose([
transforms.Scale(32),
transforms.CenterCrop(32),
transforms.ToTensor(),
]))
return dataset
elif dataset_name == 'sketchy':
dataset = dset.ImageFolder(root='data/sketchy/'+split,
transform=transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
Gray()
]))
return dataset
elif dataset_name == 'fonts':
dataset = dset.ImageFolder(root='data/fonts/'+split,
transform=transforms.Compose([
transforms.ToTensor(),
Invert(),
Gray(),
]))
return dataset
else:
raise ValueError('Error : unknown dataset')