Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import io | |
| import json | |
| import numpy | |
| import os | |
| import pandas as pd | |
| import piexif | |
| import spaces | |
| import timeit | |
| import torch | |
| import torchvision | |
| from diffusers import AutoencoderKL, AutoencoderTiny | |
| from PIL import Image | |
| from PIL.PngImagePlugin import PngInfo | |
| from torchvision.io import decode_image | |
| from torchvision.transforms import v2 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") | |
| vae = vae.to(device) | |
| # Encoding | |
| def image_to_latent(image): | |
| transforms = v2.Compose([ | |
| v2.ToImage(), | |
| v2.Resize(512), | |
| v2.ToDtype(torch.float32, scale=True) | |
| ]) | |
| tensor = transforms(image).unsqueeze(0).to(device) * 2 - 1 | |
| with torch.no_grad(): | |
| encoded_image = vae.encode(tensor) | |
| return encoded_image.latent_dist.sample() | |
| def latent_to_latcomp(latent): | |
| latent = latent.to(device) | |
| min_val, max_val = latent.min(), latent.max() | |
| normalised_latent = (latent - min_val) / (max_val - min_val) * 255 | |
| clamped_latent = normalised_latent.clamp(0, 255).squeeze(0).byte() | |
| np_latent = clamped_latent.permute(1, 2, 0).cpu().numpy() | |
| latcomp = Image.fromarray(np_latent, mode="RGBA") | |
| range_data = { "min_val": min_val.item(), "max_val": max_val.item() } | |
| json_comment = json.dumps(range_data) | |
| exif_dict = piexif.load(latcomp.info["exif"]) if "exif" in latcomp.info else {} | |
| if "Exif" not in exif_dict: | |
| exif_dict["Exif"] = {} | |
| exif_dict["Exif"][piexif.ExifIFD.UserComment] = json_comment.encode("utf-16") | |
| exif_bytes = piexif.dump(exif_dict) | |
| filepath = "latcomp.webp" | |
| latcomp.save(filepath, format="WebP", exif=exif_bytes, lossless=True) | |
| return filepath | |
| def image_to_latcomp(image): | |
| latent = image_to_latent(image) | |
| latcomp = latent_to_latcomp(latent) | |
| return latcomp | |
| # Decoding | |
| def latcomp_to_latent(latcomp): | |
| exif_dict = piexif.load(latcomp.info["exif"]) | |
| user_comment = exif_dict.get("Exif", {}).get(piexif.ExifIFD.UserComment) | |
| user_comment = user_comment.decode("utf-16") | |
| metadata = json.loads(user_comment) | |
| min_val = metadata["min_val"] | |
| max_val = metadata["max_val"] | |
| latent = v2.PILToTensor()(latcomp).unsqueeze(0).float().to(device) | |
| denormalised_latent = (latent / 255) * (max_val - min_val) + min_val | |
| return denormalised_latent | |
| def latent_to_image(latent): | |
| with torch.no_grad(): | |
| decoded_image = vae.decode(latent).sample | |
| tensor = ((decoded_image + 1) / 2).squeeze(0).clamp(0, 1) | |
| transforms = v2.Compose([ | |
| v2.ToDtype(torch.uint8, scale=True), | |
| ]) | |
| int_tensor = transforms(tensor.to(device)) | |
| np_image = int_tensor.permute(1, 2, 0).cpu().numpy() | |
| image = Image.fromarray(np_image) | |
| filepath = "image.webp" | |
| image.save(filepath, format="WebP", lossless=True) | |
| return filepath | |
| def latcomp_to_image(latcomp): | |
| latent = latcomp_to_latent(latcomp) | |
| image = latent_to_image(latent) | |
| return image | |
| # Gradio | |
| comparison_data = { | |
| "Method": ["Size (KB)"], | |
| "No Compression": [338], | |
| "LatComp": [11], | |
| "WebP": [35], | |
| "JPEG": [66], | |
| "TinyPNG": [92], | |
| "PNG": [107], | |
| "WebP (Lossless)": [214], | |
| "PNG (Lossless)": [271], | |
| "ZIP (Lossless)": [338] | |
| } | |
| df = pd.DataFrame(comparison_data) | |
| styled_df = df.style.background_gradient(subset=['LatComp'], cmap='YlOrRd') | |
| with gr.Blocks() as app: | |
| gr.Markdown("# LatComp (Latent Compression)") | |
| gr.Markdown() | |
| gr.Markdown( | |
| """ | |
| ## LatComp compression uses an AI model (VAE) and some custom code & math to compress images into a small, reversible format. | |
| """ | |
| ) | |
| gr.Markdown( | |
| """ | |
| This work was inspired by **Jeremy Howard** and **Jonathan Whitaker** of [fast.ai](https://www.fast.ai/) and [answer.ai](https://www.answer.ai/).<br> | |
| While taking the fast.ai course, I was learning about **Variational Autoencoders (VAE)** and began to wonder:<br> | |
| *Is it possible to represent the latent space as an image, and then reconstruct the original image from that representation?* | |
| """ | |
| ) | |
| gr.Markdown() | |
| gr.Markdown("### **Compression Comparison:** A 338 KB image compressed using various methods.") | |
| gr.Dataframe(styled_df) | |
| gr.Markdown("**Note:** *Lossless compression means the original image can be perfectly reconstructed.*") | |
| gr.Markdown() | |
| with gr.Row(): | |
| gr.Markdown( | |
| """ | |
| ## **Use Cases:** | |
| - Save storage space | |
| - Faster file transfers | |
| - Backups & archives | |
| """ | |
| ) | |
| gr.Markdown( | |
| """ | |
| ## **Potential Improvements:** | |
| - Better/Faster AI model (VAE) | |
| - Replace custom code & math with an AI model | |
| - All-in-one AI Model | |
| """ | |
| ) | |
| gr.Markdown() | |
| with gr.Tab("Compression"): | |
| gr.Markdown( | |
| """ | |
| ## Compress your image into a small and reversible format. | |
| Images bigger than 512x512 will be resized to reduce GPU memory usage. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Image", type="pil") | |
| with gr.Row(): | |
| clear_compress_button = gr.ClearButton() | |
| compress_button = gr.Button("Compress", variant="primary") | |
| output_latcomp = gr.Image(label="Latcomp") | |
| gr.Examples( | |
| examples=[["macaw.png"], ["flowers.jpg"], ["newyork.jpg"]], | |
| inputs=input_image, | |
| outputs=output_latcomp, | |
| fn=image_to_latcomp, | |
| cache_examples=True, | |
| cache_mode="eager" | |
| ) | |
| with gr.Tab("Decompression"): | |
| gr.Markdown("## Get your original image back from a latcomp.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_latcomp = gr.Image(label="Latcomp", type="pil", image_mode="RGBA", sources=["upload", "clipboard"]) | |
| with gr.Row(): | |
| clear_decompress_button = gr.ClearButton() | |
| decompress_button = gr.Button("Decompress", variant="primary") | |
| output_image = gr.Image(label="Image") | |
| gr.Examples( | |
| examples=[["macaw_latcomp.webp"], ["flowers_latcomp.webp"], ["newyork_latcomp.webp"]], | |
| inputs=input_latcomp, | |
| outputs=output_image, | |
| fn=latcomp_to_image, | |
| cache_examples=True, | |
| cache_mode="eager" | |
| ) | |
| clear_compress_button.add([input_image, output_latcomp]) | |
| compress_button.click(fn=image_to_latcomp, inputs=input_image, outputs=output_latcomp) | |
| clear_decompress_button.add([input_latcomp, output_image]) | |
| decompress_button.click(fn=latcomp_to_image, inputs=input_latcomp, outputs=output_image) | |
| app.launch() |