Julián Tachella commited on
Commit
6cc096a
1 Parent(s): f05a019
Files changed (2) hide show
  1. app.py +6 -5
  2. requirements.txt +1 -1
app.py CHANGED
@@ -5,14 +5,15 @@ import numpy as np
5
  import PIL.Image
6
 
7
 
8
- def pil_to_torch(image):
9
  image = np.array(image)
10
  image = image.transpose((2, 0, 1))
11
  image = torch.tensor(image).float() / 255
12
  image = image.unsqueeze(0)
13
 
14
- ref_size = 256
15
- if image.shape[2] > image.shape[3]:
 
16
  size = (ref_size, ref_size * image.shape[3]//image.shape[2])
17
  else:
18
  size = (ref_size * image.shape[2]//image.shape[3], ref_size)
@@ -30,7 +31,7 @@ def torch_to_pil(image):
30
 
31
 
32
  def image_mod(image, noise_level, denoiser):
33
- image = pil_to_torch(image)
34
  if denoiser == 'DnCNN':
35
  den = dinv.models.DnCNN()
36
  sigma0 = 2/255
@@ -46,7 +47,7 @@ def image_mod(image, noise_level, denoiser):
46
  elif denoiser == 'Wavelets':
47
  denoiser = dinv.models.WaveletPrior()
48
  elif denoiser == 'SwinIR':
49
- denoiser = dinv.models.SwinIR(img_size=256)
50
  elif denoiser == 'DRUNet':
51
  denoiser = dinv.models.DRUNet()
52
  else:
 
5
  import PIL.Image
6
 
7
 
8
+ def pil_to_torch(image, ref_size=256):
9
  image = np.array(image)
10
  image = image.transpose((2, 0, 1))
11
  image = torch.tensor(image).float() / 255
12
  image = image.unsqueeze(0)
13
 
14
+ if ref_size == 128:
15
+ size = (ref_size, ref_size)
16
+ elif image.shape[2] > image.shape[3]:
17
  size = (ref_size, ref_size * image.shape[3]//image.shape[2])
18
  else:
19
  size = (ref_size * image.shape[2]//image.shape[3], ref_size)
 
31
 
32
 
33
  def image_mod(image, noise_level, denoiser):
34
+ image = pil_to_torch(image, ref_size=128 if denoiser == 'SwinIR' else 256)
35
  if denoiser == 'DnCNN':
36
  den = dinv.models.DnCNN()
37
  sigma0 = 2/255
 
47
  elif denoiser == 'Wavelets':
48
  denoiser = dinv.models.WaveletPrior()
49
  elif denoiser == 'SwinIR':
50
+ denoiser = dinv.models.SwinIR()
51
  elif denoiser == 'DRUNet':
52
  denoiser = dinv.models.DRUNet()
53
  else:
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  deepinv
2
  bm3d
3
  timm
4
- PyWavelets
 
1
  deepinv
2
  bm3d
3
  timm
4
+ ptwt