Deploy_Restoration / Lowlight.py
AlexZou's picture
Upload 4 files
c4d8d8b
raw
history blame
No virus
1.52 kB
import torch
import onnxruntime
import onnx
import cv2
import argparse
import warnings
import numpy as np
import matplotlib.pyplot as plt
import os
parser = argparse.ArgumentParser()
parser.add_argument('--test_path', type=str, default='/home/arye-stark/zwb/Illumination-Adaptive-Transformer/IAT_enhance/demo_imgs/low_demo.jpg')
parser.add_argument('--pk_path', type=str, default='model_zoo/Low.onnx')
parser.add_argument('--save_path', type=str, default='Results/')
config = parser.parse_args()
if not os.path.isdir(config.save_path):
os.mkdir(config.save_path)
img = plt.imread(config.test_path)
input_image = np.asarray(img) / 255.0
input_image = torch.from_numpy(input_image).float()
input_image = input_image.permute(2, 0, 1).unsqueeze(0)
input_image = input_image.numpy()
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
model_name = 'IAT'
print('-' * 50)
try:
onnx_session = onnxruntime.InferenceSession(config.pk_path, providers=providers)
onnx_input = {'input': input_image}
#onnx_output0, onnx_output1, onnx_output2 = onnx_session.run(['output0', 'output1', 'output2'], onnx_input)
onnx_output = onnx_session.run(['output'], onnx_input)
torch_output = np.squeeze(onnx_output[0], 0)
torch_output = np.transpose(torch_output * 255, [1, 2, 0]).astype(np.uint8)
plt.imsave(config.save_path+'output.png', torch_output)
except Exception as e:
print(f'Input on model:{model_name} failed')
print(e)
else:
print(f'Input on model:{model_name} succeed')