darknoon commited on
Commit
aa212ba
1 Parent(s): 392a374

Upload from local env

Browse files
Files changed (1) hide show
  1. app.py +245 -0
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ import colorsys
6
+
7
+ from diffusers import VQModel
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from diffusers.pipelines.wuerstchen.modeling_paella_vq_model import PaellaVQModel
10
+ from abc import abstractmethod
11
+ import torch.backends
12
+ import torch.mps
13
+ from PIL import Image
14
+
15
+
16
+ if torch.cuda.is_available():
17
+ device = torch.device("cuda")
18
+ elif torch.backends.mps.is_available():
19
+ device = torch.device("mps")
20
+ else:
21
+ device = torch.device("cpu")
22
+
23
+
24
+ # abstract class VQImageRoundtripPipeline:
25
+ class ImageRoundtripPipeline:
26
+ @abstractmethod
27
+ def roundtrip_image(self, image, output_type="pil"): ...
28
+
29
+
30
+ class VQImageRoundtripPipeline(ImageRoundtripPipeline):
31
+ vqvae: VQModel
32
+ vae_scale_factor: int
33
+ vqvae_processor: VaeImageProcessor
34
+
35
+ def __init__(self):
36
+ self.vqvae = VQModel.from_pretrained("amused/amused-512", subfolder="vqvae")
37
+ self.vqvae.eval()
38
+ self.vqvae.to(device)
39
+ self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
40
+ self.vqvae_processor = VaeImageProcessor(
41
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False
42
+ )
43
+ print("VQ-GAN model loaded", self.vqvae)
44
+
45
+ def roundtrip_image(self, image, output_type="pil"):
46
+ image = self.vqvae_processor.preprocess(image)
47
+ device = self.vqvae.device
48
+ needs_upcasting = (
49
+ self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
50
+ )
51
+
52
+ batch_size, im_channels, height, width = image.shape
53
+
54
+ if needs_upcasting:
55
+ self.vqvae.float()
56
+
57
+ latents = self.vqvae.encode(
58
+ image.to(dtype=self.vqvae.dtype, device=device)
59
+ ).latents
60
+ latents_batch_size, latent_channels, latents_height, latents_width = (
61
+ latents.shape
62
+ )
63
+ latents = self.vqvae.quantize(latents)[2][2].reshape(
64
+ batch_size, latents_height, latents_width
65
+ )
66
+ output = self.vqvae.decode(
67
+ latents,
68
+ force_not_quantize=True,
69
+ shape=(
70
+ batch_size,
71
+ height // self.vae_scale_factor,
72
+ width // self.vae_scale_factor,
73
+ self.vqvae.config.latent_channels,
74
+ ),
75
+ ).sample.clip(0, 1)
76
+ output = self.vqvae_processor.postprocess(output, output_type)
77
+
78
+ if needs_upcasting:
79
+ self.vqvae.half()
80
+
81
+ return output[0], latents.cpu().numpy(), self.vqvae.config.num_vq_embeddings
82
+
83
+
84
+ class PaellaImageRoundtripPipeline(ImageRoundtripPipeline):
85
+ vqgan: PaellaVQModel
86
+ vae_scale_factor: int
87
+ vqvae_processor: VaeImageProcessor
88
+
89
+ def __init__(self):
90
+ self.vqgan = PaellaVQModel.from_pretrained(
91
+ "warp-ai/wuerstchen", subfolder="vqgan"
92
+ )
93
+ self.vqgan.eval()
94
+ self.vqgan.to(device)
95
+ self.vae_scale_factor = 4
96
+ self.vqvae_processor = VaeImageProcessor(
97
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False
98
+ )
99
+ print("Paella VQ-GAN model loaded", self.vqgan)
100
+
101
+ def roundtrip_image(self, image, output_type="pil"):
102
+ image = self.vqvae_processor.preprocess(image)
103
+ device = self.vqgan.device
104
+
105
+ batch_size, im_channels, height, width = image.shape
106
+
107
+ latents = self.vqgan.encode(
108
+ image.to(dtype=self.vqgan.dtype, device=device)
109
+ ).latents
110
+ latents_batch_size, latent_channels, latents_height, latents_width = (
111
+ latents.shape
112
+ )
113
+ # latents = latents * self.vqgan.config.scale_factor
114
+ # Manually quantize so we can inspect
115
+ latents_q = self.vqgan.vquantizer(latents)[2][2].reshape(
116
+ batch_size, latents_height, latents_width
117
+ )
118
+ print("latents after quantize", (latents_q.shape, latents_q.dtype))
119
+ images = self.vqgan.decode(latents).sample.clamp(0, 1)
120
+ output = self.vqvae_processor.postprocess(images, output_type)
121
+
122
+ # if needs_upcasting:
123
+ # self.vqgan.half()
124
+
125
+ return output[0], latents_q.cpu().numpy(), self.vqgan.config.num_vq_embeddings
126
+
127
+
128
+ pipeline_paella = PaellaImageRoundtripPipeline()
129
+ pipeline_vq = VQImageRoundtripPipeline()
130
+
131
+
132
+ # Function to generate a list of unique colors
133
+ def generate_unique_colors_hsl(n):
134
+ colors = []
135
+ for i in range(n):
136
+ hue = i / (n // 4) # Distribute hues evenly around the color wheel 4 times
137
+ lightness = 0.8 - (i / n) * 0.6 # Decrease brightness from 0.8 to 0.2
138
+ saturation = 1.0
139
+ rgb = colorsys.hls_to_rgb(hue, lightness, saturation)
140
+ rgb = tuple(int(255 * x) for x in rgb)
141
+ colors.append(rgb)
142
+ return colors
143
+
144
+
145
+ # Function to create the image from VQGAN tokens
146
+ def vqgan_tokens_to_image(tokens, codebook_size, downscale_factor):
147
+ # Generate unique colors for each token in the codebook
148
+ colors = generate_unique_colors_hsl(codebook_size)
149
+
150
+ # Create a lookup table
151
+ lookup_table = np.array(colors, dtype=np.uint8)
152
+
153
+ # Extract the token array (remove the batch dimension)
154
+ token_array = tokens[0]
155
+
156
+ # Map tokens to their RGB colors using the lookup table
157
+ color_image = lookup_table[token_array]
158
+
159
+ # Create a PIL image from the numpy array
160
+ img = Image.fromarray(color_image, "RGB")
161
+
162
+ # Upscale the image using nearest neighbor interpolation
163
+ img = img.resize(
164
+ (
165
+ color_image.shape[1] * downscale_factor,
166
+ color_image.shape[0] * downscale_factor,
167
+ ),
168
+ Image.NEAREST,
169
+ )
170
+
171
+ return img
172
+
173
+
174
+ # This is a gradio space that lets you encode an image with various encoder-decoder pairs, eg VQ-GAN, SDXL's VAE, etc and check the image quality
175
+
176
+
177
+ # def image_grid_to_string(image_grid):
178
+ # """Convert a latent vq index "image" grid to a string, input shape is (1, height, width)"""
179
+ # return "\n".join(
180
+ # [" ".join([str(int(x)) for x in row]) for row in image_grid.squeeze()]
181
+ # )
182
+
183
+
184
+ def describe_shape(shape):
185
+ return f"Shape: {shape} num elements: {np.prod(shape)}"
186
+
187
+
188
+ # @spaces.GPU
189
+ @torch.no_grad()
190
+ def roundtrip_image(
191
+ image,
192
+ model: List[Literal["vqgan", Literal["paella"]]],
193
+ size: List[Literal["256x256", "512x512", "1024x1024"]],
194
+ output_type="pil",
195
+ ):
196
+ if size == "256x256":
197
+ image = image.resize((256, 256))
198
+ elif size == "512x512":
199
+ image = image.resize((512, 512))
200
+ elif size == "1024x1024":
201
+ image = image.resize((1024, 1024))
202
+ else:
203
+ raise ValueError(f"Unknown size {size}")
204
+
205
+ if model == "vqgan":
206
+ image, latents, codebook_size = pipeline_vq.roundtrip_image(image, output_type)
207
+ return (
208
+ image,
209
+ vqgan_tokens_to_image(
210
+ latents, codebook_size, downscale_factor=pipeline_vq.vae_scale_factor
211
+ ),
212
+ describe_shape(latents.shape),
213
+ )
214
+ elif model == "paella":
215
+ image, latents, codebook_size = pipeline_paella.roundtrip_image(
216
+ image, output_type
217
+ )
218
+ return (
219
+ image,
220
+ vqgan_tokens_to_image(
221
+ latents, codebook_size, downscale_factor=pipeline_vq.vae_scale_factor
222
+ ),
223
+ describe_shape(latents.shape),
224
+ )
225
+ else:
226
+ raise ValueError(f"Unknown model {model}")
227
+
228
+
229
+ demo = gr.Interface(
230
+ fn=roundtrip_image,
231
+ inputs=[
232
+ gr.Image(type="pil"),
233
+ gr.Dropdown(["vqgan", "paella"], label="Model", value="vqgan"),
234
+ gr.Dropdown(["256x256", "512x512", "1024x1024"], label="Size", value="512x512"),
235
+ ],
236
+ outputs=[
237
+ gr.Image(label="Reconstructed"),
238
+ gr.Image(label="Tokens"),
239
+ gr.Text(label="VQ Shape"),
240
+ ],
241
+ title="Image Tokenizer Playground",
242
+ description="Round-trip an image through an encode-decoder pair to see the quality loss from the VQ-GAN for image generation, etc.",
243
+ )
244
+
245
+ demo.launch()