Meloo commited on
Commit
f65651d
1 Parent(s): 6265765

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -127,13 +127,13 @@ def inference(image, upscale, large_input_flag, color_fix):
127
  print(f'input size: {img.shape}')
128
 
129
  # img2tensor
130
- img = img.astype(np.float32) / 255.
131
- img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
132
- img = img.unsqueeze(0).to(device)
133
 
134
  # inference
135
  if large_input_flag:
136
- patches, idx, size = img2patch(img, scale=upscale)
137
  with torch.no_grad():
138
  n = len(patches)
139
  outs = []
@@ -153,24 +153,26 @@ def inference(image, upscale, large_input_flag, color_fix):
153
  output = patch2img(output, idx, size, scale=upscale)
154
  else:
155
  with torch.no_grad():
156
- output = model(img)
157
 
158
  # color fix
159
  if color_fix:
160
- img = F.interpolate(img, scale_factor=upscale, mode='bilinear')
161
- output = wavelet_reconstruction(output, img)
162
  # tensor2img
163
  output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
164
  if output.ndim == 3:
165
  output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
166
  output = (output * 255.0).round().astype(np.uint8)
 
 
167
 
168
- # save restored img
169
- save_path = f'results/out.png'
170
- cv2.imwrite(save_path, output)
171
 
172
- output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
173
- return output, save_path
174
 
175
 
176
 
@@ -223,10 +225,11 @@ demo = gr.Interface(
223
  fn=inference,
224
  inputs=[
225
  gr.Image(value="real_testdata/004.png", type="pil", label="Input"),
226
- gr.Number(minimum=2, maximum=4, default=2, label="Upscaling factor (up to 4)"),
227
  gr.Checkbox(value=False, label="Memory-efficient inference"),
228
  gr.Checkbox(value=False, label="Color correction"),
229
  ],
 
230
  outputs=ImageSlider(label="Super-Resolved Image",
231
  type="pil",
232
  show_download_button=True,
 
127
  print(f'input size: {img.shape}')
128
 
129
  # img2tensor
130
+ y = y.astype(np.float32) / 255.
131
+ y = torch.from_numpy(np.transpose(y[:, :, [2, 1, 0]], (2, 0, 1))).float()
132
+ y = y.unsqueeze(0).to(device)
133
 
134
  # inference
135
  if large_input_flag:
136
+ patches, idx, size = img2patch(y, scale=upscale)
137
  with torch.no_grad():
138
  n = len(patches)
139
  outs = []
 
153
  output = patch2img(output, idx, size, scale=upscale)
154
  else:
155
  with torch.no_grad():
156
+ output = model(y)
157
 
158
  # color fix
159
  if color_fix:
160
+ y = F.interpolate(y, scale_factor=upscale, mode='bilinear')
161
+ output = wavelet_reconstruction(output, y)
162
  # tensor2img
163
  output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
164
  if output.ndim == 3:
165
  output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
166
  output = (output * 255.0).round().astype(np.uint8)
167
+
168
+ return (img, Image.fromarray(output))
169
 
170
+ # # save restored img
171
+ # save_path = f'results/out.png'
172
+ # cv2.imwrite(save_path, output)
173
 
174
+ # output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
175
+ # return output, save_path
176
 
177
 
178
 
 
225
  fn=inference,
226
  inputs=[
227
  gr.Image(value="real_testdata/004.png", type="pil", label="Input"),
228
+ gr.Number(minimum=2, maximum=4, default_value=2, label="Upscaling factor (up to 4)"),
229
  gr.Checkbox(value=False, label="Memory-efficient inference"),
230
  gr.Checkbox(value=False, label="Color correction"),
231
  ],
232
+
233
  outputs=ImageSlider(label="Super-Resolved Image",
234
  type="pil",
235
  show_download_button=True,