import torch import os import sys import pathlib CURRENT_DIR = pathlib.Path(__file__).parent sys.path.append(str(CURRENT_DIR)) from tqdm import tqdm import data import metric import onnxruntime import cv2 from data.data_tiling import tiling_inference import argparse class Configs(): def __init__(self): parser = argparse.ArgumentParser(description='SR') # ipu test or cpu, you need to provide onnx path parser.add_argument('--ipu', action='store_true', help='use ipu') parser.add_argument('--onnx_path', type=str, default='RCAN_int8_NHWC.onnx', help='onnx path') parser.add_argument('--provider_config', type=str, default=None, help='provider config path') # Data specifications, you can use default parser.add_argument('--dir_data', type=str, default='dataset/', help='dataset directory') parser.add_argument('--data_test', type=str, default='Set5', help='test dataset name') parser.add_argument('--n_threads', type=int, default=6, help='number of threads for data loading') parser.add_argument('--scale', type=str, default='2', help='super resolution scale, now only support x2') self.parser = parser def parse(self): args = self.parser.parse_args() args.scale = list(map(lambda x: int(x), args.scale.split('+'))) args.data_test = args.data_test.split('+') print(args) return args def quantize(img, rgb_range): # clamp pix to rgb range pixel_range = 255 / rgb_range return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) def test_model(session, loader, device): torch.set_grad_enabled(False) self_scale = [2] for idx_data, d in enumerate(loader.loader_test): eval_ssim = 0 eval_psnr = 0 for idx_scale, scale in enumerate(self_scale): d.dataset.set_scale(idx_scale) for lr, hr, filename in tqdm(d, ncols=80): sr = tiling_inference(session, lr.cpu().numpy(), 8, (56, 56)) sr = torch.from_numpy(sr).to(device) sr = quantize(sr, 255) eval_psnr += metric.calc_psnr( sr, hr, scale, 255, benchmark=d) eval_ssim += metric.calc_ssim( sr, hr, scale, 255, dataset=d) mean_ssim = eval_ssim / len(d) mean_psnr = eval_psnr / len(d) print("psnr: %s, ssim: %s"%(mean_psnr, mean_ssim)) return mean_psnr, mean_ssim def main(args): loader = data.Data(args) if args.ipu: providers = ["VitisAIExecutionProvider"] provider_options = [{"config_file": args.provider_config}] else: providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] provider_options = None onnx_file_name = args.onnx_path ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options) test_model(ort_session, loader, device="cpu") if __name__ == '__main__': args = Configs().parse() main(args)