|
from .base_dataset import BaseDataset, get_transform |
|
from PIL import Image |
|
from pathlib import Path |
|
|
|
|
|
class OneDataset(BaseDataset): |
|
""" |
|
加载数据 |
|
|
|
加载文件夹中所有图片或直接加载指定文件 |
|
""" |
|
|
|
def __init__(self, img, opt): |
|
BaseDataset.__init__(self, opt) |
|
|
|
dataroot = img |
|
if type(dataroot) == str: |
|
dataroot = Path(dataroot) |
|
if dataroot.is_file(): |
|
self.A_path = [str(dataroot)] |
|
if dataroot.is_dir(): |
|
self.A_path = [str(i) for i in list(dataroot.iterdir())] |
|
self.A_img = [Image.open(path).convert("RGB") for path in self.A_path] |
|
else: |
|
self.A_path = [None] |
|
self.A_img = [dataroot] |
|
|
|
def __getitem__(self, idx:int): |
|
A_path = self.A_path[idx] |
|
A_img = self.A_img[idx] |
|
A = transform(A_img, self.opt) |
|
return {"A": A, "A_paths": A_path} |
|
|
|
def __len__(self): |
|
return 1 |
|
|
|
|
|
def transform(img, opt): |
|
fn_transform = get_transform(opt, grayscale=(opt.input_nc == 1)) |
|
return fn_transform(img) |
|
|