Spaces:
Runtime error
Runtime error
import random | |
from swapae.data.base_dataset import BaseDataset, get_transform | |
from swapae.data.image_folder import make_dataset | |
from PIL import Image | |
class ImageFolderDataset(BaseDataset): | |
def __init__(self, opt): | |
BaseDataset.__init__(self, opt) | |
self.dir_A = opt.dataroot | |
self.A_paths = sorted(make_dataset(self.dir_A)) | |
self.A_size = len(self.A_paths) | |
self.transform_A = get_transform(self.opt, grayscale=False) | |
def __getitem__(self, index): | |
A_path = self.A_paths[index % self.A_size] | |
return self.getitem_by_path(A_path) | |
def getitem_by_path(self, A_path): | |
try: | |
A_img = Image.open(A_path).convert('RGB') | |
except OSError as err: | |
print(err) | |
return self.__getitem__(random.randint(0, len(self) - 1)) | |
# apply image transformation | |
A = self.transform_A(A_img) | |
return {'real_A': A, 'path_A': A_path} | |
def __len__(self): | |
return self.A_size | |