import torch import onnxruntime import onnx import cv2 import argparse import warnings import numpy as np import matplotlib.pyplot as plt import os parser = argparse.ArgumentParser() parser.add_argument('--test_path', type=str, default='/home/arye-stark/zwb/Illumination-Adaptive-Transformer/IAT_enhance/demo_imgs/low_demo.jpg') parser.add_argument('--pk_path', type=str, default='model_zoo/Low.onnx') parser.add_argument('--save_path', type=str, default='Results/') config = parser.parse_args() if not os.path.isdir(config.save_path): os.mkdir(config.save_path) img = plt.imread(config.test_path) input_image = np.asarray(img) / 255.0 input_image = torch.from_numpy(input_image).float() input_image = input_image.permute(2, 0, 1).unsqueeze(0) input_image = input_image.numpy() providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] model_name = 'IAT' print('-' * 50) try: onnx_session = onnxruntime.InferenceSession(config.pk_path, providers=providers) onnx_input = {'input': input_image} #onnx_output0, onnx_output1, onnx_output2 = onnx_session.run(['output0', 'output1', 'output2'], onnx_input) onnx_output = onnx_session.run(['output'], onnx_input) torch_output = np.squeeze(onnx_output[0], 0) torch_output = np.transpose(torch_output * 255, [1, 2, 0]).astype(np.uint8) plt.imsave(config.save_path+'output.png', torch_output) except Exception as e: print(f'Input on model:{model_name} failed') print(e) else: print(f'Input on model:{model_name} succeed')