AlexZou commited on
Commit
7970501
1 Parent(s): 234f5e7

Upload 4 files

Browse files
Files changed (2) hide show
  1. SuperResolution.py +7 -8
  2. Underwater.py +6 -7
SuperResolution.py CHANGED
@@ -8,7 +8,7 @@ import torchvision
8
  import argparse
9
  from models.SCET import SCET
10
 
11
- def inference_img(img_path,Net,device):
12
 
13
  low_image = Image.open(img_path).convert('RGB')
14
  enhance_transforms = transforms.Compose([
@@ -19,7 +19,7 @@ def inference_img(img_path,Net,device):
19
  low_image = enhance_transforms(low_image)
20
  low_image = low_image.unsqueeze(0)
21
  start = time.time()
22
- restored2 = Net(low_image.to(device))
23
  end = time.time()
24
 
25
 
@@ -31,17 +31,16 @@ if __name__ == '__main__':
31
  parser.add_argument('--save_path',type=str,required=True,help='Path to save')
32
  parser.add_argument('--pk_path',type=str,default='model_zoo/SRx4.pth',help='Path of the checkpoint')
33
  parser.add_argument('--scale',type=int,default=4,help='scale factor')
34
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
  opt = parser.parse_args()
36
  if not os.path.isdir(opt.save_path):
37
  os.mkdir(opt.save_path)
38
  if opt.scale == 3:
39
- Net = SCET(63, 128, opt.scale).eval()
40
  else:
41
- Net = SCET(64, 128, opt.scale).eval()
42
- Net.load_state_dict(torch.load(opt.pk_path))
43
- Net=Net.to(device)
44
  image=opt.test_path
45
  print(image)
46
- restored2,time_num=inference_img(image,Net,device)
47
  torchvision.utils.save_image(restored2,opt.save_path+'output.png')
 
8
  import argparse
9
  from models.SCET import SCET
10
 
11
+ def inference_img(img_path,Net):
12
 
13
  low_image = Image.open(img_path).convert('RGB')
14
  enhance_transforms = transforms.Compose([
 
19
  low_image = enhance_transforms(low_image)
20
  low_image = low_image.unsqueeze(0)
21
  start = time.time()
22
+ restored2 = Net(low_image)
23
  end = time.time()
24
 
25
 
 
31
  parser.add_argument('--save_path',type=str,required=True,help='Path to save')
32
  parser.add_argument('--pk_path',type=str,default='model_zoo/SRx4.pth',help='Path of the checkpoint')
33
  parser.add_argument('--scale',type=int,default=4,help='scale factor')
 
34
  opt = parser.parse_args()
35
  if not os.path.isdir(opt.save_path):
36
  os.mkdir(opt.save_path)
37
  if opt.scale == 3:
38
+ Net = SCET(63, 128, opt.scale)
39
  else:
40
+ Net = SCET(64, 128, opt.scale)
41
+ Net.load_state_dict(torch.load(opt.pk_path, map_location=torch.device('cpu')))
42
+ Net=Net.eval()
43
  image=opt.test_path
44
  print(image)
45
+ restored2,time_num=inference_img(image,Net)
46
  torchvision.utils.save_image(restored2,opt.save_path+'output.png')
Underwater.py CHANGED
@@ -11,7 +11,7 @@ import torch.functional as F
11
  import argparse
12
  from net.Ushape_Trans import *
13
 
14
- def inference_img(img_path,Net,device):
15
 
16
  low_image = Image.open(img_path).convert('RGB')
17
  enhance_transforms = transforms.Compose([
@@ -23,7 +23,7 @@ def inference_img(img_path,Net,device):
23
  low_image = enhance_transforms(low_image)
24
  low_image = low_image.unsqueeze(0)
25
  start = time.time()
26
- restored2 = Net(low_image.to(device))
27
  end = time.time()
28
 
29
 
@@ -37,11 +37,10 @@ if __name__ == '__main__':
37
  opt = parser.parse_args()
38
  if not os.path.isdir(opt.save_path):
39
  os.mkdir(opt.save_path)
40
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
- Net = Generator().eval()
42
- Net.load_state_dict(torch.load(opt.pk_path))
43
- Net = Net.to(device)
44
  image = opt.test_path
45
  print(image)
46
- restored2,time_num = inference_img(image,Net,device)
47
  torchvision.utils.save_image(restored2,opt.save_path+'output.png')
 
11
  import argparse
12
  from net.Ushape_Trans import *
13
 
14
+ def inference_img(img_path,Net):
15
 
16
  low_image = Image.open(img_path).convert('RGB')
17
  enhance_transforms = transforms.Compose([
 
23
  low_image = enhance_transforms(low_image)
24
  low_image = low_image.unsqueeze(0)
25
  start = time.time()
26
+ restored2 = Net(low_image)
27
  end = time.time()
28
 
29
 
 
37
  opt = parser.parse_args()
38
  if not os.path.isdir(opt.save_path):
39
  os.mkdir(opt.save_path)
40
+ Net = Generator()
41
+ Net.load_state_dict(torch.load(opt.pk_path, map_location=torch.device('cpu')))
42
+ Net = Net.eval()
 
43
  image = opt.test_path
44
  print(image)
45
+ restored2,time_num = inference_img(image,Net)
46
  torchvision.utils.save_image(restored2,opt.save_path+'output.png')