Spaces:
Running
Running
Nguyễn Bá Thiêm
commited on
Commit
•
12f4dcf
1
Parent(s):
bff701c
Add test methods to RCAN and SRGAN models
Browse files- models/RCAN/rcan.py +9 -0
- models/SRFlow/srflow.py +35 -1
- models/SRGAN/srgan.py +12 -1
models/RCAN/rcan.py
CHANGED
@@ -118,6 +118,15 @@ class RCAN(nn.Module):
|
|
118 |
x = self.forward(x)
|
119 |
x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
|
120 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
if __name__ == '__main__':
|
123 |
current_dir = os.path.dirname(os.path.realpath(__file__))
|
|
|
118 |
x = self.forward(x)
|
119 |
x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
|
120 |
return x
|
121 |
+
|
122 |
+
def test(self, x):
|
123 |
+
"""
|
124 |
+
x is a tensor
|
125 |
+
"""
|
126 |
+
self.eval()
|
127 |
+
with torch.no_grad():
|
128 |
+
x = self.forward(x)
|
129 |
+
return x
|
130 |
|
131 |
if __name__ == '__main__':
|
132 |
current_dir = os.path.dirname(os.path.realpath(__file__))
|
models/SRFlow/srflow.py
CHANGED
@@ -39,7 +39,41 @@ def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.
|
|
39 |
sr = Image.fromarray((sr).astype('uint8'))
|
40 |
return sr
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
if __name__ == '__main__':
|
43 |
ip = Image.open('images/demo.png')
|
44 |
-
|
|
|
|
|
45 |
print(sr.size)
|
|
|
39 |
sr = Image.fromarray((sr).astype('uint8'))
|
40 |
return sr
|
41 |
|
42 |
+
def test(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', heat=0.7):
|
43 |
+
"""
|
44 |
+
Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
- lr: tensor
|
48 |
+
- conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml.
|
49 |
+
- heat (float): Heat parameter for the SRFlow model. Default is 0.6.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
- sr: tensor
|
53 |
+
"""
|
54 |
+
model, opt = load_model(conf_path)
|
55 |
+
|
56 |
+
scale = opt['scale']
|
57 |
+
pad_factor = 2
|
58 |
+
|
59 |
+
h, w, c = lr.shape
|
60 |
+
lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
|
61 |
+
right=int(np.ceil(w / pad_factor) * pad_factor - w))
|
62 |
+
|
63 |
+
lr_t = t(lr)
|
64 |
+
heat = opt['heat']
|
65 |
+
|
66 |
+
sr_t = model.get_sr(lq=lr_t, heat=heat)
|
67 |
+
|
68 |
+
sr = rgb(torch.clamp(sr_t, 0, 1))
|
69 |
+
sr = sr[:h * scale, :w * scale]
|
70 |
+
|
71 |
+
sr = Image.fromarray((sr).astype('uint8'))
|
72 |
+
return sr
|
73 |
+
|
74 |
if __name__ == '__main__':
|
75 |
ip = Image.open('images/demo.png')
|
76 |
+
lr = PILToTensor()(ip).permute(1, 2, 0).numpy()
|
77 |
+
sr = test(lr)
|
78 |
+
# sr = return_SRFlow_result(ip)
|
79 |
print(sr.size)
|
models/SRGAN/srgan.py
CHANGED
@@ -64,6 +64,15 @@ class GeneratorResnet(nn.Module):
|
|
64 |
x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
|
65 |
return TF.adjust_brightness(x, 1.1)
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
if __name__ == '__main__':
|
68 |
current_dir = os.path.dirname(os.path.realpath(__file__))
|
69 |
|
@@ -72,4 +81,6 @@ if __name__ == '__main__':
|
|
72 |
model.eval()
|
73 |
with torch.no_grad():
|
74 |
input_image = Image.open('images/demo.png')
|
75 |
-
|
|
|
|
|
|
64 |
x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
|
65 |
return TF.adjust_brightness(x, 1.1)
|
66 |
|
67 |
+
def test(self, x):
|
68 |
+
"""
|
69 |
+
x is a tensor
|
70 |
+
"""
|
71 |
+
self.eval()
|
72 |
+
with torch.no_grad():
|
73 |
+
x = self.forward(x)
|
74 |
+
return x
|
75 |
+
|
76 |
if __name__ == '__main__':
|
77 |
current_dir = os.path.dirname(os.path.realpath(__file__))
|
78 |
|
|
|
81 |
model.eval()
|
82 |
with torch.no_grad():
|
83 |
input_image = Image.open('images/demo.png')
|
84 |
+
input_image = ToTensor()(input_image).unsqueeze(0)
|
85 |
+
output_image = model.test(input_image)
|
86 |
+
print(output_image.max())
|