File size: 3,060 Bytes
af9162a
52893ae
 
501b516
a983c72
af9162a
ba74db2
71c2965
501b516
 
a983c72
da5fdaa
ece0ce5
71c2965
6cc096a
 
fd548c6
ece0ce5
fd548c6
ece0ce5
 
da5fdaa
501b516
a983c72
501b516
a983c72
501b516
da5fdaa
a983c72
501b516
af9162a
a983c72
ba74db2
71c2965
ba74db2
69590ad
 
866446d
ba74db2
ece0ce5
8a7fe4e
 
da5fdaa
 
 
 
866446d
 
76c7ca0
 
8a7fe4e
 
ba74db2
 
a983c72
a4dc15b
501b516
 
 
dbb94b0
7f74cd7
 
 
501b516
71c2965
a983c72
71c2965
744ad2f
 
 
ba74db2
71c2965
501b516
 
71c2965
744ad2f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import deepinv as dinv
import torch
import numpy as np
import PIL.Image


def pil_to_torch(image, ref_size=512):
    image = np.array(image)
    image = image.transpose((2, 0, 1))
    image = torch.tensor(image).float() / 255
    image = image.unsqueeze(0)

    if ref_size == 256:
        size = (ref_size, ref_size)
    elif image.shape[2] > image.shape[3]:
        size = (ref_size, ref_size * image.shape[3]//image.shape[2])
    else:
        size = (ref_size * image.shape[2]//image.shape[3], ref_size)

    image = torch.nn.functional.interpolate(image, size=size, mode='bilinear')
    return image


def torch_to_pil(image):
    image = image.squeeze(0).cpu().detach().numpy()
    image = image.transpose((1, 2, 0))
    image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
    image = PIL.Image.fromarray(image)
    return image


def image_mod(image, noise_level, denoiser):
    image = pil_to_torch(image, ref_size=256 if denoiser == 'DiffUNet' else 512)
    if denoiser == 'DnCNN':
        den = dinv.models.DnCNN()
        sigma0 = 2/255
        denoiser = lambda x, sigma: den(x*sigma0/sigma)*sigma/sigma0
    elif denoiser == 'MedianFilter':
        denoiser = dinv.models.MedianFilter(kernel_size=5)
    elif denoiser == 'BM3D':
        denoiser = dinv.models.BM3D()
    elif denoiser == 'TV':
        denoiser = dinv.models.TVDenoiser()
    elif denoiser == 'TGV':
        denoiser = dinv.models.TGVDenoiser()
    elif denoiser == 'Wavelets':
        denoiser = dinv.models.WaveletPrior()
    elif denoiser == 'DiffUNet':
        denoiser = dinv.models.DiffUNet()
    elif denoiser == 'DRUNet':
        denoiser = dinv.models.DRUNet()
    else:
        raise ValueError("Invalid denoiser")
    noisy = image + torch.randn_like(image) * noise_level
    estimated = denoiser(noisy, noise_level)
    return torch_to_pil(noisy), torch_to_pil(estimated)


input_image = gr.Image(label='Input Image')
output_images = gr.Image(label='Denoised Image')
noise_image = gr.Image(label='Noisy Image')
input_image_output = gr.Image(label='Input Image')

noise_levels = gr.Dropdown(choices=[0.05, 0.1, 0.2, 0.3, 0.5, 1], value=0.1, label='Noise Level')

denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'DiffUNet', 'BM3D', 'MedianFilter', 'TV', 'TGV', 'Wavelets'], value='DRUNet', label='Denoiser')

demo = gr.Interface(
    image_mod,
    inputs=[input_image, noise_levels, denoiser],
    examples=[['https://upload.wikimedia.org/wikipedia/commons/b/b4/Lionel-Messi-Argentina-2022-FIFA-World-Cup_%28cropped%29.jpg', 0.1, 'DRUNet']],
    outputs=[noise_image, output_images],
    title="Image Denoising with DeepInverse",
    description="Denoise an image using a variety of denoisers and noise levels using the deepinverse library (https://deepinv.github.io/). We only include lightweight models like DnCNN and MedianFilter as this example is intended to be run on a CPU. We also automatically resize the input image to 512 pixels to reduce the computation time. For more advanced models, please run the code locally.",
)

demo.launch()