mischeiwiller's picture
fix: handle both numpy arrays and file paths in inference function
bdf18f8 verified
raw
history blame
2.57 kB
import gradio as gr
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import kornia as K
def inference(file1,num_iters):
# Check if file1 is already a numpy array
if isinstance(file1, np.ndarray):
img = file1
else:
# If it's not a numpy array, assume it's a file path
img = cv2.imread(file1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32) / 255.0
img = img + np.random.normal(loc=0.0, scale=0.1, size=img.shape)
img = np.clip(img, 0.0, 1.0)
# convert to torch tensor
noisy_image = K.utils.image_to_tensor(img).squeeze()
class TVDenoise(torch.nn.Module):
def __init__(self, noisy_image):
super(TVDenoise, self).__init__()
self.l2_term = torch.nn.MSELoss(reduction='mean')
self.regularization_term = K.losses.TotalVariation()
# create the variable which will be optimized to produce the noise free image
self.clean_image = torch.nn.Parameter(data=noisy_image.clone(), requires_grad=True)
self.noisy_image = noisy_image
def forward(self):
# print(self.l2_term(self.clean_image, self.noisy_image))
# print(self.regularization_term(self.clean_image))
return self.l2_term(self.clean_image, self.noisy_image) + 0.0001 * self.regularization_term(self.clean_image)
def get_clean_image(self):
return self.clean_image
tv_denoiser = TVDenoise(noisy_image)
# define the optimizer to optimize the 1 parameter of tv_denoiser
optimizer = torch.optim.SGD(tv_denoiser.parameters(), lr=0.1, momentum=0.9)
for i in range(int(num_iters)):
optimizer.zero_grad()
loss = torch.mean(tv_denoiser())
if i % 50 == 0:
print("Loss in iteration {} of {}: {:.3f}".format(i, num_iters, loss.item()))
loss.backward()
optimizer.step()
img_clean: np.ndarray = K.utils.tensor_to_image(tv_denoiser.get_clean_image())
return img, img_clean
examples = [ ["doraemon.png",2000]
]
inputs = [
gr.Image(type='numpy', label='Input Image'),
gr.Slider(minimum=50, maximum=10000, step=50, value=500, label="num_iters")
]
outputs = [
gr.Image(type='numpy', label='Noised Image'),
gr.Image(type='numpy', label='Denoised Image'),
]
title = "Denoise image using total variation"
demo_app = gr.Interface(
fn=inference,
inputs=inputs,
outputs=outputs,
title=title,
examples=examples,
theme='huggingface',
)
demo_app.launch(debug=True)