Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import io | |
import json | |
import numpy | |
import os | |
import pandas as pd | |
import piexif | |
import spaces | |
import timeit | |
import torch | |
import torchvision | |
from diffusers import AutoencoderKL, AutoencoderTiny | |
from PIL import Image | |
from PIL.PngImagePlugin import PngInfo | |
from torchvision.io import decode_image | |
from torchvision.transforms import v2 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") | |
vae = vae.to(device) | |
# Encoding | |
def image_to_latent(image): | |
transforms = v2.Compose([ | |
v2.ToImage(), | |
v2.Resize(512), | |
v2.ToDtype(torch.float32, scale=True) | |
]) | |
tensor = transforms(image).unsqueeze(0).to(device) * 2 - 1 | |
with torch.no_grad(): | |
encoded_image = vae.encode(tensor) | |
return encoded_image.latent_dist.sample() | |
def latent_to_latcomp(latent): | |
latent = latent.to(device) | |
min_val, max_val = latent.min(), latent.max() | |
normalised_latent = (latent - min_val) / (max_val - min_val) * 255 | |
clamped_latent = normalised_latent.clamp(0, 255).squeeze(0).byte() | |
np_latent = clamped_latent.permute(1, 2, 0).cpu().numpy() | |
latcomp = Image.fromarray(np_latent, mode="RGBA") | |
range_data = { "min_val": min_val.item(), "max_val": max_val.item() } | |
json_comment = json.dumps(range_data) | |
exif_dict = piexif.load(latcomp.info["exif"]) if "exif" in latcomp.info else {} | |
if "Exif" not in exif_dict: | |
exif_dict["Exif"] = {} | |
exif_dict["Exif"][piexif.ExifIFD.UserComment] = json_comment.encode("utf-16") | |
exif_bytes = piexif.dump(exif_dict) | |
filepath = "latcomp.webp" | |
latcomp.save(filepath, format="WebP", exif=exif_bytes, lossless=True) | |
return filepath | |
def image_to_latcomp(image): | |
latent = image_to_latent(image) | |
latcomp = latent_to_latcomp(latent) | |
return latcomp | |
# Decoding | |
def latcomp_to_latent(latcomp): | |
exif_dict = piexif.load(latcomp.info["exif"]) | |
user_comment = exif_dict.get("Exif", {}).get(piexif.ExifIFD.UserComment) | |
user_comment = user_comment.decode("utf-16") | |
metadata = json.loads(user_comment) | |
min_val = metadata["min_val"] | |
max_val = metadata["max_val"] | |
latent = v2.PILToTensor()(latcomp).unsqueeze(0).float().to(device) | |
denormalised_latent = (latent / 255) * (max_val - min_val) + min_val | |
return denormalised_latent | |
def latent_to_image(latent): | |
with torch.no_grad(): | |
decoded_image = vae.decode(latent).sample | |
tensor = ((decoded_image + 1) / 2).squeeze(0).clamp(0, 1) | |
transforms = v2.Compose([ | |
v2.ToDtype(torch.uint8, scale=True), | |
]) | |
int_tensor = transforms(tensor.to(device)) | |
np_image = int_tensor.permute(1, 2, 0).cpu().numpy() | |
image = Image.fromarray(np_image) | |
filepath = "image.webp" | |
image.save(filepath, format="WebP", lossless=True) | |
return filepath | |
def latcomp_to_image(latcomp): | |
latent = latcomp_to_latent(latcomp) | |
image = latent_to_image(latent) | |
return image | |
# Gradio | |
comparison_data = { | |
"Method": ["Size (KB)"], | |
"No Compression": [338], | |
"LatComp": [11], | |
"WebP": [35], | |
"JPEG": [66], | |
"TinyPNG": [92], | |
"PNG": [107], | |
"WebP (Lossless)": [214], | |
"PNG (Lossless)": [271], | |
"ZIP (Lossless)": [338] | |
} | |
df = pd.DataFrame(comparison_data) | |
styled_df = df.style.background_gradient(subset=['LatComp'], cmap='YlOrRd') | |
with gr.Blocks() as app: | |
gr.Markdown("# LatComp (Latent Compression)") | |
gr.Markdown() | |
gr.Markdown( | |
""" | |
## LatComp compression uses an AI model (VAE) and some custom code & math to compress images into a small, reversible format. | |
""" | |
) | |
gr.Markdown( | |
""" | |
This work was inspired by **Jeremy Howard** and **Jonathan Whitaker** of [fast.ai](https://www.fast.ai/) and [answer.ai](https://www.answer.ai/).<br> | |
While taking the fast.ai course, I was learning about **Variational Autoencoders (VAE)** and began to wonder:<br> | |
*Is it possible to represent the latent space as an image, and then reconstruct the original image from that representation?* | |
""" | |
) | |
gr.Markdown() | |
gr.Markdown("### **Compression Comparison:** A 338 KB image compressed using various methods.") | |
gr.Dataframe(styled_df) | |
gr.Markdown("**Note:** *Lossless compression means the original image can be perfectly reconstructed.*") | |
gr.Markdown() | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
## **Use Cases:** | |
- Save storage space | |
- Faster file transfers | |
- Backups & archives | |
""" | |
) | |
gr.Markdown( | |
""" | |
## **Potential Improvements:** | |
- Better/Faster AI model (VAE) | |
- Replace custom code & math with an AI model | |
- All-in-one AI Model | |
""" | |
) | |
gr.Markdown() | |
with gr.Tab("Compression"): | |
gr.Markdown( | |
""" | |
## Compress your image into a small and reversible format. | |
Images bigger than 512x512 will be resized to reduce GPU memory usage. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Image", type="pil") | |
with gr.Row(): | |
clear_compress_button = gr.ClearButton() | |
compress_button = gr.Button("Compress", variant="primary") | |
output_latcomp = gr.Image(label="Latcomp") | |
gr.Examples( | |
examples=[["macaw.png"], ["flowers.jpg"], ["newyork.jpg"]], | |
inputs=input_image, | |
outputs=output_latcomp, | |
fn=image_to_latcomp, | |
cache_examples=True, | |
cache_mode="eager" | |
) | |
with gr.Tab("Decompression"): | |
gr.Markdown("## Get your original image back from a latcomp.") | |
with gr.Row(): | |
with gr.Column(): | |
input_latcomp = gr.Image(label="Latcomp", type="pil", image_mode="RGBA", sources=["upload", "clipboard"]) | |
with gr.Row(): | |
clear_decompress_button = gr.ClearButton() | |
decompress_button = gr.Button("Decompress", variant="primary") | |
output_image = gr.Image(label="Image") | |
gr.Examples( | |
examples=[["macaw_latcomp.webp"], ["flowers_latcomp.webp"], ["newyork_latcomp.webp"]], | |
inputs=input_latcomp, | |
outputs=output_image, | |
fn=latcomp_to_image, | |
cache_examples=True, | |
cache_mode="eager" | |
) | |
clear_compress_button.add([input_image, output_latcomp]) | |
compress_button.click(fn=image_to_latcomp, inputs=input_image, outputs=output_latcomp) | |
clear_decompress_button.add([input_latcomp, output_image]) | |
decompress_button.click(fn=latcomp_to_image, inputs=input_latcomp, outputs=output_image) | |
app.launch() |