import os | |
from data import common | |
import numpy as np | |
import imageio | |
import torch | |
import torch.utils.data as data | |
class Demo(data.Dataset): | |
def __init__(self, args, name='Demo', train=False, benchmark=False): | |
self.args = args | |
self.name = name | |
self.scale = args.scale | |
self.idx_scale = 0 | |
self.train = False | |
self.benchmark = benchmark | |
self.filelist = [] | |
for f in os.listdir(args.dir_demo): | |
if f.find('.png') >= 0 or f.find('.jp') >= 0: | |
self.filelist.append(os.path.join(args.dir_demo, f)) | |
self.filelist.sort() | |
def __getitem__(self, idx): | |
filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] | |
lr = imageio.imread(self.filelist[idx]) | |
lr, = common.set_channel(lr, n_channels=self.args.n_colors) | |
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) | |
return lr_t, -1, filename | |
def __len__(self): | |
return len(self.filelist) | |
def set_scale(self, idx_scale): | |
self.idx_scale = idx_scale | |