AlexZou commited on
Commit
af8dd52
1 Parent(s): 9eb95ee

Upload Underwater.py

Browse files
Files changed (1) hide show
  1. Underwater.py +47 -0
Underwater.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import time
7
+ import torchvision
8
+ import cv2
9
+ import torchvision.utils as tvu
10
+ import torch.functional as F
11
+ import argparse
12
+ from net.Ushape_Trans import *
13
+
14
+ def inference_img(img_path,Net,device):
15
+
16
+ low_image = Image.open(img_path).convert('RGB')
17
+ enhance_transforms = transforms.Compose([
18
+ transforms.Resize((256,256)),
19
+ transforms.ToTensor()
20
+ ])
21
+
22
+ with torch.no_grad():
23
+ low_image = enhance_transforms(low_image)
24
+ low_image = low_image.unsqueeze(0)
25
+ start = time.time()
26
+ restored2 = Net(low_image.to(device))
27
+ end = time.time()
28
+
29
+
30
+ return restored2,end-start
31
+
32
+ if __name__ == '__main__':
33
+ parser=argparse.ArgumentParser()
34
+ parser.add_argument('--test_path',type=str,required=True,help='Path to test')
35
+ parser.add_argument('--save_path',type=str,required=True,help='Path to save')
36
+ parser.add_argument('--pk_path',type=str,default='model_zoo/underwater.pth',help='Path of the checkpoint')
37
+ opt = parser.parse_args()
38
+ if not os.path.isdir(opt.save_path):
39
+ os.mkdir(opt.save_path)
40
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
+ Net = Generator().eval()
42
+ Net.load_state_dict(torch.load(opt.pk_path))
43
+ Net = Net.to(device)
44
+ image = opt.test_path
45
+ print(image)
46
+ restored2,time_num = inference_img(image,Net,device)
47
+ torchvision.utils.save_image(restored2,opt.save_path+os.path.split(image)[-1])