TextureScraping / libs /test_base.py
sunshineatnoon's picture
Update libs/test_base.py
8fb0c3a
raw history blame
No virus
5.34 kB
"""
Testing base class.
"""
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch
import numpy as np
import math
from . import flow_transforms
class TesterBase():
def __init__(self, args):
cudnn.benchmark = True
self.mean_values = torch.tensor([0.411, 0.432, 0.45]).view(1, 3, 1, 1).to(args.device)
self.args = args
def init_dataset(self):
if self.args.dataset == 'BSD500':
from ..data import BSD500
# ========== Data loading code ==============
input_transform = transforms.Compose([
flow_transforms.ArrayToTensor(),
transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1])
])
val_input_transform = transforms.Compose([
flow_transforms.ArrayToTensor(),
transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1])
])
target_transform = transforms.Compose([
flow_transforms.ArrayToTensor(),
])
co_transform = flow_transforms.Compose([
flow_transforms.CenterCrop((self.args.train_img_height , self.args.train_img_width)),
])
print("=> loading img pairs from '{}'".format(self.args.data))
if self.args.crop_img == 0:
train_set, val_set = BSD500(self.args.data,
transform=input_transform,
val_transform = val_input_transform,
target_transform=target_transform)
else:
train_set, val_set = BSD500(self.args.data,
transform=input_transform,
val_transform = val_input_transform,
target_transform=target_transform,
co_transform=co_transform)
print('{} samples found, {} train samples and {} val samples '.format(len(val_set)+len(train_set), len(train_set), len(val_set)))
self.train_loader = torch.utils.data.DataLoader(
train_set, batch_size=self.args.batch_size,
num_workers=self.args.workers, pin_memory=True, shuffle=False, drop_last=True)
elif self.args.dataset == 'texture':
from ..data.texture_v3 import Dataset
dataset = Dataset(self.args.data_path, crop_size=self.args.train_img_height, test=True)
self.train_loader = torch.utils.data.DataLoader(dataset = dataset,
batch_size = self.args.batch_size,
num_workers = self.args.workers,
shuffle = False,
drop_last = True)
else:
from basicsr.data import create_dataloader, create_dataset
opt = {}
opt['dist'] = False
opt['phase'] = 'train'
opt['name'] = 'DIV2K'
opt['type'] = 'PairedImageDataset'
opt['dataroot_gt'] = self.args.HR_dir
opt['dataroot_lq'] = self.args.LR_dir
opt['filename_tmpl'] = '{}'
opt['io_backend'] = dict(type='disk')
opt['gt_size'] = self.args.train_img_height
opt['use_flip'] = True
opt['use_rot'] = True
opt['use_shuffle'] = True
opt['num_worker_per_gpu'] = self.args.workers
opt['batch_size_per_gpu'] = self.args.batch_size
opt['scale'] = int(self.args.ratio)
opt['dataset_enlarge_ratio'] = 1
dataset = create_dataset(opt)
self.train_loader = create_dataloader(
dataset, opt, num_gpu=1, dist=opt['dist'], sampler=None)
def init_testing(self):
self.init_constant()
self.init_dataset()
self.define_model()
def init_constant(self):
return
def define_model(self):
raise NotImplementedError
def display(self):
raise NotImplementedError
def forward(self, iteration):
raise NotImplementedError
def test(self):
args = self.args
for iteration, data in enumerate(self.train_loader):
print("Iteration: {}.".format(iteration))
if args.dataset == 'BSD500':
image = data[0].cuda()
self.label = data[1].cuda()
self.gt = None
elif args.dataset == 'texture':
image = data[0].cuda()
self.image2 = data[1].cuda()
else:
image = data['lq'].cuda()
self.gt = data['gt'].cuda()
image = image.cuda()
self.image = image
self.forward()
self.display(iteration)
if iteration > args.niteration:
break