File size: 6,362 Bytes
03b30bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d6b90a
 
03b30bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d6b90a
 
03b30bd
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import gradio as gr
import io
import json
import numpy
import os
import pandas as pd
import piexif
import spaces
import timeit
import torch
import torchvision

from diffusers import AutoencoderKL, AutoencoderTiny
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from torchvision.io import decode_image
from torchvision.transforms import v2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
vae = vae.to(device)

# Encoding

def image_to_latent(image):
  transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize(512),
    v2.ToDtype(torch.float32, scale=True)
  ])
  tensor = transforms(image).unsqueeze(0).to(device) * 2 - 1
  with torch.no_grad():
    encoded_image = vae.encode(tensor)
  return encoded_image.latent_dist.sample()

def latent_to_latcomp(latent):
  latent = latent.to(device)
  min_val, max_val = latent.min(), latent.max()
  normalised_latent = (latent - min_val) / (max_val - min_val) * 255
  clamped_latent = normalised_latent.clamp(0, 255).squeeze(0).byte()
  np_latent = clamped_latent.permute(1, 2, 0).cpu().numpy()
  latcomp = Image.fromarray(np_latent, mode="RGBA")
  range_data = { "min_val": min_val.item(), "max_val": max_val.item() }
  json_comment = json.dumps(range_data)
  exif_dict = piexif.load(latcomp.info["exif"]) if "exif" in latcomp.info else {}
  if "Exif" not in exif_dict:
    exif_dict["Exif"] = {}
  exif_dict["Exif"][piexif.ExifIFD.UserComment] = json_comment.encode("utf-16")
  exif_bytes = piexif.dump(exif_dict)
  filepath = "latcomp.webp"
  latcomp.save(filepath, format="WebP", exif=exif_bytes, lossless=True)
  return filepath

@spaces.GPU
def image_to_latcomp(image):
  latent = image_to_latent(image)
  latcomp = latent_to_latcomp(latent)
  return latcomp

# Decoding

def latcomp_to_latent(latcomp):
  exif_dict = piexif.load(latcomp.info["exif"])
  user_comment = exif_dict.get("Exif", {}).get(piexif.ExifIFD.UserComment)
  user_comment = user_comment.decode("utf-16")
  metadata = json.loads(user_comment)
  min_val = metadata["min_val"]
  max_val = metadata["max_val"]
  latent = v2.PILToTensor()(latcomp).unsqueeze(0).float().to(device)
  denormalised_latent = (latent / 255) * (max_val - min_val) + min_val
  return denormalised_latent

def latent_to_image(latent):
  with torch.no_grad():
    decoded_image = vae.decode(latent).sample
  tensor = ((decoded_image + 1) / 2).squeeze(0).clamp(0, 1)
  transforms = v2.Compose([
    v2.ToDtype(torch.uint8, scale=True),
  ])
  int_tensor = transforms(tensor.to(device))
  np_image = int_tensor.permute(1, 2, 0).cpu().numpy()
  image = Image.fromarray(np_image)
  filepath = "image.webp"
  image.save(filepath, format="WebP", lossless=True)
  return filepath

@spaces.GPU
def latcomp_to_image(latcomp):
  latent = latcomp_to_latent(latcomp)
  image = latent_to_image(latent)
  return image

# Gradio

comparison_data = {
    "Method": ["Size (KB)"],
    "No Compression": [338],
    "LatComp": [11],
    "WebP": [35],
    "JPEG": [66],
    "TinyPNG": [92],
    "PNG": [107],
    "WebP (Lossless)": [214],
    "PNG (Lossless)": [271],
    "ZIP (Lossless)": [338]
}

df = pd.DataFrame(comparison_data)
styled_df = df.style.background_gradient(subset=['LatComp'], cmap='YlOrRd')

with gr.Blocks() as app:
  gr.Markdown("# LatComp (Latent Compression)")
  gr.Markdown()
  gr.Markdown(
    """
    ## LatComp compression uses an AI model (VAE) and some custom code & math to compress images into a small, reversible format.
    """
  )
  gr.Markdown(
    """
    This work was inspired by **Jeremy Howard** and **Jonathan Whitaker** of [fast.ai](https://www.fast.ai/) and [answer.ai](https://www.answer.ai/).<br>
    While taking the fast.ai course, I was learning about **Variational Autoencoders (VAE)** and began to wonder:<br>
    *Is it possible to represent the latent space as an image, and then reconstruct the original image from that representation?*
    """
  )
  gr.Markdown()
  gr.Markdown("### **Compression Comparison:** A 338 KB image compressed using various methods.")
  gr.Dataframe(styled_df)
  gr.Markdown("**Note:** *Lossless compression means the original image can be perfectly reconstructed.*")
  gr.Markdown()
  with gr.Row():
    gr.Markdown(
      """
      ## **Use Cases:**
      - Save storage space
      - Faster file transfers
      - Backups & archives
      """
    )
    gr.Markdown(
      """
      ## **Potential Improvements:**
      - Better/Faster AI model (VAE)
      - Replace custom code & math with an AI model
      - All-in-one AI Model
      """
    )
  gr.Markdown()
  with gr.Tab("Compression"):
    gr.Markdown(
      """
      ## Compress your image into a small and reversible format.
      Images bigger than 512x512 will be resized to reduce GPU memory usage.
      """
    )
    with gr.Row():
      with gr.Column():
        input_image = gr.Image(label="Image", type="pil")
        with gr.Row():
          clear_compress_button = gr.ClearButton()
          compress_button = gr.Button("Compress", variant="primary")
      output_latcomp = gr.Image(label="Latcomp")
    gr.Examples(
      examples=[["macaw.png"], ["flowers.jpg"], ["newyork.jpg"]],
      inputs=input_image,
      outputs=output_latcomp,
      fn=image_to_latcomp,
      cache_examples=True,
      cache_mode="eager"
    )

  with gr.Tab("Decompression"):
    gr.Markdown("## Get your original image back from a latcomp.")
    with gr.Row():
      with gr.Column():
        input_latcomp = gr.Image(label="Latcomp", type="pil", image_mode="RGBA", sources=["upload", "clipboard"])
        with gr.Row():
          clear_decompress_button = gr.ClearButton()
          decompress_button = gr.Button("Decompress", variant="primary")
      output_image = gr.Image(label="Image")
    gr.Examples(
      examples=[["macaw_latcomp.webp"], ["flowers_latcomp.webp"], ["newyork_latcomp.webp"]],
      inputs=input_latcomp,
      outputs=output_image,
      fn=latcomp_to_image,
      cache_examples=True,
      cache_mode="eager"
    )

  clear_compress_button.add([input_image, output_latcomp])
  compress_button.click(fn=image_to_latcomp, inputs=input_image, outputs=output_latcomp)
  clear_decompress_button.add([input_latcomp, output_image])
  decompress_button.click(fn=latcomp_to_image, inputs=input_latcomp, outputs=output_image)

app.launch()