Spaces:
Running
Running
Nguyễn Bá Thiêm
commited on
Commit
·
d09509f
1
Parent(s):
70dafec
Refactor function to improve performance and readability
Browse files- models/SRFlow/srflow.py +6 -3
- models/SRGAN/m.py +9 -0
models/SRFlow/srflow.py
CHANGED
@@ -55,7 +55,8 @@ def test_srflow(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', hea
|
|
55 |
|
56 |
scale = opt['scale']
|
57 |
pad_factor = 2
|
58 |
-
|
|
|
59 |
h, w, c = lr.shape
|
60 |
lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
|
61 |
right=int(np.ceil(w / pad_factor) * pad_factor - w))
|
@@ -67,12 +68,14 @@ def test_srflow(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', hea
|
|
67 |
|
68 |
sr = rgb(torch.clamp(sr_t, 0, 1))
|
69 |
sr = sr[:h * scale, :w * scale]
|
70 |
-
|
|
|
71 |
return sr
|
72 |
|
73 |
if __name__ == '__main__':
|
74 |
ip = Image.open('images/demo.png')
|
75 |
lr = PILToTensor()(ip).permute(1, 2, 0).numpy()
|
76 |
-
|
|
|
77 |
# sr = return_SRFlow_result(ip)
|
78 |
print(sr.size)
|
|
|
55 |
|
56 |
scale = opt['scale']
|
57 |
pad_factor = 2
|
58 |
+
|
59 |
+
lr = lr.squeeze(0).permute(0, 1, 2).numpy()
|
60 |
h, w, c = lr.shape
|
61 |
lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
|
62 |
right=int(np.ceil(w / pad_factor) * pad_factor - w))
|
|
|
68 |
|
69 |
sr = rgb(torch.clamp(sr_t, 0, 1))
|
70 |
sr = sr[:h * scale, :w * scale]
|
71 |
+
sr = sr.unsqueeze(0).permute(0, 3, 1, 2)
|
72 |
+
|
73 |
return sr
|
74 |
|
75 |
if __name__ == '__main__':
|
76 |
ip = Image.open('images/demo.png')
|
77 |
lr = PILToTensor()(ip).permute(1, 2, 0).numpy()
|
78 |
+
print(lr.shape)
|
79 |
+
sr = test_srflow(lr)
|
80 |
# sr = return_SRFlow_result(ip)
|
81 |
print(sr.size)
|
models/SRGAN/m.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from torchvision.transforms import PILToTensor
|
3 |
+
ip = Image.open('images/demo.png')
|
4 |
+
lr = PILToTensor()(ip).unsqueeze(0).permute(0, 2, 3, 1)
|
5 |
+
print(lr.shape)
|
6 |
+
lr = lr.squeeze(0).permute(0, 1, 2).numpy()
|
7 |
+
# lr = PILToTensor()(ip).permute(1, 2, 0)
|
8 |
+
# lr = lr.unsqueeze(0).permute(0, 3, 1, 2).numpy()
|
9 |
+
print(lr.shape)
|