|
from importlib import import_module
|
|
|
|
from torch.utils.data import dataloader
|
|
from torch.utils.data import ConcatDataset
|
|
|
|
|
|
class MyConcatDataset(ConcatDataset):
|
|
def __init__(self, datasets):
|
|
super(MyConcatDataset, self).__init__(datasets)
|
|
self.train = datasets[0].train
|
|
|
|
def set_scale(self, idx_scale):
|
|
for d in self.datasets:
|
|
if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
|
|
|
|
class Data:
|
|
def __init__(self, args):
|
|
self.loader_train = None
|
|
if not args.test_only:
|
|
datasets = []
|
|
for d in args.data_train:
|
|
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
|
|
m = import_module('data.' + module_name.lower())
|
|
datasets.append(getattr(m, module_name)(args, name=d))
|
|
|
|
self.loader_train = dataloader.DataLoader(
|
|
MyConcatDataset(datasets),
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
pin_memory=not args.cpu,
|
|
num_workers=args.n_threads,
|
|
)
|
|
|
|
self.loader_test = []
|
|
for d in args.data_test:
|
|
if d in ['CBSD68','Kodak24','McM','Set5', 'Set14', 'B100', 'Urban100']:
|
|
m = import_module('data.benchmark')
|
|
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
|
|
else:
|
|
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
|
|
m = import_module('data.' + module_name.lower())
|
|
testset = getattr(m, module_name)(args, train=False, name=d)
|
|
|
|
self.loader_test.append(
|
|
dataloader.DataLoader(
|
|
testset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
pin_memory=not args.cpu,
|
|
num_workers=args.n_threads,
|
|
)
|
|
)
|
|
|