Julián Tachella commited on
Commit
da5fdaa
1 Parent(s): f3af5c1
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -9,13 +9,15 @@ def pil_to_torch(image):
9
  image = np.array(image)
10
  image = image.transpose((2, 0, 1))
11
  image = torch.tensor(image).float() / 255
12
- return image.unsqueeze(0)
 
 
13
 
14
 
15
  def torch_to_pil(image):
16
  image = image.squeeze(0).cpu().detach().numpy()
17
  image = image.transpose((1, 2, 0))
18
- image = (image * 255).astype(np.uint8)
19
  image = PIL.Image.fromarray(image)
20
  return image
21
 
@@ -28,6 +30,10 @@ def image_mod(image, noise_level, denoiser):
28
  denoiser = dinv.models.MedianFilter()
29
  elif denoiser == 'BM3D':
30
  denoiser = dinv.models.BM3D()
 
 
 
 
31
  elif denoiser == 'DRUNet':
32
  denoiser = dinv.models.DRUNet()
33
  else:
@@ -44,7 +50,7 @@ input_image_output = gr.Image(label='Input Image')
44
 
45
  noise_levels = gr.Dropdown(choices=[0.1, 0.2, 0.3, 0.4, 0.5], value=0.1, label='Noise Level')
46
 
47
- denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'BM3D', 'MedianFilter'], value='DnCNN', label='Denoiser')
48
 
49
  demo = gr.Interface(
50
  image_mod,
 
9
  image = np.array(image)
10
  image = image.transpose((2, 0, 1))
11
  image = torch.tensor(image).float() / 255
12
+ image = image.unsqueeze(0)
13
+ image = torch.nn.functional.interpolate(image.unsqueeze(0), size=(128, 128*image.shape[3]//image.shape[2]))
14
+ return image
15
 
16
 
17
  def torch_to_pil(image):
18
  image = image.squeeze(0).cpu().detach().numpy()
19
  image = image.transpose((1, 2, 0))
20
+ image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
21
  image = PIL.Image.fromarray(image)
22
  return image
23
 
 
30
  denoiser = dinv.models.MedianFilter()
31
  elif denoiser == 'BM3D':
32
  denoiser = dinv.models.BM3D()
33
+ elif denoiser == 'TV':
34
+ denoiser = dinv.models.TVDenoiser()
35
+ elif denoiser == 'TGV':
36
+ denoiser = dinv.models.TGVDenoiser()
37
  elif denoiser == 'DRUNet':
38
  denoiser = dinv.models.DRUNet()
39
  else:
 
50
 
51
  noise_levels = gr.Dropdown(choices=[0.1, 0.2, 0.3, 0.4, 0.5], value=0.1, label='Noise Level')
52
 
53
+ denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'BM3D', 'MedianFilter', 'TV', 'TGV'], value='DnCNN', label='Denoiser')
54
 
55
  demo = gr.Interface(
56
  image_mod,