goldpulpy's picture
Upload model and app
f9e4a6c
import torch
import os
from PIL import Image
import random
import numpy as np
import pickle
import torchvision.transforms as transforms
from .celeba import CelebADataset
def create_dataloader(opt):
data_loader = DataLoader()
data_loader.initialize(opt)
return data_loader
class DataLoader:
def name(self):
return self.dataset.name() + "_Loader"
def create_datase(self):
# specify which dataset to load here
loaded_dataset = os.path.basename(self.opt.data_root.strip('/')).lower()
if 'celeba' in loaded_dataset or 'emotion' in loaded_dataset:
dataset = CelebADataset()
else:
dataset = BaseDataset()
dataset.initialize(self.opt)
return dataset
def initialize(self, opt):
self.opt = opt
self.dataset = self.create_datase()
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.n_threads)
)
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data