File size: 981 Bytes
ed697ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from torch.utils.data import Dataset
from PIL import Image
from utils import data_utils
class ImagesDataset(Dataset):
def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
self.source_paths = sorted(data_utils.make_dataset(source_root))
self.target_paths = sorted(data_utils.make_dataset(target_root))
self.source_transform = source_transform
self.target_transform = target_transform
self.opts = opts
def __len__(self):
return len(self.source_paths)
def __getitem__(self, index):
from_path = self.source_paths[index]
from_im = Image.open(from_path)
from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L')
to_path = self.target_paths[index]
to_im = Image.open(to_path).convert('RGB')
if self.target_transform:
to_im = self.target_transform(to_im)
if self.source_transform:
from_im = self.source_transform(from_im)
else:
from_im = to_im
return from_im, to_im
|