Spaces:
Runtime error
Runtime error
""" | |
Testing base class. | |
""" | |
import torchvision.transforms as transforms | |
import torch.backends.cudnn as cudnn | |
import torch.nn.functional as F | |
import torch | |
import numpy as np | |
import math | |
from . import flow_transforms | |
class TesterBase(): | |
def __init__(self, args): | |
cudnn.benchmark = True | |
self.mean_values = torch.tensor([0.411, 0.432, 0.45]).view(1, 3, 1, 1).to(args.device) | |
self.args = args | |
def init_dataset(self): | |
if self.args.dataset == 'BSD500': | |
from ..data import BSD500 | |
# ========== Data loading code ============== | |
input_transform = transforms.Compose([ | |
flow_transforms.ArrayToTensor(), | |
transforms.Normalize(mean=[0,0,0], std=[255,255,255]), | |
transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1]) | |
]) | |
val_input_transform = transforms.Compose([ | |
flow_transforms.ArrayToTensor(), | |
transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]), | |
transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1]) | |
]) | |
target_transform = transforms.Compose([ | |
flow_transforms.ArrayToTensor(), | |
]) | |
co_transform = flow_transforms.Compose([ | |
flow_transforms.CenterCrop((self.args.train_img_height , self.args.train_img_width)), | |
]) | |
print("=> loading img pairs from '{}'".format(self.args.data)) | |
if self.args.crop_img == 0: | |
train_set, val_set = BSD500(self.args.data, | |
transform=input_transform, | |
val_transform = val_input_transform, | |
target_transform=target_transform) | |
else: | |
train_set, val_set = BSD500(self.args.data, | |
transform=input_transform, | |
val_transform = val_input_transform, | |
target_transform=target_transform, | |
co_transform=co_transform) | |
print('{} samples found, {} train samples and {} val samples '.format(len(val_set)+len(train_set), len(train_set), len(val_set))) | |
self.train_loader = torch.utils.data.DataLoader( | |
train_set, batch_size=self.args.batch_size, | |
num_workers=self.args.workers, pin_memory=True, shuffle=False, drop_last=True) | |
elif self.args.dataset == 'texture': | |
from ..data.texture_v3 import Dataset | |
dataset = Dataset(self.args.data_path, crop_size=self.args.train_img_height, test=True) | |
self.train_loader = torch.utils.data.DataLoader(dataset = dataset, | |
batch_size = self.args.batch_size, | |
num_workers = self.args.workers, | |
shuffle = False, | |
drop_last = True) | |
else: | |
from basicsr.data import create_dataloader, create_dataset | |
opt = {} | |
opt['dist'] = False | |
opt['phase'] = 'train' | |
opt['name'] = 'DIV2K' | |
opt['type'] = 'PairedImageDataset' | |
opt['dataroot_gt'] = self.args.HR_dir | |
opt['dataroot_lq'] = self.args.LR_dir | |
opt['filename_tmpl'] = '{}' | |
opt['io_backend'] = dict(type='disk') | |
opt['gt_size'] = self.args.train_img_height | |
opt['use_flip'] = True | |
opt['use_rot'] = True | |
opt['use_shuffle'] = True | |
opt['num_worker_per_gpu'] = self.args.workers | |
opt['batch_size_per_gpu'] = self.args.batch_size | |
opt['scale'] = int(self.args.ratio) | |
opt['dataset_enlarge_ratio'] = 1 | |
dataset = create_dataset(opt) | |
self.train_loader = create_dataloader( | |
dataset, opt, num_gpu=1, dist=opt['dist'], sampler=None) | |
def init_testing(self): | |
self.init_constant() | |
self.init_dataset() | |
self.define_model() | |
def init_constant(self): | |
return | |
def define_model(self): | |
raise NotImplementedError | |
def display(self): | |
raise NotImplementedError | |
def forward(self, iteration): | |
raise NotImplementedError | |
def test(self): | |
args = self.args | |
for iteration, data in enumerate(self.train_loader): | |
print("Iteration: {}.".format(iteration)) | |
if args.dataset == 'BSD500': | |
image = data[0].cuda() | |
self.label = data[1].cuda() | |
self.gt = None | |
elif args.dataset == 'texture': | |
image = data[0].cuda() | |
self.image2 = data[1].cuda() | |
else: | |
image = data['lq'].cuda() | |
self.gt = data['gt'].cuda() | |
image = image.cuda() | |
self.image = image | |
self.forward() | |
self.display(iteration) | |
if iteration > args.niteration: | |
break | |