from .base_dataset import BaseDataset, get_transform from PIL import Image from pathlib import Path class OneDataset(BaseDataset): """ 加载数据 加载文件夹中所有图片或直接加载指定文件 """ def __init__(self, img, opt): BaseDataset.__init__(self, opt) # self.opt = opt dataroot = img if type(dataroot) == str: dataroot = Path(dataroot) if dataroot.is_file(): self.A_path = [str(dataroot)] if dataroot.is_dir(): self.A_path = [str(i) for i in list(dataroot.iterdir())] self.A_img = [Image.open(path).convert("RGB") for path in self.A_path] else: # dataroot 传入的直接是PIL格式图片 self.A_path = [None] self.A_img = [dataroot] def __getitem__(self, idx:int): A_path = self.A_path[idx] A_img = self.A_img[idx] A = transform(A_img, self.opt) return {"A": A, "A_paths": A_path} def __len__(self): return 1 def transform(img, opt): fn_transform = get_transform(opt, grayscale=(opt.input_nc == 1)) return fn_transform(img)