|
import torch |
|
import os |
|
import sys |
|
import pathlib |
|
CURRENT_DIR = pathlib.Path(__file__).parent |
|
sys.path.append(str(CURRENT_DIR)) |
|
from tqdm import tqdm |
|
import data |
|
import metric |
|
import onnxruntime |
|
import cv2 |
|
from data.data_tiling import tiling_inference |
|
import argparse |
|
|
|
class Configs(): |
|
def __init__(self): |
|
parser = argparse.ArgumentParser(description='SR') |
|
|
|
|
|
parser.add_argument('--ipu', action='store_true', |
|
help='use ipu') |
|
parser.add_argument('--onnx_path', type=str, default='RCAN_int8_NHWC.onnx', |
|
help='onnx path') |
|
parser.add_argument('--provider_config', type=str, default=None, |
|
help='provider config path') |
|
|
|
parser.add_argument('--dir_data', type=str, default='dataset/', |
|
help='dataset directory') |
|
parser.add_argument('--data_test', type=str, default='Set5', |
|
help='test dataset name') |
|
|
|
parser.add_argument('--n_threads', type=int, default=6, |
|
help='number of threads for data loading') |
|
parser.add_argument('--scale', type=str, default='2', |
|
help='super resolution scale, now only support x2') |
|
self.parser = parser |
|
|
|
def parse(self): |
|
args = self.parser.parse_args() |
|
args.scale = list(map(lambda x: int(x), args.scale.split('+'))) |
|
args.data_test = args.data_test.split('+') |
|
print(args) |
|
return args |
|
|
|
|
|
|
|
def quantize(img, rgb_range): |
|
pixel_range = 255 / rgb_range |
|
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) |
|
|
|
def test_model(session, loader, device): |
|
torch.set_grad_enabled(False) |
|
self_scale = [2] |
|
for idx_data, d in enumerate(loader.loader_test): |
|
eval_ssim = 0 |
|
eval_psnr = 0 |
|
for idx_scale, scale in enumerate(self_scale): |
|
d.dataset.set_scale(idx_scale) |
|
for lr, hr, filename in tqdm(d, ncols=80): |
|
sr = tiling_inference(session, lr.cpu().numpy(), 8, (56, 56)) |
|
sr = torch.from_numpy(sr).to(device) |
|
sr = quantize(sr, 255) |
|
eval_psnr += metric.calc_psnr( |
|
sr, hr, scale, 255, benchmark=d) |
|
eval_ssim += metric.calc_ssim( |
|
sr, hr, scale, 255, dataset=d) |
|
mean_ssim = eval_ssim / len(d) |
|
mean_psnr = eval_psnr / len(d) |
|
print("psnr: %s, ssim: %s"%(mean_psnr, mean_ssim)) |
|
return mean_psnr, mean_ssim |
|
|
|
def main(args): |
|
loader = data.Data(args) |
|
if args.ipu: |
|
providers = ["VitisAIExecutionProvider"] |
|
provider_options = [{"config_file": args.provider_config}] |
|
else: |
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
|
provider_options = None |
|
onnx_file_name = args.onnx_path |
|
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options) |
|
test_model(ort_session, loader, device="cpu") |
|
|
|
|
|
if __name__ == '__main__': |
|
args = Configs().parse() |
|
main(args) |
|
|