AlexZou commited on
Commit
31970e5
1 Parent(s): d03bb00

Update Lowlight.py

Browse files
Files changed (1) hide show
  1. Lowlight.py +37 -36
Lowlight.py CHANGED
@@ -1,46 +1,47 @@
 
1
  import torch
 
 
 
 
2
  import torchvision
3
- import onnxruntime
4
- import onnx
5
  import cv2
 
 
6
  import argparse
7
- import warnings
8
- import numpy as np
9
- import matplotlib.pyplot as plt
10
- import os
11
-
12
- parser = argparse.ArgumentParser()
13
- parser.add_argument('--test_path', type=str, default='/home/arye-stark/zwb/Illumination-Adaptive-Transformer/IAT_enhance/demo_imgs/low_demo.jpg')
14
- parser.add_argument('--pk_path', type=str, default='model_zoo/Low.onnx')
15
- parser.add_argument('--save_path', type=str, default='Results/')
16
- config = parser.parse_args()
17
 
18
- if not os.path.isdir(config.save_path):
19
- os.mkdir(config.save_path)
 
 
 
 
 
20
 
21
- img = plt.imread(config.test_path)
22
- input_image = np.asarray(img) / 255.0
23
- input_image = torch.from_numpy(input_image).float()
24
- input_image = input_image.permute(2, 0, 1).unsqueeze(0)
25
- input_image = input_image.numpy()
 
26
 
27
- providers = ['CPUExecutionProvider']
28
- model_name = 'IAT'
29
 
30
- print('-' * 50)
31
- try:
32
- onnx_session = onnxruntime.InferenceSession(config.pk_path, providers=providers)
33
- onnx_input = {'input': input_image}
34
- #onnx_output0, onnx_output1, onnx_output2 = onnx_session.run(['output0', 'output1', 'output2'], onnx_input)
35
- onnx_output = onnx_session.run(['output'], onnx_input)
36
- torchvision.utils.save_image(torch.from_numpy(onnx_output[0]), config.save_path+'output.png')
37
- #torch_output = np.squeeze(onnx_output[0], 0)
38
- #torch_output = np.transpose(torch_output * 255, [1, 2, 0]).astype(np.uint8)
39
- #plt.imsave(config.save_path+'output.png', torch_output)
40
- except Exception as e:
41
- print(f'Input on model:{model_name} failed')
42
- print(e)
43
- else:
44
- print(f'Input on model:{model_name} succeed')
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
1
+ import os
2
  import torch
3
+ import numpy as np
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import time
7
  import torchvision
 
 
8
  import cv2
9
+ import torchvision.utils as tvu
10
+ import torch.functional as F
11
  import argparse
12
+ from model.IAT_main import IAT
 
 
 
 
 
 
 
 
 
13
 
14
+ def inference_img(img_path,Net):
15
+
16
+ low_image = Image.open(img_path).convert('RGB')
17
+ enhance_transforms = transforms.Compose([
18
+ transforms.Resize((256,256)),
19
+ transforms.ToTensor()
20
+ ])
21
 
22
+ with torch.no_grad():
23
+ low_image = enhance_transforms(low_image)
24
+ low_image = low_image.unsqueeze(0)
25
+ start = time.time()
26
+ restored2 = Net(low_image)
27
+ end = time.time()
28
 
 
 
29
 
30
+ return restored2,end-start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ if __name__ == '__main__':
33
+ parser=argparse.ArgumentParser()
34
+ parser.add_argument('--test_path',type=str,required=True,help='Path to test')
35
+ parser.add_argument('--save_path',type=str,required=True,help='Path to save')
36
+ parser.add_argument('--pk_path',type=str,default='model_zoo/underwater.pth',help='Path of the checkpoint')
37
+ opt = parser.parse_args()
38
+ if not os.path.isdir(opt.save_path):
39
+ os.mkdir(opt.save_path)
40
+ Net = IAT()
41
+ Net.load_state_dict(torch.load(opt.pk_path, map_location=torch.device('cpu')))
42
+ Net = Net.eval()
43
+ image = opt.test_path
44
+ print(image)
45
+ restored2,time_num = inference_img(image,Net)
46
+ torchvision.utils.save_image(restored2,opt.save_path+'output.png')
47