denoising / app.py
Julián Tachella
test
71c2965
raw
history blame contribute delete
No virus
3.06 kB
import gradio as gr
import deepinv as dinv
import torch
import numpy as np
import PIL.Image
def pil_to_torch(image, ref_size=512):
image = np.array(image)
image = image.transpose((2, 0, 1))
image = torch.tensor(image).float() / 255
image = image.unsqueeze(0)
if ref_size == 256:
size = (ref_size, ref_size)
elif image.shape[2] > image.shape[3]:
size = (ref_size, ref_size * image.shape[3]//image.shape[2])
else:
size = (ref_size * image.shape[2]//image.shape[3], ref_size)
image = torch.nn.functional.interpolate(image, size=size, mode='bilinear')
return image
def torch_to_pil(image):
image = image.squeeze(0).cpu().detach().numpy()
image = image.transpose((1, 2, 0))
image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
image = PIL.Image.fromarray(image)
return image
def image_mod(image, noise_level, denoiser):
image = pil_to_torch(image, ref_size=256 if denoiser == 'DiffUNet' else 512)
if denoiser == 'DnCNN':
den = dinv.models.DnCNN()
sigma0 = 2/255
denoiser = lambda x, sigma: den(x*sigma0/sigma)*sigma/sigma0
elif denoiser == 'MedianFilter':
denoiser = dinv.models.MedianFilter(kernel_size=5)
elif denoiser == 'BM3D':
denoiser = dinv.models.BM3D()
elif denoiser == 'TV':
denoiser = dinv.models.TVDenoiser()
elif denoiser == 'TGV':
denoiser = dinv.models.TGVDenoiser()
elif denoiser == 'Wavelets':
denoiser = dinv.models.WaveletPrior()
elif denoiser == 'DiffUNet':
denoiser = dinv.models.DiffUNet()
elif denoiser == 'DRUNet':
denoiser = dinv.models.DRUNet()
else:
raise ValueError("Invalid denoiser")
noisy = image + torch.randn_like(image) * noise_level
estimated = denoiser(noisy, noise_level)
return torch_to_pil(noisy), torch_to_pil(estimated)
input_image = gr.Image(label='Input Image')
output_images = gr.Image(label='Denoised Image')
noise_image = gr.Image(label='Noisy Image')
input_image_output = gr.Image(label='Input Image')
noise_levels = gr.Dropdown(choices=[0.05, 0.1, 0.2, 0.3, 0.5, 1], value=0.1, label='Noise Level')
denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'DiffUNet', 'BM3D', 'MedianFilter', 'TV', 'TGV', 'Wavelets'], value='DRUNet', label='Denoiser')
demo = gr.Interface(
image_mod,
inputs=[input_image, noise_levels, denoiser],
examples=[['https://upload.wikimedia.org/wikipedia/commons/b/b4/Lionel-Messi-Argentina-2022-FIFA-World-Cup_%28cropped%29.jpg', 0.1, 'DRUNet']],
outputs=[noise_image, output_images],
title="Image Denoising with DeepInverse",
description="Denoise an image using a variety of denoisers and noise levels using the deepinverse library (https://deepinv.github.io/). We only include lightweight models like DnCNN and MedianFilter as this example is intended to be run on a CPU. We also automatically resize the input image to 512 pixels to reduce the computation time. For more advanced models, please run the code locally.",
)
demo.launch()