|
|
|
|
|
|
|
import os |
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
|
|
from utils.data_utils import make_dataset |
|
|
|
|
|
class ImagesDataset(Dataset): |
|
|
|
def __init__(self, source_root, source_transform=None): |
|
self.source_paths = sorted(make_dataset(source_root)) |
|
self.source_transform = source_transform |
|
|
|
def __len__(self): |
|
return len(self.source_paths) |
|
|
|
def __getitem__(self, index): |
|
fname, from_path = self.source_paths[index] |
|
from_im = Image.open(from_path).convert('RGB') |
|
|
|
if self.source_transform: |
|
from_im = self.source_transform(from_im) |
|
|
|
return fname, from_im |
|
|
|
|