denoising / app.py
Julián Tachella
test
71c2965
import gradio as gr
import deepinv as dinv
import torch
import numpy as np
import PIL.Image
def pil_to_torch(image, ref_size=512):
image = np.array(image)
image = image.transpose((2, 0, 1))
image = torch.tensor(image).float() / 255
image = image.unsqueeze(0)
if ref_size == 256:
size = (ref_size, ref_size)
elif image.shape[2] > image.shape[3]:
size = (ref_size, ref_size * image.shape[3]//image.shape[2])
else:
size = (ref_size * image.shape[2]//image.shape[3], ref_size)
image = torch.nn.functional.interpolate(image, size=size, mode='bilinear')
return image
def torch_to_pil(image):
image = image.squeeze(0).cpu().detach().numpy()
image = image.transpose((1, 2, 0))
image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
image = PIL.Image.fromarray(image)
return image
def image_mod(image, noise_level, denoiser):
image = pil_to_torch(image, ref_size=256 if denoiser == 'DiffUNet' else 512)
if denoiser == 'DnCNN':
den = dinv.models.DnCNN()
sigma0 = 2/255
denoiser = lambda x, sigma: den(x*sigma0/sigma)*sigma/sigma0
elif denoiser == 'MedianFilter':
denoiser = dinv.models.MedianFilter(kernel_size=5)
elif denoiser == 'BM3D':
denoiser = dinv.models.BM3D()
elif denoiser == 'TV':
denoiser = dinv.models.TVDenoiser()
elif denoiser == 'TGV':
denoiser = dinv.models.TGVDenoiser()
elif denoiser == 'Wavelets':
denoiser = dinv.models.WaveletPrior()
elif denoiser == 'DiffUNet':
denoiser = dinv.models.DiffUNet()
elif denoiser == 'DRUNet':
denoiser = dinv.models.DRUNet()
else:
raise ValueError("Invalid denoiser")
noisy = image + torch.randn_like(image) * noise_level
estimated = denoiser(noisy, noise_level)
return torch_to_pil(noisy), torch_to_pil(estimated)
input_image = gr.Image(label='Input Image')
output_images = gr.Image(label='Denoised Image')
noise_image = gr.Image(label='Noisy Image')
input_image_output = gr.Image(label='Input Image')
noise_levels = gr.Dropdown(choices=[0.05, 0.1, 0.2, 0.3, 0.5, 1], value=0.1, label='Noise Level')
denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'DiffUNet', 'BM3D', 'MedianFilter', 'TV', 'TGV', 'Wavelets'], value='DRUNet', label='Denoiser')
demo = gr.Interface(
image_mod,
inputs=[input_image, noise_levels, denoiser],
examples=[['https://upload.wikimedia.org/wikipedia/commons/b/b4/Lionel-Messi-Argentina-2022-FIFA-World-Cup_%28cropped%29.jpg', 0.1, 'DRUNet']],
outputs=[noise_image, output_images],
title="Image Denoising with DeepInverse",
description="Denoise an image using a variety of denoisers and noise levels using the deepinverse library (https://deepinv.github.io/). We only include lightweight models like DnCNN and MedianFilter as this example is intended to be run on a CPU. We also automatically resize the input image to 512 pixels to reduce the computation time. For more advanced models, please run the code locally.",
)
demo.launch()