|
import os
|
|
from data import srdata
|
|
|
|
class DIV2K(srdata.SRData):
|
|
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
|
|
data_range = [r.split('-') for r in args.data_range.split('/')]
|
|
if train:
|
|
data_range = data_range[0]
|
|
else:
|
|
if args.test_only and len(data_range) == 1:
|
|
data_range = data_range[0]
|
|
else:
|
|
data_range = data_range[1]
|
|
|
|
self.begin, self.end = list(map(lambda x: int(x), data_range))
|
|
super(DIV2K, self).__init__(
|
|
args, name=name, train=train, benchmark=benchmark
|
|
)
|
|
|
|
def _scan(self):
|
|
names_hr, names_lr = super(DIV2K, self)._scan()
|
|
names_hr = names_hr[self.begin - 1:self.end]
|
|
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
|
|
|
|
return names_hr, names_lr
|
|
|
|
def _set_filesystem(self, dir_data):
|
|
super(DIV2K, self)._set_filesystem(dir_data)
|
|
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
|
|
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
|
|
if self.input_large: self.dir_lr += 'L'
|
|
|
|
|