Update app.py
Browse files
app.py
CHANGED
@@ -127,13 +127,13 @@ def inference(image, upscale, large_input_flag, color_fix):
|
|
127 |
print(f'input size: {img.shape}')
|
128 |
|
129 |
# img2tensor
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
|
134 |
# inference
|
135 |
if large_input_flag:
|
136 |
-
patches, idx, size = img2patch(
|
137 |
with torch.no_grad():
|
138 |
n = len(patches)
|
139 |
outs = []
|
@@ -153,24 +153,26 @@ def inference(image, upscale, large_input_flag, color_fix):
|
|
153 |
output = patch2img(output, idx, size, scale=upscale)
|
154 |
else:
|
155 |
with torch.no_grad():
|
156 |
-
output = model(
|
157 |
|
158 |
# color fix
|
159 |
if color_fix:
|
160 |
-
|
161 |
-
output = wavelet_reconstruction(output,
|
162 |
# tensor2img
|
163 |
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
164 |
if output.ndim == 3:
|
165 |
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
|
166 |
output = (output * 255.0).round().astype(np.uint8)
|
|
|
|
|
167 |
|
168 |
-
# save restored img
|
169 |
-
save_path = f'results/out.png'
|
170 |
-
cv2.imwrite(save_path, output)
|
171 |
|
172 |
-
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
|
173 |
-
return output, save_path
|
174 |
|
175 |
|
176 |
|
@@ -223,10 +225,11 @@ demo = gr.Interface(
|
|
223 |
fn=inference,
|
224 |
inputs=[
|
225 |
gr.Image(value="real_testdata/004.png", type="pil", label="Input"),
|
226 |
-
gr.Number(minimum=2, maximum=4,
|
227 |
gr.Checkbox(value=False, label="Memory-efficient inference"),
|
228 |
gr.Checkbox(value=False, label="Color correction"),
|
229 |
],
|
|
|
230 |
outputs=ImageSlider(label="Super-Resolved Image",
|
231 |
type="pil",
|
232 |
show_download_button=True,
|
|
|
127 |
print(f'input size: {img.shape}')
|
128 |
|
129 |
# img2tensor
|
130 |
+
y = y.astype(np.float32) / 255.
|
131 |
+
y = torch.from_numpy(np.transpose(y[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
132 |
+
y = y.unsqueeze(0).to(device)
|
133 |
|
134 |
# inference
|
135 |
if large_input_flag:
|
136 |
+
patches, idx, size = img2patch(y, scale=upscale)
|
137 |
with torch.no_grad():
|
138 |
n = len(patches)
|
139 |
outs = []
|
|
|
153 |
output = patch2img(output, idx, size, scale=upscale)
|
154 |
else:
|
155 |
with torch.no_grad():
|
156 |
+
output = model(y)
|
157 |
|
158 |
# color fix
|
159 |
if color_fix:
|
160 |
+
y = F.interpolate(y, scale_factor=upscale, mode='bilinear')
|
161 |
+
output = wavelet_reconstruction(output, y)
|
162 |
# tensor2img
|
163 |
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
164 |
if output.ndim == 3:
|
165 |
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
|
166 |
output = (output * 255.0).round().astype(np.uint8)
|
167 |
+
|
168 |
+
return (img, Image.fromarray(output))
|
169 |
|
170 |
+
# # save restored img
|
171 |
+
# save_path = f'results/out.png'
|
172 |
+
# cv2.imwrite(save_path, output)
|
173 |
|
174 |
+
# output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
|
175 |
+
# return output, save_path
|
176 |
|
177 |
|
178 |
|
|
|
225 |
fn=inference,
|
226 |
inputs=[
|
227 |
gr.Image(value="real_testdata/004.png", type="pil", label="Input"),
|
228 |
+
gr.Number(minimum=2, maximum=4, default_value=2, label="Upscaling factor (up to 4)"),
|
229 |
gr.Checkbox(value=False, label="Memory-efficient inference"),
|
230 |
gr.Checkbox(value=False, label="Color correction"),
|
231 |
],
|
232 |
+
|
233 |
outputs=ImageSlider(label="Super-Resolved Image",
|
234 |
type="pil",
|
235 |
show_download_button=True,
|