decodingchris commited on
Commit
03b30bd
1 Parent(s): a3e9e30

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import io
3
+ import json
4
+ import numpy
5
+ import os
6
+ import pandas as pd
7
+ import piexif
8
+ import spaces
9
+ import timeit
10
+ import torch
11
+ import torchvision
12
+
13
+ from diffusers import AutoencoderKL, AutoencoderTiny
14
+ from PIL import Image
15
+ from PIL.PngImagePlugin import PngInfo
16
+ from torchvision.io import decode_image
17
+ from torchvision.transforms import v2
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
21
+ vae = vae.to(device)
22
+
23
+ # Encoding
24
+
25
+ def image_to_latent(image):
26
+ transforms = v2.Compose([
27
+ v2.ToImage(),
28
+ v2.Resize(512),
29
+ v2.ToDtype(torch.float32, scale=True)
30
+ ])
31
+ tensor = transforms(image).unsqueeze(0).to(device) * 2 - 1
32
+ with torch.no_grad():
33
+ encoded_image = vae.encode(tensor)
34
+ return encoded_image.latent_dist.sample()
35
+
36
+ def latent_to_latcomp(latent):
37
+ latent = latent.to(device)
38
+ min_val, max_val = latent.min(), latent.max()
39
+ normalised_latent = (latent - min_val) / (max_val - min_val) * 255
40
+ clamped_latent = normalised_latent.clamp(0, 255).squeeze(0).byte()
41
+ np_latent = clamped_latent.permute(1, 2, 0).cpu().numpy()
42
+ latcomp = Image.fromarray(np_latent, mode="RGBA")
43
+ range_data = { "min_val": min_val.item(), "max_val": max_val.item() }
44
+ json_comment = json.dumps(range_data)
45
+ exif_dict = piexif.load(latcomp.info["exif"]) if "exif" in latcomp.info else {}
46
+ if "Exif" not in exif_dict:
47
+ exif_dict["Exif"] = {}
48
+ exif_dict["Exif"][piexif.ExifIFD.UserComment] = json_comment.encode("utf-16")
49
+ exif_bytes = piexif.dump(exif_dict)
50
+ filepath = "latcomp.webp"
51
+ latcomp.save(filepath, format="WebP", exif=exif_bytes, lossless=True)
52
+ return filepath
53
+
54
+ @spaces.GPU
55
+ def image_to_latcomp(image):
56
+ latent = image_to_latent(image)
57
+ latcomp = latent_to_latcomp(latent)
58
+ return latcomp
59
+
60
+ # Decoding
61
+
62
+ def latcomp_to_latent(latcomp):
63
+ exif_dict = piexif.load(latcomp.info["exif"])
64
+ user_comment = exif_dict.get("Exif", {}).get(piexif.ExifIFD.UserComment)
65
+ user_comment = user_comment.decode("utf-16")
66
+ metadata = json.loads(user_comment)
67
+ min_val = metadata["min_val"]
68
+ max_val = metadata["max_val"]
69
+ latent = v2.PILToTensor()(latcomp).unsqueeze(0).float().to(device)
70
+ denormalised_latent = (latent / 255) * (max_val - min_val) + min_val
71
+ return denormalised_latent
72
+
73
+ def latent_to_image(latent):
74
+ with torch.no_grad():
75
+ decoded_image = vae.decode(latent).sample
76
+ tensor = ((decoded_image + 1) / 2).squeeze(0).clamp(0, 1)
77
+ transforms = v2.Compose([
78
+ v2.ToDtype(torch.uint8, scale=True),
79
+ ])
80
+ int_tensor = transforms(tensor.to(device))
81
+ np_image = int_tensor.permute(1, 2, 0).cpu().numpy()
82
+ image = Image.fromarray(np_image)
83
+ filepath = "image.webp"
84
+ image.save(filepath, format="WebP", lossless=True)
85
+ return filepath
86
+
87
+ @spaces.GPU
88
+ def latcomp_to_image(latcomp):
89
+ latent = latcomp_to_latent(latcomp)
90
+ image = latent_to_image(latent)
91
+ return image
92
+
93
+ # Gradio
94
+
95
+ comparison_data = {
96
+ "Method": ["Size (KB)"],
97
+ "No Compression": [338],
98
+ "LatComp": [11],
99
+ "WebP": [35],
100
+ "JPEG": [66],
101
+ "TinyPNG": [92],
102
+ "PNG": [107],
103
+ "WebP (Lossless)": [214],
104
+ "PNG (Lossless)": [271],
105
+ "ZIP (Lossless)": [338]
106
+ }
107
+
108
+ df = pd.DataFrame(comparison_data)
109
+ styled_df = df.style.background_gradient(subset=['LatComp'], cmap='YlOrRd')
110
+
111
+ with gr.Blocks() as app:
112
+ gr.Markdown("# LatComp (Latent Compression)")
113
+ gr.Markdown()
114
+ gr.Markdown(
115
+ """
116
+ ## LatComp compression uses an AI model (VAE) and some custom code & math to compress images into a small, reversible format.
117
+ """
118
+ )
119
+ gr.Markdown(
120
+ """
121
+ This work was inspired by **Jeremy Howard** and **Jonathan Whitaker** of [fast.ai](https://www.fast.ai/) and [answer.ai](https://www.answer.ai/).<br>
122
+ While taking the fast.ai course, I was learning about **Variational Autoencoders (VAE)** and began to wonder:<br>
123
+ *Is it possible to represent the latent space as an image, and then reconstruct the original image from that representation?*
124
+ """
125
+ )
126
+ gr.Markdown()
127
+ gr.Markdown("### **Compression Comparison:** A 338 KB image compressed using various methods.")
128
+ gr.Dataframe(styled_df)
129
+ gr.Markdown("**Note:** *Lossless compression means the original image can be perfectly reconstructed.*")
130
+ gr.Markdown()
131
+ with gr.Row():
132
+ gr.Markdown(
133
+ """
134
+ ## **Use Cases:**
135
+ - Save storage space
136
+ - Faster file transfers
137
+ - Backups & archives
138
+ """
139
+ )
140
+ gr.Markdown(
141
+ """
142
+ ## **Potential Improvements:**
143
+ - Better/Faster AI model (VAE)
144
+ - Replace custom code & math with an AI model
145
+ - All-in-one AI Model
146
+ """
147
+ )
148
+ gr.Markdown()
149
+ with gr.Tab("Compression"):
150
+ gr.Markdown(
151
+ """
152
+ ## Compress your image into a small and reversible format.
153
+ Images bigger than 512x512 will be resized to reduce GPU memory usage.
154
+ """
155
+ )
156
+ with gr.Row():
157
+ with gr.Column():
158
+ input_image = gr.Image(label="Image", type="pil")
159
+ with gr.Row():
160
+ clear_compress_button = gr.ClearButton()
161
+ compress_button = gr.Button("Compress", variant="primary")
162
+ output_latcomp = gr.Image(label="Latcomp")
163
+ gr.Examples(
164
+ examples=[["macaw.png"], ["flowers.jpg"], ["newyork.jpg"]],
165
+ inputs=input_image,
166
+ outputs=output_latcomp,
167
+ fn=image_to_latcomp,
168
+ run_on_click=True
169
+ )
170
+
171
+ with gr.Tab("Decompression"):
172
+ gr.Markdown("## Get your original image back from a latcomp.")
173
+ with gr.Row():
174
+ with gr.Column():
175
+ input_latcomp = gr.Image(label="Latcomp", type="pil", image_mode="RGBA", sources=["upload", "clipboard"])
176
+ with gr.Row():
177
+ clear_decompress_button = gr.ClearButton()
178
+ decompress_button = gr.Button("Decompress", variant="primary")
179
+ output_image = gr.Image(label="Image")
180
+ gr.Examples(
181
+ examples=[["macaw_latcomp.webp"], ["flowers_latcomp.webp"], ["newyork_latcomp.webp"]],
182
+ inputs=input_latcomp,
183
+ outputs=output_image,
184
+ fn=latcomp_to_image,
185
+ run_on_click=True
186
+ )
187
+
188
+ clear_compress_button.add([input_image, output_latcomp])
189
+ compress_button.click(fn=image_to_latcomp, inputs=input_image, outputs=output_latcomp)
190
+ clear_decompress_button.add([input_latcomp, output_image])
191
+ decompress_button.click(fn=latcomp_to_image, inputs=input_latcomp, outputs=output_image)
192
+
193
+ app.launch()