zca-whitening / app.py
mischeiwiller's picture
fix: fix link to zca docs
4a862ec verified
raw
history blame
No virus
2.35 kB
import gradio as gr
import torch
import kornia as K
from kornia.geometry.transform import resize
import cv2
import numpy as np
from torchvision import transforms
from torchvision.utils import make_grid
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def read_image(img):
image_to_tensor = transforms.ToTensor()
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
img_tensor = image_to_tensor(img)
resized_image = resize(img_tensor.unsqueeze(0), (50, 50)).squeeze(0)
return resized_image
def predict(images, eps):
eps = float(eps)
images = [read_image(img) for img in images]
images = torch.stack(images, dim=0).to(device)
zca = K.enhance.ZCAWhitening(eps=eps, compute_inv=True)
zca.fit(images)
zca_images = zca(images)
grid_zca = make_grid(zca_images, nrow=3, normalize=True).cpu().numpy()
return np.transpose(grid_zca, [1, 2, 0])
title = 'ZCA Whitening with Kornia!'
description = '''[ZCA Whitening](https://paperswithcode.com/method/zca-whitening) is an image preprocessing method that leads to a transformation of data such that the covariance matrix is the identity matrix, leading to decorrelated features:
*Note that you can upload only image files, e.g. jpg, png etc and there should be at least 2 images!*
Learn more about [ZCA Whitening and Kornia](https://kornia.readthedocs.io/en/v0.6.4/_modules/kornia/enhance/zca.html)'''
with gr.Blocks(title=title) as demo:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Row():
input_images = gr.File(file_count="multiple", label="Input Images")
eps_slider = gr.Slider(minimum=0.01, maximum=1, value=0.01, label="Epsilon")
output_image = gr.Image(label="ZCA Whitened Images")
submit_button = gr.Button("Apply ZCA Whitening")
submit_button.click(fn=predict, inputs=[input_images, eps_slider], outputs=output_image)
gr.Examples(
examples=[
[
['irises.jpg', 'roses.jpg', 'sunflower.jpg', 'violets.jpg', 'chamomile.jpg',
'tulips.jpg', 'Alstroemeria.jpg', 'Carnation.jpg', 'Orchid.jpg', 'Peony.jpg'],
0.01
]
],
inputs=[input_images, eps_slider],
)
if __name__ == "__main__":
demo.launch(show_error=True)