|
from data.dataset import Dataset |
|
|
|
from utils import interact |
|
|
|
class REDS(Dataset): |
|
"""REDS train, val, test subset class |
|
""" |
|
def __init__(self, args, mode='train'): |
|
super(REDS, self).__init__(args, mode) |
|
|
|
def set_modes(self): |
|
self.modes = ('train', 'val', 'test') |
|
|
|
def set_keys(self): |
|
super(REDS, self).set_keys() |
|
|
|
|
|
|
|
self.non_blur_keys = ['blur', 'blur_comp', 'blur_bicubic'] |
|
self.non_blur_keys.remove(self.blur_key) |
|
self.non_sharp_keys = ['sharp_bicubic', 'sharp'] |
|
self.non_sharp_keys.remove(self.sharp_key) |
|
|
|
def __getitem__(self, idx): |
|
blur, sharp, pad_width, idx, relpath = super(REDS, self).__getitem__(idx) |
|
relpath = relpath.replace('{}/{}/'.format(self.mode, self.blur_key), '') |
|
|
|
return blur, sharp, pad_width, idx, relpath |
|
|