AlexZou commited on
Commit
1951449
1 Parent(s): af8dd52

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dehazing.py +45 -0
  2. Lowlight.py +44 -0
  3. SuperResolution.py +47 -0
Dehazing.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import time
7
+ import torchvision
8
+ import cv2
9
+ import torchvision.utils as tvu
10
+ import torch.functional as F
11
+ import argparse
12
+
13
+ def inference_img(haze_path,Net):
14
+
15
+ haze_image = Image.open(haze_path).convert('RGB')
16
+ enhance_transforms = transforms.Compose([
17
+ transforms.Resize((400,400)),
18
+ transforms.ToTensor()
19
+ ])
20
+
21
+ print(haze_image.size)
22
+ with torch.no_grad():
23
+ haze_image = enhance_transforms(haze_image)
24
+ #print(haze_image)
25
+ haze_image = haze_image.unsqueeze(0)
26
+ start = time.time()
27
+ restored2 = Net(haze_image)
28
+ end = time.time()
29
+
30
+
31
+ return restored2,end-start
32
+
33
+ if __name__ == '__main__':
34
+ parser=argparse.ArgumentParser()
35
+ parser.add_argument('--test_path',type=str,required=True,help='Path to test')
36
+ parser.add_argument('--save_path',type=str,required=True,help='Path to save')
37
+ parser.add_argument('--pk_path',type=str,default='model_zoo/Haze4k.tjm',help='Path of the checkpoint')
38
+ opt = parser.parse_args()
39
+ if not os.path.isdir(opt.save_path):
40
+ os.mkdir(opt.save_path)
41
+ Net=torch.jit.load(opt.pk_path,map_location=torch.device('cpu')).eval()
42
+ image = opt.test_path
43
+ print(image)
44
+ restored2,time_num = inference_img(image,Net)
45
+ torchvision.utils.save_image(restored2,opt.save_path+os.path.split(image)[-1])
Lowlight.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import onnxruntime
3
+ import onnx
4
+ import cv2
5
+ import argparse
6
+ import warnings
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import os
10
+
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('--test_path', type=str, default='/home/arye-stark/zwb/Illumination-Adaptive-Transformer/IAT_enhance/demo_imgs/low_demo.jpg')
13
+ parser.add_argument('--pk_path', type=str, default='model_zoo/Low.onnx')
14
+ parser.add_argument('--save_path', type=str, default='Results/')
15
+ config = parser.parse_args()
16
+
17
+ if not os.path.isdir(config.save_path):
18
+ os.mkdir(config.save_path)
19
+
20
+ img = plt.imread(config.test_path)
21
+ input_image = np.asarray(img) / 255.0
22
+ input_image = torch.from_numpy(input_image).float()
23
+ input_image = input_image.permute(2, 0, 1).unsqueeze(0)
24
+ input_image = input_image.numpy()
25
+
26
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
27
+ model_name = 'IAT'
28
+
29
+ print('-' * 50)
30
+ try:
31
+ onnx_session = onnxruntime.InferenceSession(config.pk_path, providers=providers)
32
+ onnx_input = {'input': input_image}
33
+ #onnx_output0, onnx_output1, onnx_output2 = onnx_session.run(['output0', 'output1', 'output2'], onnx_input)
34
+ onnx_output = onnx_session.run(['output'], onnx_input)
35
+ torch_output = np.squeeze(onnx_output[0], 0)
36
+ torch_output = np.transpose(torch_output * 255, [1, 2, 0]).astype(np.uint8)
37
+ plt.imsave(config.save_path+os.path.split(config.test_path)[-1], torch_output)
38
+ except Exception as e:
39
+ print(f'Input on model:{model_name} failed')
40
+ print(e)
41
+ else:
42
+ print(f'Input on model:{model_name} succeed')
43
+
44
+
SuperResolution.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import time
7
+ import torchvision
8
+ import argparse
9
+ from models.SCET import SCET
10
+
11
+ def inference_img(img_path,Net,device):
12
+
13
+ low_image = Image.open(img_path).convert('RGB')
14
+ enhance_transforms = transforms.Compose([
15
+ transforms.ToTensor()
16
+ ])
17
+
18
+ with torch.no_grad():
19
+ low_image = enhance_transforms(low_image)
20
+ low_image = low_image.unsqueeze(0)
21
+ start = time.time()
22
+ restored2 = Net(low_image.to(device))
23
+ end = time.time()
24
+
25
+
26
+ return restored2,end-start
27
+
28
+ if __name__ == '__main__':
29
+ parser=argparse.ArgumentParser()
30
+ parser.add_argument('--test_path',type=str,required=True,help='Path to test')
31
+ parser.add_argument('--save_path',type=str,required=True,help='Path to save')
32
+ parser.add_argument('--pk_path',type=str,default='model_zoo/SRx4.pth',help='Path of the checkpoint')
33
+ parser.add_argument('--scale',type=int,default=4,help='scale factor')
34
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
+ opt = parser.parse_args()
36
+ if not os.path.isdir(opt.save_path):
37
+ os.mkdir(opt.save_path)
38
+ if opt.scale == 3:
39
+ Net = SCET(63, 128, opt.scale).eval()
40
+ else:
41
+ Net = SCET(64, 128, opt.scale).eval()
42
+ Net.load_state_dict(torch.load(opt.pk_path))
43
+ Net=Net.to(device)
44
+ image=opt.test_path
45
+ print(image)
46
+ restored2,time_num=inference_img(image,Net,device)
47
+ torchvision.utils.save_image(restored2,opt.save_path+os.path.split(image)[-1])