CycleGAN / data /one_dataset.py
Yanguan's picture
0
58da73e
from .base_dataset import BaseDataset, get_transform
from PIL import Image
from pathlib import Path
class OneDataset(BaseDataset):
"""
加载数据
加载文件夹中所有图片或直接加载指定文件
"""
def __init__(self, img, opt):
BaseDataset.__init__(self, opt)
# self.opt = opt
dataroot = img
if type(dataroot) == str:
dataroot = Path(dataroot)
if dataroot.is_file():
self.A_path = [str(dataroot)]
if dataroot.is_dir():
self.A_path = [str(i) for i in list(dataroot.iterdir())]
self.A_img = [Image.open(path).convert("RGB") for path in self.A_path]
else: # dataroot 传入的直接是PIL格式图片
self.A_path = [None]
self.A_img = [dataroot]
def __getitem__(self, idx:int):
A_path = self.A_path[idx]
A_img = self.A_img[idx]
A = transform(A_img, self.opt)
return {"A": A, "A_paths": A_path}
def __len__(self):
return 1
def transform(img, opt):
fn_transform = get_transform(opt, grayscale=(opt.input_nc == 1))
return fn_transform(img)