rcan / eval_onnx.py
zhengrongzhang's picture
change onnx to NHWC (#1)
30c4d88
raw
history blame
No virus
3.25 kB
import torch
import os
import sys
import pathlib
CURRENT_DIR = pathlib.Path(__file__).parent
sys.path.append(str(CURRENT_DIR))
from tqdm import tqdm
import data
import metric
import onnxruntime
import cv2
from data.data_tiling import tiling_inference
import argparse
class Configs():
def __init__(self):
parser = argparse.ArgumentParser(description='SR')
# ipu test or cpu, you need to provide onnx path
parser.add_argument('--ipu', action='store_true',
help='use ipu')
parser.add_argument('--onnx_path', type=str, default='RCAN_int8_NHWC.onnx',
help='onnx path')
parser.add_argument('--provider_config', type=str, default=None,
help='provider config path')
# Data specifications, you can use default
parser.add_argument('--dir_data', type=str, default='dataset/',
help='dataset directory')
parser.add_argument('--data_test', type=str, default='Set5',
help='test dataset name')
parser.add_argument('--n_threads', type=int, default=6,
help='number of threads for data loading')
parser.add_argument('--scale', type=str, default='2',
help='super resolution scale, now only support x2')
self.parser = parser
def parse(self):
args = self.parser.parse_args()
args.scale = list(map(lambda x: int(x), args.scale.split('+')))
args.data_test = args.data_test.split('+')
print(args)
return args
def quantize(img, rgb_range): # clamp pix to rgb range
pixel_range = 255 / rgb_range
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
def test_model(session, loader, device):
torch.set_grad_enabled(False)
self_scale = [2]
for idx_data, d in enumerate(loader.loader_test):
eval_ssim = 0
eval_psnr = 0
for idx_scale, scale in enumerate(self_scale):
d.dataset.set_scale(idx_scale)
for lr, hr, filename in tqdm(d, ncols=80):
sr = tiling_inference(session, lr.cpu().numpy(), 8, (56, 56))
sr = torch.from_numpy(sr).to(device)
sr = quantize(sr, 255)
eval_psnr += metric.calc_psnr(
sr, hr, scale, 255, benchmark=d)
eval_ssim += metric.calc_ssim(
sr, hr, scale, 255, dataset=d)
mean_ssim = eval_ssim / len(d)
mean_psnr = eval_psnr / len(d)
print("psnr: %s, ssim: %s"%(mean_psnr, mean_ssim))
return mean_psnr, mean_ssim
def main(args):
loader = data.Data(args)
if args.ipu:
providers = ["VitisAIExecutionProvider"]
provider_options = [{"config_file": args.provider_config}]
else:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
provider_options = None
onnx_file_name = args.onnx_path
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
test_model(ort_session, loader, device="cpu")
if __name__ == '__main__':
args = Configs().parse()
main(args)