|
import os
|
|
import random
|
|
import imageio
|
|
import numpy as np
|
|
import torch.utils.data as data
|
|
|
|
from data import common
|
|
|
|
from utils import interact
|
|
|
|
class Dataset(data.Dataset):
|
|
"""Basic dataloader class
|
|
"""
|
|
def __init__(self, args, mode='train'):
|
|
super(Dataset, self).__init__()
|
|
self.args = args
|
|
self.mode = mode
|
|
|
|
self.modes = ()
|
|
self.set_modes()
|
|
self._check_mode()
|
|
|
|
self.set_keys()
|
|
|
|
if self.mode == 'train':
|
|
dataset = args.data_train
|
|
elif self.mode == 'val':
|
|
dataset = args.data_val
|
|
elif self.mode == 'test':
|
|
dataset = args.data_test
|
|
elif self.mode == 'demo':
|
|
pass
|
|
else:
|
|
raise NotImplementedError('not implemented for this mode: {}!'.format(self.mode))
|
|
|
|
if self.mode == 'demo':
|
|
self.subset_root = args.demo_input_dir
|
|
else:
|
|
self.subset_root = os.path.join(args.data_root, dataset, self.mode)
|
|
|
|
self.blur_list = []
|
|
self.sharp_list = []
|
|
|
|
self._scan()
|
|
|
|
def set_modes(self):
|
|
self.modes = ('train', 'val', 'test', 'demo')
|
|
|
|
def _check_mode(self):
|
|
"""Should be called in the child class __init__() after super
|
|
"""
|
|
if self.mode not in self.modes:
|
|
raise NotImplementedError('mode error: not for {}'.format(self.mode))
|
|
|
|
return
|
|
|
|
def set_keys(self):
|
|
self.blur_key = 'blur'
|
|
self.sharp_key = 'sharp'
|
|
|
|
self.non_blur_keys = []
|
|
self.non_sharp_keys = []
|
|
|
|
return
|
|
|
|
def _scan(self, root=None):
|
|
"""Should be called in the child class __init__() after super
|
|
"""
|
|
if root is None:
|
|
root = self.subset_root
|
|
|
|
if self.blur_key in self.non_blur_keys:
|
|
self.non_blur_keys.remove(self.blur_key)
|
|
if self.sharp_key in self.non_sharp_keys:
|
|
self.non_sharp_keys.remove(self.sharp_key)
|
|
|
|
def _key_check(path, true_key, false_keys):
|
|
path = os.path.join(path, '')
|
|
if path.find(true_key) >= 0:
|
|
for false_key in false_keys:
|
|
if path.find(false_key) >= 0:
|
|
return False
|
|
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def _get_list_by_key(root, true_key, false_keys):
|
|
data_list = []
|
|
for sub, dirs, files in os.walk(root):
|
|
if not dirs:
|
|
file_list = [os.path.join(sub, f) for f in files]
|
|
if _key_check(sub, true_key, false_keys):
|
|
data_list += file_list
|
|
|
|
data_list.sort()
|
|
|
|
return data_list
|
|
|
|
def _rectify_keys():
|
|
self.blur_key = os.path.join(self.blur_key, '')
|
|
self.non_blur_keys = [os.path.join(non_blur_key, '') for non_blur_key in self.non_blur_keys]
|
|
self.sharp_key = os.path.join(self.sharp_key, '')
|
|
self.non_sharp_keys = [os.path.join(non_sharp_key, '') for non_sharp_key in self.non_sharp_keys]
|
|
|
|
_rectify_keys()
|
|
|
|
self.blur_list = _get_list_by_key(root, self.blur_key, self.non_blur_keys)
|
|
self.sharp_list = _get_list_by_key(root, self.sharp_key, self.non_sharp_keys)
|
|
|
|
if len(self.sharp_list) > 0:
|
|
assert(len(self.blur_list) == len(self.sharp_list))
|
|
|
|
return
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
blur = imageio.imread(self.blur_list[idx], pilmode='RGB')
|
|
if len(self.sharp_list) > 0:
|
|
sharp = imageio.imread(self.sharp_list[idx], pilmode='RGB')
|
|
imgs = [blur, sharp]
|
|
else:
|
|
imgs = [blur]
|
|
|
|
pad_width = 0
|
|
if self.mode == 'train':
|
|
imgs = common.crop(*imgs, ps=self.args.patch_size)
|
|
if self.args.augment:
|
|
imgs = common.augment(*imgs, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=self.args.rgb_range)
|
|
imgs[0] = common.add_noise(imgs[0], sigma_sigma=2, rgb_range=self.args.rgb_range)
|
|
elif self.mode == 'demo':
|
|
imgs[0], pad_width = common.pad(imgs[0], divisor=2**(self.args.n_scales-1))
|
|
else:
|
|
pass
|
|
|
|
if self.args.gaussian_pyramid:
|
|
imgs = common.generate_pyramid(*imgs, n_scales=self.args.n_scales)
|
|
|
|
imgs = common.np2tensor(*imgs, rgb_range=self.args.rgb_range)
|
|
relpath = os.path.relpath(self.blur_list[idx], self.subset_root)
|
|
|
|
blur = imgs[0]
|
|
sharp = imgs[1] if len(imgs) > 1 else False
|
|
|
|
return blur, sharp, pad_width, idx, relpath
|
|
|
|
def __len__(self):
|
|
return len(self.blur_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|