Spaces:
Paused
Paused
import torch | |
import os | |
from PIL import Image | |
import random | |
import numpy as np | |
import pickle | |
import torchvision.transforms as transforms | |
from .celeba import CelebADataset | |
def create_dataloader(opt): | |
data_loader = DataLoader() | |
data_loader.initialize(opt) | |
return data_loader | |
class DataLoader: | |
def name(self): | |
return self.dataset.name() + "_Loader" | |
def create_datase(self): | |
# specify which dataset to load here | |
loaded_dataset = os.path.basename(self.opt.data_root.strip('/')).lower() | |
if 'celeba' in loaded_dataset or 'emotion' in loaded_dataset: | |
dataset = CelebADataset() | |
else: | |
dataset = BaseDataset() | |
dataset.initialize(self.opt) | |
return dataset | |
def initialize(self, opt): | |
self.opt = opt | |
self.dataset = self.create_datase() | |
self.dataloader = torch.utils.data.DataLoader( | |
self.dataset, | |
batch_size=opt.batch_size, | |
shuffle=not opt.serial_batches, | |
num_workers=int(opt.n_threads) | |
) | |
def __len__(self): | |
return min(len(self.dataset), self.opt.max_dataset_size) | |
def __iter__(self): | |
for i, data in enumerate(self.dataloader): | |
if i * self.opt.batch_size >= self.opt.max_dataset_size: | |
break | |
yield data | |