Nguyễn Bá Thiêm commited on
Commit
d09509f
·
1 Parent(s): 70dafec

Refactor function to improve performance and readability

Browse files
Files changed (2) hide show
  1. models/SRFlow/srflow.py +6 -3
  2. models/SRGAN/m.py +9 -0
models/SRFlow/srflow.py CHANGED
@@ -55,7 +55,8 @@ def test_srflow(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', hea
55
 
56
  scale = opt['scale']
57
  pad_factor = 2
58
-
 
59
  h, w, c = lr.shape
60
  lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
61
  right=int(np.ceil(w / pad_factor) * pad_factor - w))
@@ -67,12 +68,14 @@ def test_srflow(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', hea
67
 
68
  sr = rgb(torch.clamp(sr_t, 0, 1))
69
  sr = sr[:h * scale, :w * scale]
70
-
 
71
  return sr
72
 
73
  if __name__ == '__main__':
74
  ip = Image.open('images/demo.png')
75
  lr = PILToTensor()(ip).permute(1, 2, 0).numpy()
76
- sr = test(lr)
 
77
  # sr = return_SRFlow_result(ip)
78
  print(sr.size)
 
55
 
56
  scale = opt['scale']
57
  pad_factor = 2
58
+
59
+ lr = lr.squeeze(0).permute(0, 1, 2).numpy()
60
  h, w, c = lr.shape
61
  lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
62
  right=int(np.ceil(w / pad_factor) * pad_factor - w))
 
68
 
69
  sr = rgb(torch.clamp(sr_t, 0, 1))
70
  sr = sr[:h * scale, :w * scale]
71
+ sr = sr.unsqueeze(0).permute(0, 3, 1, 2)
72
+
73
  return sr
74
 
75
  if __name__ == '__main__':
76
  ip = Image.open('images/demo.png')
77
  lr = PILToTensor()(ip).permute(1, 2, 0).numpy()
78
+ print(lr.shape)
79
+ sr = test_srflow(lr)
80
  # sr = return_SRFlow_result(ip)
81
  print(sr.size)
models/SRGAN/m.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torchvision.transforms import PILToTensor
3
+ ip = Image.open('images/demo.png')
4
+ lr = PILToTensor()(ip).unsqueeze(0).permute(0, 2, 3, 1)
5
+ print(lr.shape)
6
+ lr = lr.squeeze(0).permute(0, 1, 2).numpy()
7
+ # lr = PILToTensor()(ip).permute(1, 2, 0)
8
+ # lr = lr.unsqueeze(0).permute(0, 3, 1, 2).numpy()
9
+ print(lr.shape)