""" Testing base class. """ import torchvision.transforms as transforms import torch.backends.cudnn as cudnn import torch.nn.functional as F import torch import numpy as np import math from . import flow_transforms class TesterBase(): def __init__(self, args): cudnn.benchmark = True self.mean_values = torch.tensor([0.411, 0.432, 0.45]).view(1, 3, 1, 1).to(args.device) self.args = args def init_dataset(self): if self.args.dataset == 'BSD500': from ..data import BSD500 # ========== Data loading code ============== input_transform = transforms.Compose([ flow_transforms.ArrayToTensor(), transforms.Normalize(mean=[0,0,0], std=[255,255,255]), transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1]) ]) val_input_transform = transforms.Compose([ flow_transforms.ArrayToTensor(), transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]), transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1]) ]) target_transform = transforms.Compose([ flow_transforms.ArrayToTensor(), ]) co_transform = flow_transforms.Compose([ flow_transforms.CenterCrop((self.args.train_img_height , self.args.train_img_width)), ]) print("=> loading img pairs from '{}'".format(self.args.data)) if self.args.crop_img == 0: train_set, val_set = BSD500(self.args.data, transform=input_transform, val_transform = val_input_transform, target_transform=target_transform) else: train_set, val_set = BSD500(self.args.data, transform=input_transform, val_transform = val_input_transform, target_transform=target_transform, co_transform=co_transform) print('{} samples found, {} train samples and {} val samples '.format(len(val_set)+len(train_set), len(train_set), len(val_set))) self.train_loader = torch.utils.data.DataLoader( train_set, batch_size=self.args.batch_size, num_workers=self.args.workers, pin_memory=True, shuffle=False, drop_last=True) elif self.args.dataset == 'texture': from ..data.texture_v3 import Dataset dataset = Dataset(self.args.data_path, crop_size=self.args.train_img_height, test=True) self.train_loader = torch.utils.data.DataLoader(dataset = dataset, batch_size = self.args.batch_size, num_workers = self.args.workers, shuffle = False, drop_last = True) else: from basicsr.data import create_dataloader, create_dataset opt = {} opt['dist'] = False opt['phase'] = 'train' opt['name'] = 'DIV2K' opt['type'] = 'PairedImageDataset' opt['dataroot_gt'] = self.args.HR_dir opt['dataroot_lq'] = self.args.LR_dir opt['filename_tmpl'] = '{}' opt['io_backend'] = dict(type='disk') opt['gt_size'] = self.args.train_img_height opt['use_flip'] = True opt['use_rot'] = True opt['use_shuffle'] = True opt['num_worker_per_gpu'] = self.args.workers opt['batch_size_per_gpu'] = self.args.batch_size opt['scale'] = int(self.args.ratio) opt['dataset_enlarge_ratio'] = 1 dataset = create_dataset(opt) self.train_loader = create_dataloader( dataset, opt, num_gpu=1, dist=opt['dist'], sampler=None) def init_testing(self): self.init_constant() self.init_dataset() self.define_model() def init_constant(self): return def define_model(self): raise NotImplementedError def display(self): raise NotImplementedError def forward(self, iteration): raise NotImplementedError def test(self): args = self.args for iteration, data in enumerate(self.train_loader): print("Iteration: {}.".format(iteration)) if args.dataset == 'BSD500': image = data[0].cuda() self.label = data[1].cuda() self.gt = None elif args.dataset == 'texture': image = data[0].cuda() self.image2 = data[1].cuda() else: image = data['lq'].cuda() self.gt = data['gt'].cuda() image = image.cuda() self.image = image self.forward() self.display(iteration) if iteration > args.niteration: break