prithivMLmods commited on
Commit
2c889e9
·
verified ·
1 Parent(s): 4f9f20b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +887 -387
app.py CHANGED
@@ -1,248 +1,148 @@
1
  import os
 
2
  import gc
3
- import gradio as gr
4
- import numpy as np
 
5
  import random
 
 
 
 
 
 
6
  import spaces
 
7
  import torch
8
- from diffusers import Flux2KleinPipeline, AutoencoderKLFlux2
9
  from PIL import Image
10
- from pathlib import Path
11
- import concurrent.futures
12
- import threading
13
- from typing import Iterable
14
-
15
- from gradio.themes import Soft
16
- from gradio.themes.utils import colors, fonts, sizes
17
-
18
- colors.orange_red = colors.Color(
19
- name="orange_red",
20
- c50="#FFF0E5",
21
- c100="#FFE0CC",
22
- c200="#FFC299",
23
- c300="#FFA366",
24
- c400="#FF8533",
25
- c500="#FF4500",
26
- c600="#E63E00",
27
- c700="#CC3700",
28
- c800="#B33000",
29
- c900="#992900",
30
- c950="#802200",
31
- )
32
-
33
- class OrangeRedTheme(Soft):
34
- def __init__(
35
- self,
36
- *,
37
- primary_hue: colors.Color | str = colors.gray,
38
- secondary_hue: colors.Color | str = colors.orange_red,
39
- neutral_hue: colors.Color | str = colors.slate,
40
- text_size: sizes.Size | str = sizes.text_lg,
41
- font: fonts.Font | str | Iterable[fonts.Font | str] = (
42
- fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
43
- ),
44
- font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
45
- fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
46
- ),
47
- ):
48
- super().__init__(
49
- primary_hue=primary_hue,
50
- secondary_hue=secondary_hue,
51
- neutral_hue=neutral_hue,
52
- text_size=text_size,
53
- font=font,
54
- font_mono=font_mono,
55
- )
56
- super().set(
57
- background_fill_primary="*primary_50",
58
- background_fill_primary_dark="*primary_900",
59
- body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
60
- body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
61
- button_primary_text_color="white",
62
- button_primary_text_color_hover="white",
63
- button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
64
- button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
65
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
66
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
67
- button_secondary_text_color="black",
68
- button_secondary_text_color_hover="white",
69
- button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
70
- button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
71
- button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
72
- button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
73
- slider_color="*secondary_500",
74
- slider_color_dark="*secondary_600",
75
- block_title_text_weight="600",
76
- block_border_width="3px",
77
- block_shadow="*shadow_drop_lg",
78
- button_primary_shadow="*shadow_drop_lg",
79
- button_large_padding="11px",
80
- color_accent_soft="*primary_100",
81
- block_label_background_fill="*primary_200",
82
- )
83
 
84
- orange_red_theme = OrangeRedTheme()
 
 
 
85
 
86
- dtype = torch.bfloat16
87
- device = "cuda" if torch.cuda.is_available() else "cpu"
88
 
89
- MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
90
  MAX_IMAGE_SIZE = 1024
91
- EXAMPLES_DIR = Path("examples")
92
 
 
 
 
 
 
 
 
 
 
 
 
93
  print("Loading 4B Distilled model (Standard VAE)...")
94
  pipe_standard = Flux2KleinPipeline.from_pretrained(
95
  "black-forest-labs/FLUX.2-klein-4B",
96
  torch_dtype=dtype,
97
- )
98
  pipe_standard.enable_model_cpu_offload()
99
 
100
  print("Loading Small Decoder VAE...")
101
  vae_small = AutoencoderKLFlux2.from_pretrained(
102
  "black-forest-labs/FLUX.2-small-decoder",
103
  torch_dtype=dtype,
104
- )
105
 
106
  print("Loading 4B Distilled model (Small Decoder VAE)...")
107
  pipe_small_decoder = Flux2KleinPipeline.from_pretrained(
108
  "black-forest-labs/FLUX.2-klein-4B",
109
  vae=vae_small,
110
  torch_dtype=dtype,
111
- )
112
  pipe_small_decoder.enable_model_cpu_offload()
113
 
114
  pipe_lock_standard = threading.Lock()
115
  pipe_lock_small = threading.Lock()
116
 
 
117
  def calc_dimensions(pil_img: Image.Image):
118
- """
119
- Given a PIL image return (width, height) snapped to multiples of 8,
120
- fitting within 1024 px on the long side, min 256 px on each side.
121
- Uses round() so we match the reference app exactly.
122
- """
123
  iw, ih = pil_img.size
124
  aspect = iw / ih
125
 
126
- if aspect >= 1: # landscape / square
127
  new_width = 1024
128
  new_height = int(round(1024 / aspect))
129
- else: # portrait
130
  new_height = 1024
131
  new_width = int(round(1024 * aspect))
132
 
133
- # snap to 8-pixel grid with round(), clamp to [256, 1024]
134
  new_width = max(256, min(1024, round(new_width / 8) * 8))
135
  new_height = max(256, min(1024, round(new_height / 8) * 8))
136
  return new_width, new_height
137
 
138
-
139
- def update_dimensions_from_image(image_list):
140
- """
141
- Called by the gallery .upload() event.
142
- Returns updated slider values for width and height.
143
- """
144
- if not image_list:
145
- return 1024, 1024
146
-
147
- # gallery items arrive as PIL images when type="pil"
148
- item = image_list[0]
149
- img = item[0] if isinstance(item, tuple) else item
150
-
151
- if isinstance(img, str):
152
- img = Image.open(img).convert("RGB")
153
- elif not isinstance(img, Image.Image):
154
- return 1024, 1024
155
-
156
- return calc_dimensions(img)
157
-
158
- def parse_and_resize_images(input_images, width: int, height: int):
159
- """
160
- Parse the gallery input and resize every frame to (width, height).
161
- Returns a list[PIL.Image] or None.
162
- """
163
- if input_images is None:
164
  return None
165
-
166
- raw_list = []
167
-
168
- if isinstance(input_images, str):
169
- if os.path.exists(input_images):
170
- raw_list = [Image.open(input_images).convert("RGB")]
171
- elif isinstance(input_images, Image.Image):
172
- raw_list = [input_images.convert("RGB")]
173
- elif isinstance(input_images, list):
174
- for item in input_images:
175
- try:
176
- src = item[0] if isinstance(item, tuple) else item
177
- if isinstance(src, str):
178
- raw_list.append(Image.open(src).convert("RGB"))
179
- elif isinstance(src, Image.Image):
180
- raw_list.append(src.convert("RGB"))
181
- elif hasattr(src, "name"):
182
- raw_list.append(Image.open(src.name).convert("RGB"))
183
- except Exception as e:
184
- print(f"Skipping invalid image: {e}")
185
-
186
- if not raw_list:
187
- return None
188
-
189
- resized = [
190
- img.resize((width, height), Image.LANCZOS)
191
- for img in raw_list
192
- ]
193
- return resized
194
 
195
  def run_pipeline(pipe, lock, kwargs, seed):
196
  with lock:
197
- gen = torch.Generator(device="cpu").manual_seed(seed)
198
  result = pipe(**kwargs, generator=gen).images[0]
199
  return result
200
 
 
 
 
 
 
 
 
201
  @spaces.GPU(duration=120)
202
  def infer(
203
- prompt,
204
- input_images=None,
205
- seed=42,
206
- randomize_seed=False,
207
- width=1024,
208
- height=1024,
209
- num_inference_steps=4,
210
- guidance_scale=1.0,
211
- progress=gr.Progress(track_tqdm=True),
212
  ):
213
  gc.collect()
214
- torch.cuda.empty_cache()
 
215
 
216
  if not prompt or not prompt.strip():
217
- raise gr.Error("Please enter a prompt.")
218
 
219
  if randomize_seed:
220
  seed = random.randint(0, MAX_SEED)
221
 
222
- # ── width / height: derive from the first uploaded image if present ──
223
  image_list = None
224
- if input_images:
225
- # Re-derive dimensions from the actual first image so they are
226
- # always consistent with what the pipeline will receive.
227
- item = (
228
- input_images[0][0]
229
- if isinstance(input_images[0], tuple)
230
- else input_images[0]
231
- )
232
- if isinstance(item, str):
233
- first_pil = Image.open(item).convert("RGB")
234
- elif isinstance(item, Image.Image):
235
- first_pil = item.convert("RGB")
236
- else:
237
- first_pil = None
238
-
239
- if first_pil is not None:
240
  width, height = calc_dimensions(first_pil)
 
 
 
241
 
242
- # parse + resize all images to the final (width, height)
243
- image_list = parse_and_resize_images(input_images, width, height)
244
-
245
- # ensure dims are multiples of 8 even for text-only runs
246
  width = max(256, min(MAX_IMAGE_SIZE, round(int(width) / 8) * 8))
247
  height = max(256, min(MAX_IMAGE_SIZE, round(int(height) / 8) * 8))
248
 
@@ -256,230 +156,830 @@ def infer(
256
  if image_list is not None:
257
  shared_kwargs["image"] = image_list
258
 
259
- progress(0.30, desc="Launching both pipelines simultaneously...")
260
-
261
  with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
262
- future_std = executor.submit(
263
- run_pipeline, pipe_standard, pipe_lock_standard, shared_kwargs, seed
264
- )
265
- future_small = executor.submit(
266
- run_pipeline, pipe_small_decoder, pipe_lock_small, shared_kwargs, seed
267
- )
268
  concurrent.futures.wait(
269
  [future_std, future_small],
270
  return_when=concurrent.futures.ALL_COMPLETED,
271
  )
272
 
273
- progress(0.80, desc="✅ Both pipelines done!")
274
-
275
  out_standard = future_std.result()
276
  out_small = future_small.result()
277
 
278
  gc.collect()
279
- torch.cuda.empty_cache()
 
280
 
281
  return out_standard, out_small, seed
282
 
283
 
284
- @spaces.GPU(duration=120)
285
- def infer_example(prompt):
286
- out_std, out_small, seed_used = infer(
287
- prompt=prompt,
288
- input_images=None,
289
- seed=0,
290
- randomize_seed=True,
291
- width=1024,
292
- height=1024,
293
- num_inference_steps=4,
294
- guidance_scale=1.0,
295
- )
296
- return out_std, out_small, seed_used
297
-
298
-
299
  def get_example_items():
300
- example_prompts = {
301
- "1.jpg": "Change the weather to stormy.",
302
- "2.jpg": "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition.",
303
- "3.jpg": "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent.",
304
- "4.jpg": "Make the texture high-resolution.",
305
- }
306
- items = []
307
- if EXAMPLES_DIR.exists():
308
- for name in sorted(os.listdir(EXAMPLES_DIR)):
309
- if name.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
310
- items.append({
311
- "file": name,
312
- "path": str(EXAMPLES_DIR / name),
313
- "prompt": example_prompts.get(
314
- name, "Edit this image while preserving composition."
315
- ),
316
- })
317
- return items
318
-
319
- EXAMPLE_ITEMS = get_example_items()
320
-
321
- css = """
322
- #col-container {
323
- margin: 0 auto;
324
- max-width: 1100px;
325
- }
326
- #main-title h1 {
327
- font-size: 2.4em !important;
328
- }
329
- .vae-badge {
330
- font-weight: 700;
331
- font-size: 0.95em;
332
- text-align: center;
333
- padding: 4px 16px;
334
- border-radius: 20px;
335
- display: block;
336
- margin-bottom: 6px;
337
- }
338
- """
339
-
340
- with gr.Blocks() as demo:
341
-
342
- with gr.Column(elem_id="col-container"):
343
-
344
- gr.Markdown(
345
- "# **Flux.2-4B-Decoder-Comparator**",
346
- elem_id="main-title",
347
- )
348
- gr.Markdown(
349
- "Compare **FLUX.2-klein-4B** side-by-side with "
350
- "[small decoder](https://huggingface.co/black-forest-labs/FLUX.2-small-decoder)."
351
- )
352
 
353
- with gr.Row(equal_height=True):
354
-
355
- with gr.Column():
356
- input_images = gr.Gallery(
357
- label="Input Images",
358
- type="pil",
359
- columns=2,
360
- rows=1,
361
- height=300,
362
- allow_preview=True,
363
- )
364
-
365
- prompt = gr.Text(
366
- label="Prompt",
367
- max_lines=1,
368
- show_label=True,
369
- placeholder="e.g., A black cat holding a sign that says hello world...",
370
- )
371
-
372
- run_button = gr.Button("Run Comparison", variant="primary")
373
-
374
- with gr.Column():
375
- with gr.Row():
376
- with gr.Column():
377
- result_standard = gr.Image(
378
- label="Standard Decoder",
379
- show_label=True,
380
- interactive=False,
381
- format="png",
382
- height=250,
383
- )
384
- with gr.Column():
385
- result_small = gr.Image(
386
- label="Small Decoder",
387
- show_label=True,
388
- interactive=False,
389
- format="png",
390
- height=250,
391
- )
392
-
393
- seed_output = gr.Number(label="Seed Used", precision=0, visible=False)
394
-
395
- with gr.Accordion("Advanced Settings", open=False, visible=False):
396
- seed = gr.Slider(
397
- label="Seed",
398
- minimum=0,
399
- maximum=MAX_SEED,
400
- step=1,
401
- value=0,
402
- )
403
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
404
-
405
- with gr.Row():
406
- width = gr.Slider(
407
- label="Width",
408
- minimum=256,
409
- maximum=MAX_IMAGE_SIZE,
410
- step=8,
411
- value=1024,
412
- )
413
- height_slider = gr.Slider(
414
- label="Height",
415
- minimum=256,
416
- maximum=MAX_IMAGE_SIZE,
417
- step=8,
418
- value=1024,
419
- )
420
-
421
- with gr.Row():
422
- num_inference_steps = gr.Slider(
423
- label="Inference Steps",
424
- minimum=1,
425
- maximum=20,
426
- step=1,
427
- value=4,
428
- )
429
- guidance_scale = gr.Slider(
430
- label="Guidance Scale",
431
- minimum=0.0,
432
- maximum=10.0,
433
- step=0.1,
434
- value=1.0,
435
- )
436
-
437
- gr.Examples(
438
- examples=[
439
- [["examples/I1.jpg", "examples/I2.jpg"], "Make her wear these glasses in Image 2."],
440
- [["examples/1.jpg"], "Change the weather to stormy."],
441
- [["examples/2.jpg"], "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition."],
442
- [["examples/3.jpg"], "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent."],
443
- [["examples/4.jpg"], "Make the texture high-resolution."],
444
- ],
445
- inputs=[input_images, prompt],
446
- outputs=[result_standard, result_small, seed_output],
447
- fn=infer_example,
448
- cache_examples=False,
449
- label="Examples",
450
- )
451
 
452
- gr.Markdown(
453
- "[*](https://huggingface.co/black-forest-labs/FLUX.2-klein-4B) "
454
- "Experimental Space FLUX.2 [klein] 4B VAE Decoder Comparison."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  )
456
 
457
- input_images.upload(
458
- fn=update_dimensions_from_image,
459
- inputs=[input_images],
460
- outputs=[width, height_slider],
461
- )
462
-
463
- gr.on(
464
- triggers=[run_button.click, prompt.submit],
465
- fn=infer,
466
- inputs=[
467
- prompt,
468
- input_images,
469
- seed,
470
- randomize_seed,
471
- width,
472
- height_slider,
473
- num_inference_steps,
474
- guidance_scale,
475
- ],
476
- outputs=[result_standard, result_small, seed_output],
477
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
 
479
- if __name__ == "__main__":
480
- demo.queue(max_size=20).launch(
481
- theme=orange_red_theme, css=css,
482
- mcp_server=True,
483
- ssr_mode=False,
484
- show_error=True,
485
- )
 
1
  import os
2
+ import io
3
  import gc
4
+ import uuid
5
+ import json
6
+ import base64
7
  import random
8
+ import zipfile
9
+ import threading
10
+ import concurrent.futures
11
+ from pathlib import Path
12
+ from typing import List, Optional
13
+
14
  import spaces
15
+ import numpy as np
16
  import torch
 
17
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ from gradio import Server
20
+ from fastapi import Request, UploadFile, File, Form
21
+ from fastapi.responses import HTMLResponse, JSONResponse, FileResponse, StreamingResponse
22
+ from diffusers import Flux2KleinPipeline, AutoencoderKLFlux2
23
 
24
+ # --- App Configuration & Directories ---
25
+ app = Server()
26
 
27
+ BASE_DIR = Path(__file__).resolve().parent
28
+ STATIC_DIR = BASE_DIR / "static"
29
+ OUTPUT_DIR = BASE_DIR / "outputs"
30
+ EXAMPLES_DIR = BASE_DIR / "examples"
31
+
32
+ STATIC_DIR.mkdir(exist_ok=True)
33
+ OUTPUT_DIR.mkdir(exist_ok=True)
34
+
35
+ MAX_SEED = np.iinfo(np.int32).max
36
  MAX_IMAGE_SIZE = 1024
 
37
 
38
+ dtype = torch.bfloat16
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+ if torch.cuda.is_available():
42
+ print("current device:", torch.cuda.current_device())
43
+ print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
44
+ DEVICE_LABEL = torch.cuda.get_device_name(torch.cuda.current_device()).lower()
45
+ else:
46
+ DEVICE_LABEL = str(device).lower()
47
+
48
+ # --- Model Loading ---
49
  print("Loading 4B Distilled model (Standard VAE)...")
50
  pipe_standard = Flux2KleinPipeline.from_pretrained(
51
  "black-forest-labs/FLUX.2-klein-4B",
52
  torch_dtype=dtype,
53
+ ).to(device)
54
  pipe_standard.enable_model_cpu_offload()
55
 
56
  print("Loading Small Decoder VAE...")
57
  vae_small = AutoencoderKLFlux2.from_pretrained(
58
  "black-forest-labs/FLUX.2-small-decoder",
59
  torch_dtype=dtype,
60
+ ).to(device)
61
 
62
  print("Loading 4B Distilled model (Small Decoder VAE)...")
63
  pipe_small_decoder = Flux2KleinPipeline.from_pretrained(
64
  "black-forest-labs/FLUX.2-klein-4B",
65
  vae=vae_small,
66
  torch_dtype=dtype,
67
+ ).to(device)
68
  pipe_small_decoder.enable_model_cpu_offload()
69
 
70
  pipe_lock_standard = threading.Lock()
71
  pipe_lock_small = threading.Lock()
72
 
73
+ # --- Utility Functions ---
74
  def calc_dimensions(pil_img: Image.Image):
 
 
 
 
 
75
  iw, ih = pil_img.size
76
  aspect = iw / ih
77
 
78
+ if aspect >= 1:
79
  new_width = 1024
80
  new_height = int(round(1024 / aspect))
81
+ else:
82
  new_height = 1024
83
  new_width = int(round(1024 * aspect))
84
 
 
85
  new_width = max(256, min(1024, round(new_width / 8) * 8))
86
  new_height = max(256, min(1024, round(new_height / 8) * 8))
87
  return new_width, new_height
88
 
89
+ def parse_and_resize_images(image_paths: List[str], width: int, height: int):
90
+ if not image_paths:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  return None
92
+
93
+ resized = []
94
+ for path in image_paths:
95
+ try:
96
+ img = Image.open(path).convert("RGB")
97
+ resized.append(img.resize((width, height), Image.LANCZOS))
98
+ except Exception as e:
99
+ print(f"Skipping invalid image: {e}")
100
+
101
+ return resized if resized else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  def run_pipeline(pipe, lock, kwargs, seed):
104
  with lock:
105
+ gen = torch.Generator(device="cpu").manual_seed(seed)
106
  result = pipe(**kwargs, generator=gen).images[0]
107
  return result
108
 
109
+ def save_image(img: Image.Image, prefix: str = "output") -> str:
110
+ filename = f"{prefix}_{uuid.uuid4().hex}.png"
111
+ path = OUTPUT_DIR / filename
112
+ img.save(path, format="PNG")
113
+ return filename
114
+
115
+ # --- Inference Function ---
116
  @spaces.GPU(duration=120)
117
  def infer(
118
+ prompt: str,
119
+ image_paths: List[str] = None,
120
+ seed: int = 42,
121
+ randomize_seed: bool = False,
122
+ width: int = 1024,
123
+ height: int = 1024,
124
+ num_inference_steps: int = 4,
125
+ guidance_scale: float = 1.0,
 
126
  ):
127
  gc.collect()
128
+ if torch.cuda.is_available():
129
+ torch.cuda.empty_cache()
130
 
131
  if not prompt or not prompt.strip():
132
+ raise ValueError("Please enter a prompt.")
133
 
134
  if randomize_seed:
135
  seed = random.randint(0, MAX_SEED)
136
 
 
137
  image_list = None
138
+ if image_paths and len(image_paths) > 0:
139
+ try:
140
+ first_pil = Image.open(image_paths[0]).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  width, height = calc_dimensions(first_pil)
142
+ image_list = parse_and_resize_images(image_paths, width, height)
143
+ except Exception as e:
144
+ print(f"Error processing upload: {e}")
145
 
 
 
 
 
146
  width = max(256, min(MAX_IMAGE_SIZE, round(int(width) / 8) * 8))
147
  height = max(256, min(MAX_IMAGE_SIZE, round(int(height) / 8) * 8))
148
 
 
156
  if image_list is not None:
157
  shared_kwargs["image"] = image_list
158
 
 
 
159
  with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
160
+ future_std = executor.submit(run_pipeline, pipe_standard, pipe_lock_standard, shared_kwargs, seed)
161
+ future_small = executor.submit(run_pipeline, pipe_small_decoder, pipe_lock_small, shared_kwargs, seed)
162
+
 
 
 
163
  concurrent.futures.wait(
164
  [future_std, future_small],
165
  return_when=concurrent.futures.ALL_COMPLETED,
166
  )
167
 
 
 
168
  out_standard = future_std.result()
169
  out_small = future_small.result()
170
 
171
  gc.collect()
172
+ if torch.cuda.is_available():
173
+ torch.cuda.empty_cache()
174
 
175
  return out_standard, out_small, seed
176
 
177
 
178
+ # --- FastAPI Endpoints ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  def get_example_items():
180
+ return [
181
+ {
182
+ "urls": ["/example-file/I1.jpg", "/example-file/I2.jpg"],
183
+ "prompt": "Make her wear these glasses in Image 2."
184
+ },
185
+ {
186
+ "urls": ["/example-file/1.jpg"],
187
+ "prompt": "Change the weather to stormy."
188
+ },
189
+ {
190
+ "urls": ["/example-file/2.jpg"],
191
+ "prompt": "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition."
192
+ },
193
+ {
194
+ "urls": ["/example-file/3.jpg"],
195
+ "prompt": "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent."
196
+ },
197
+ {
198
+ "urls": ["/example-file/4.jpg"],
199
+ "prompt": "Make the texture high-resolution."
200
+ }
201
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ @app.get("/example-file/{filename}")
204
+ async def example_file(filename: str):
205
+ path = EXAMPLES_DIR / filename
206
+ if not path.exists():
207
+ return JSONResponse({"error": "Example not found"}, status_code=404)
208
+ return FileResponse(path)
209
+
210
+ @app.get("/download/{filename}")
211
+ async def download_file(filename: str):
212
+ path = OUTPUT_DIR / filename
213
+ if not path.exists():
214
+ return JSONResponse({"error": "File not found"}, status_code=404)
215
+ return FileResponse(path, filename=filename, media_type="image/png")
216
+
217
+ @app.get("/api/download-zip")
218
+ async def download_zip(std: str, small: str):
219
+ """Packages both generated images into a single ZIP file and streams it."""
220
+ std_name = Path(std).name
221
+ small_name = Path(small).name
222
+
223
+ std_path = OUTPUT_DIR / std_name
224
+ small_path = OUTPUT_DIR / small_name
225
+
226
+ if not std_path.exists() or not small_path.exists():
227
+ return JSONResponse({"error": "Generated files not found"}, status_code=404)
228
+
229
+ memory_file = io.BytesIO()
230
+ with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf:
231
+ zf.write(std_path, arcname=f"Standard_Decoder_{std_name}")
232
+ zf.write(small_path, arcname=f"Small_Decoder_{small_name}")
233
+
234
+ memory_file.seek(0)
235
+
236
+ return StreamingResponse(
237
+ memory_file,
238
+ media_type="application/zip",
239
+ headers={"Content-Disposition": f"attachment; filename=Flux2_Comparison_{uuid.uuid4().hex[:6]}.zip"}
240
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
+ @app.post("/api/compare")
243
+ async def compare_images(
244
+ prompt: str = Form(...),
245
+ seed: str = Form("0"),
246
+ randomize_seed: str = Form("true"),
247
+ width: str = Form("1024"),
248
+ height: str = Form("1024"),
249
+ steps: str = Form("4"),
250
+ guidance: str = Form("1.0"),
251
+ images: Optional[List[UploadFile]] = File(None),
252
+ ):
253
+ temp_paths = []
254
+ try:
255
+ image_paths = []
256
+ if images:
257
+ for upload in images:
258
+ if not upload.filename: continue
259
+ suffix = Path(upload.filename).suffix or ".png"
260
+ temp_path = OUTPUT_DIR / f"upload_{uuid.uuid4().hex}{suffix}"
261
+ content = await upload.read()
262
+ with open(temp_path, "wb") as f:
263
+ f.write(content)
264
+ temp_paths.append(str(temp_path))
265
+ image_paths.append(str(temp_path))
266
+
267
+ result_std, result_small, used_seed = infer(
268
+ prompt=prompt,
269
+ image_paths=image_paths,
270
+ seed=int(seed),
271
+ randomize_seed=(randomize_seed.lower() == "true"),
272
+ width=int(width),
273
+ height=int(height),
274
+ num_inference_steps=int(steps),
275
+ guidance_scale=float(guidance),
276
  )
277
 
278
+ std_filename = save_image(result_std, prefix="std")
279
+ small_filename = save_image(result_small, prefix="small")
280
+
281
+ return JSONResponse({
282
+ "success": True,
283
+ "seed": used_seed,
284
+ "std_url": f"/download/{std_filename}",
285
+ "small_url": f"/download/{small_filename}",
286
+ "std_filename": std_filename,
287
+ "small_filename": small_filename,
288
+ "device": DEVICE_LABEL,
289
+ })
290
+
291
+ except Exception as e:
292
+ return JSONResponse({"success": False, "error": str(e)}, status_code=500)
293
+ finally:
294
+ for p in temp_paths:
295
+ if os.path.exists(p):
296
+ os.remove(p)
297
+
298
+ # --- Frontend ---
299
+ @app.get("/", response_class=HTMLResponse)
300
+ async def homepage(request: Request):
301
+ examples = get_example_items()
302
+ examples_json = json.dumps(examples)
303
+
304
+ return f"""
305
+ <!DOCTYPE html>
306
+ <html lang="en">
307
+ <head>
308
+ <meta charset="UTF-8" />
309
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
310
+ <title>Flux.2-4B-Decoder-Comparator</title>
311
+ <link href="https://fonts.googleapis.com/css2?family=Ubuntu:wght@300;400;500;700&display=swap" rel="stylesheet">
312
+ <style>
313
+ :root {{
314
+ --ub-aubergine: #2C001E;
315
+ --ub-aubergine-dark: #1f0015;
316
+ --ub-orange: #E95420;
317
+ --ub-orange-hover: #c4461a;
318
+ --ub-panel: #3D3D3D;
319
+ --ub-panel-light: #4f4f4f;
320
+ --ub-border: rgba(255,255,255,0.1);
321
+ --ub-text: #FFFFFF;
322
+ --ub-muted: #b0b0b0;
323
+ --ub-input: #2b2b2b;
324
+ --panel-radius: 8px;
325
+ }}
326
+
327
+ * {{ box-sizing: border-box; font-family: 'Ubuntu', sans-serif; }}
328
+
329
+ body {{
330
+ margin: 0; padding: 0;
331
+ background: var(--ub-aubergine);
332
+ color: var(--ub-text);
333
+ min-height: 100vh;
334
+ display: flex;
335
+ flex-direction: column;
336
+ }}
337
+
338
+ .topbar {{
339
+ background: var(--ub-aubergine-dark);
340
+ padding: 16px 24px;
341
+ border-bottom: 1px solid var(--ub-border);
342
+ text-align: center;
343
+ font-weight: 700;
344
+ letter-spacing: 0.5px;
345
+ color: var(--ub-orange);
346
+ }}
347
+
348
+ .container {{
349
+ max-width: 1300px;
350
+ margin: 0 auto;
351
+ padding: 30px 20px;
352
+ flex: 1;
353
+ width: 100%;
354
+ }}
355
+
356
+ .header-text {{
357
+ text-align: center;
358
+ margin-bottom: 30px;
359
+ }}
360
+ .header-text h1 {{
361
+ margin: 0 0 10px 0;
362
+ font-size: 2.2rem;
363
+ }}
364
+ .header-text p {{
365
+ color: var(--ub-muted);
366
+ margin: 0;
367
+ }}
368
+
369
+ /* FIXED LAYOUT GRID */
370
+ .layout {{
371
+ display: grid;
372
+ grid-template-columns: 420px 1fr;
373
+ gap: 24px;
374
+ align-items: stretch;
375
+ height: 650px;
376
+ }}
377
+
378
+ .panel {{
379
+ background: var(--ub-panel);
380
+ border-radius: var(--panel-radius);
381
+ box-shadow: 0 8px 24px rgba(0,0,0,0.2);
382
+ display: flex;
383
+ flex-direction: column;
384
+ overflow: hidden;
385
+ height: 100%;
386
+ }}
387
+
388
+ .panel-header {{
389
+ padding: 16px 20px;
390
+ background: rgba(0,0,0,0.2);
391
+ border-bottom: 1px solid var(--ub-border);
392
+ font-weight: 500;
393
+ font-size: 1.1rem;
394
+ flex-shrink: 0;
395
+ display: flex;
396
+ justify-content: space-between;
397
+ align-items: center;
398
+ }}
399
+
400
+ .panel-body-scroll {{
401
+ flex: 1;
402
+ padding: 20px;
403
+ overflow-y: auto;
404
+ display: flex;
405
+ flex-direction: column;
406
+ }}
407
+
408
+ /* Input Forms */
409
+ .form-group {{ margin-bottom: 20px; flex-shrink: 0; }}
410
+ .label {{
411
+ display: block; font-weight: 500; font-size: 14px;
412
+ color: var(--ub-muted); margin-bottom: 8px;
413
+ }}
414
+
415
+ .textarea, .input {{
416
+ width: 100%;
417
+ background: var(--ub-input);
418
+ border: 1px solid var(--ub-border);
419
+ color: var(--ub-text);
420
+ padding: 12px;
421
+ border-radius: 4px;
422
+ outline: none;
423
+ font-size: 14px;
424
+ }}
425
+ .textarea:focus, .input:focus {{ border-color: var(--ub-orange); }}
426
+ .textarea {{ min-height: 100px; resize: vertical; }}
427
+
428
+ /* Upload Zone */
429
+ .upload-zone {{
430
+ background: var(--ub-input);
431
+ border: 1px dashed var(--ub-muted);
432
+ border-radius: 4px;
433
+ padding: 15px;
434
+ text-align: center;
435
+ cursor: pointer;
436
+ transition: background 0.2s, border-color 0.2s;
437
+ min-height: 100px;
438
+ display: flex;
439
+ flex-direction: column;
440
+ justify-content: center;
441
+ align-items: center;
442
+ }}
443
+ .upload-zone:hover, .upload-zone.dragover {{
444
+ border-color: var(--ub-orange);
445
+ background: rgba(233,84,32,0.05);
446
+ }}
447
+ .upload-zone input[type="file"] {{ display: none; }}
448
+ .upload-text {{ pointer-events: none; color: var(--ub-muted); }}
449
+
450
+ .preview-grid {{
451
+ display: none;
452
+ grid-template-columns: repeat(auto-fill, minmax(70px, 1fr));
453
+ gap: 10px;
454
+ width: 100%;
455
+ }}
456
+ .thumb {{
457
+ position: relative; aspect-ratio: 1;
458
+ border-radius: 4px; overflow: hidden;
459
+ border: 1px solid var(--ub-border);
460
+ }}
461
+ .thumb img {{ width: 100%; height: 100%; object-fit: cover; display: block; }}
462
+ .thumb-remove {{
463
+ position: absolute; top: 4px; right: 4px;
464
+ background: rgba(0,0,0,0.7); color: white;
465
+ border: none; border-radius: 50%; width: 20px; height: 20px;
466
+ display: flex; align-items: center; justify-content: center;
467
+ cursor: pointer; font-size: 12px;
468
+ }}
469
+
470
+ .add-more-btn {{
471
+ display: flex; align-items: center; justify-content: center;
472
+ border: 2px dashed var(--ub-muted); border-radius: 4px;
473
+ color: var(--ub-muted); font-size: 26px; cursor: pointer;
474
+ aspect-ratio: 1; transition: all 0.2s; background: transparent;
475
+ }}
476
+ .add-more-btn:hover {{
477
+ border-color: var(--ub-orange); color: var(--ub-orange);
478
+ background: rgba(233,84,32,0.05);
479
+ }}
480
+
481
+ /* Advanced Accordion */
482
+ .advanced-toggle {{
483
+ width: 100%; background: none; border: none; color: var(--ub-orange);
484
+ text-align: left; padding: 10px 0; font-weight: 500; cursor: pointer;
485
+ display: flex; justify-content: space-between; align-items: center;
486
+ flex-shrink: 0;
487
+ }}
488
+ .advanced-icon {{ font-weight: bold; font-size: 18px; line-height: 1; }}
489
+ .advanced-body {{ display: none; padding-top: 10px; flex-shrink: 0; }}
490
+ .advanced-body.open {{ display: block; }}
491
+ .grid-2 {{ display: grid; grid-template-columns: 1fr 1fr; gap: 15px; }}
492
+
493
+ /* Status Container */
494
+ .status-container {{
495
+ margin-top: 20px; margin-bottom: 20px;
496
+ border: 1px solid var(--ub-border); border-radius: 4px;
497
+ background: #200014; display: flex; flex-direction: column;
498
+ flex: 1; min-height: 100px; max-height: 200px;
499
+ }}
500
+ .status-header {{
501
+ padding: 8px 12px; font-size: 11px; font-weight: 700;
502
+ color: var(--ub-muted); background: rgba(0,0,0,0.4);
503
+ border-bottom: 1px solid var(--ub-border); text-transform: uppercase;
504
+ letter-spacing: 0.5px; flex-shrink: 0;
505
+ }}
506
+ .status-log {{
507
+ padding: 10px; font-family: 'Courier New', Courier, monospace;
508
+ font-size: 12px; color: #eeeeee; overflow-y: auto;
509
+ flex: 1; display: flex; flex-direction: column; gap: 4px;
510
+ }}
511
+ .log-time {{ color: #777; margin-right: 8px; }}
512
+ .log-info {{ color: #5bc0eb; }}
513
+ .log-success {{ color: #9bc53d; }}
514
+ .log-error {{ color: #ff5e5b; }}
515
+
516
+ /* Buttons */
517
+ .btn {{
518
+ width: 100%; padding: 14px; border: none; border-radius: 4px;
519
+ font-size: 16px; font-weight: 700; cursor: pointer;
520
+ transition: opacity 0.2s, background 0.2s; flex-shrink: 0;
521
+ }}
522
+ .btn-primary {{
523
+ background: var(--ub-orange); color: white;
524
+ box-shadow: 0 4px 12px rgba(233,84,32,0.3);
525
+ }}
526
+ .btn-primary:hover {{ background: var(--ub-orange-hover); }}
527
+ .btn:disabled {{ opacity: 0.6; cursor: not-allowed; }}
528
+
529
+ /* Top-Right Download Icon */
530
+ .action-icon {{
531
+ display: none; background: none; border: none; color: var(--ub-muted);
532
+ cursor: pointer; padding: 4px; transition: color 0.2s;
533
+ }}
534
+ .action-icon:hover {{ color: var(--ub-orange); }}
535
+
536
+ /* SLIDER CONTAINER */
537
+ .panel-body-slider {{
538
+ flex: 1; display: flex; flex-direction: column;
539
+ padding: 0; position: relative;
540
+ }}
541
+ .slider-stage {{
542
+ position: absolute; top: 0; left: 0; right: 0; bottom: 0;
543
+ background: #111; overflow: hidden; display: flex;
544
+ align-items: center; justify-content: center;
545
+ }}
546
+ .slider-empty {{ color: var(--ub-muted); text-align: center; z-index: 1; }}
547
+
548
+ .slider-img {{
549
+ position: absolute; top: 0; left: 0; width: 100%; height: 100%;
550
+ object-fit: contain; display: none; user-select: none; -webkit-user-drag: none;
551
+ }}
552
+ #imgSmall {{ clip-path: inset(0 50% 0 0); }}
553
+
554
+ .slider-handle {{
555
+ position: absolute; left: 50%; top: 0; bottom: 0;
556
+ width: 4px; background: var(--ub-orange); cursor: ew-resize; display: none; z-index: 10;
557
+ }}
558
+ .slider-handle::after {{
559
+ content: '◀ ▶'; position: absolute; top: 50%; left: 50%;
560
+ transform: translate(-50%, -50%); width: 40px; height: 30px;
561
+ background: var(--ub-orange); color: white; border-radius: 15px;
562
+ display: flex; align-items: center; justify-content: center;
563
+ font-size: 10px; font-weight: bold; box-shadow: 0 2px 6px rgba(0,0,0,0.5);
564
+ }}
565
+
566
+ .slider-labels {{
567
+ position: absolute; top: 15px; left: 15px; right: 15px;
568
+ display: none; justify-content: space-between;
569
+ pointer-events: none; z-index: 5;
570
+ }}
571
+ .badge {{
572
+ background: rgba(0,0,0,0.6); color: white; padding: 6px 12px;
573
+ border-radius: 20px; font-size: 13px; backdrop-filter: blur(4px);
574
+ }}
575
+
576
+ /* UPDATED LOADER ANIMATION (Minimalist Single Circle) */
577
+ .loader {{
578
+ position: absolute; inset: 0;
579
+ background: rgba(20, 0, 10, 0.7); /* dark aubergine tint */
580
+ backdrop-filter: blur(6px);
581
+ display: none; flex-direction: column;
582
+ align-items: center; justify-content: center; z-index: 20;
583
+ }}
584
+ .spinner-single {{
585
+ width: 55px; height: 55px;
586
+ border: 3px solid rgba(255, 255, 255, 0.1);
587
+ border-top-color: var(--ub-orange);
588
+ border-radius: 50%;
589
+ animation: spin 1s cubic-bezier(0.4, 0.0, 0.2, 1) infinite;
590
+ margin-bottom: 20px;
591
+ }}
592
+ .loader-text {{
593
+ font-weight: 500;
594
+ font-size: 15px;
595
+ color: #ffffff;
596
+ letter-spacing: 1px;
597
+ animation: pulse 1.5s ease-in-out infinite;
598
+ }}
599
+ @keyframes pulse {{
600
+ 0%, 100% {{ opacity: 1; }}
601
+ 50% {{ opacity: 0.5; }}
602
+ }}
603
+ @keyframes spin {{
604
+ to {{ transform: rotate(360deg); }}
605
+ }}
606
+
607
+ /* Examples */
608
+ .examples-section {{ margin-top: 40px; }}
609
+ .examples-section h3 {{ border-bottom: 1px solid var(--ub-border); padding-bottom: 10px; }}
610
+ .examples-grid {{
611
+ display: grid; grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); gap: 20px;
612
+ }}
613
+ .ex-card {{
614
+ background: var(--ub-panel); border-radius: 4px; overflow: hidden;
615
+ cursor: pointer; transition: transform 0.2s, box-shadow 0.2s;
616
+ }}
617
+ .ex-card:hover {{ transform: translateY(-3px); box-shadow: 0 6px 16px rgba(0,0,0,0.3); }}
618
+ .ex-card-img-wrap {{ width: 100%; aspect-ratio: 1; display: flex; background: #000; }}
619
+ .ex-card-img-wrap img {{ height: 100%; object-fit: cover; }}
620
+ .ex-card p {{ padding: 12px; margin: 0; font-size: 13px; color: var(--ub-muted); line-height: 1.4; }}
621
+
622
+ @media (max-width: 900px) {{
623
+ .layout {{ grid-template-columns: 1fr; height: auto; }}
624
+ .panel-body-slider {{ height: 450px; flex: none; }}
625
+ .slider-stage {{ position: relative; height: 100%; }}
626
+ }}
627
+ </style>
628
+ </head>
629
+ <body>
630
+
631
+ <div class="topbar">Flux.2-4B VAE Decoder Comparator</div>
632
+
633
+ <div class="container">
634
+ <div class="header-text">
635
+ <h1>Standard vs. Small Decoder</h1>
636
+ <p>Upload an image, enter a prompt, and use the slider to compare outputs in real-time.</p>
637
+ </div>
638
+
639
+ <div class="layout">
640
+ <div class="panel">
641
+ <div class="panel-header">Settings</div>
642
+ <div class="panel-body-scroll">
643
+ <div class="form-group">
644
+ <label class="label">Input Images (Optional)</label>
645
+ <div class="upload-zone" id="dropZone">
646
+ <input type="file" id="fileInput" multiple accept="image/*" />
647
+ <div class="upload-text" id="uploadText">Click or Drag & Drop images here</div>
648
+ <div class="preview-grid" id="previewGrid"></div>
649
+ </div>
650
+ </div>
651
+
652
+ <div class="form-group">
653
+ <label class="label">Prompt</label>
654
+ <textarea id="promptInput" class="textarea" placeholder="Describe the edit or generation..."></textarea>
655
+ </div>
656
+
657
+ <button class="advanced-toggle" id="advToggle">
658
+ <span>Advanced Settings</span> <span class="advanced-icon" id="advIcon">+</span>
659
+ </button>
660
+
661
+ <div class="advanced-body" id="advBody">
662
+ <div class="grid-2">
663
+ <div class="form-group">
664
+ <label class="label">Seed</label>
665
+ <input type="number" id="seed" class="input" value="0">
666
+ </div>
667
+ <div class="form-group">
668
+ <label class="label">Steps</label>
669
+ <input type="number" id="steps" class="input" value="4">
670
+ </div>
671
+ <div class="form-group">
672
+ <label class="label">Width</label>
673
+ <input type="number" id="width" class="input" value="1024" step="8">
674
+ </div>
675
+ <div class="form-group">
676
+ <label class="label">Height</label>
677
+ <input type="number" id="height" class="input" value="1024" step="8">
678
+ </div>
679
+ <div class="form-group" style="grid-column: span 2;">
680
+ <label class="label">Guidance Scale</label>
681
+ <input type="number" id="guidance" class="input" value="1.0" step="0.1">
682
+ </div>
683
+ <div class="form-group" style="grid-column: span 2;">
684
+ <label style="display:flex; align-items:center; gap:8px; font-size:14px; color:var(--ub-text);">
685
+ <input type="checkbox" id="randomize" checked> Randomize Seed
686
+ </label>
687
+ </div>
688
+ </div>
689
+ </div>
690
+
691
+ <div class="status-container">
692
+ <div class="status-header">Execution Log</div>
693
+ <div class="status-log" id="statusLog">
694
+ <div><span class="log-time">[{DEVICE_LABEL}]</span><span>System Ready...</span></div>
695
+ </div>
696
+ </div>
697
+
698
+ <button class="btn btn-primary" id="runBtn">Run Comparison</button>
699
+ </div>
700
+ </div>
701
+
702
+ <div class="panel">
703
+ <div class="panel-header">
704
+ <span>Comparison View</span>
705
+ <button id="downloadZipBtn" class="action-icon" title="Download Both Images (ZIP)">
706
+ <svg width="22" height="22" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" viewBox="0 0 24 24">
707
+ <path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4"></path>
708
+ <polyline points="7 10 12 15 17 10"></polyline>
709
+ <line x1="12" y1="15" x2="12" y2="3"></line>
710
+ </svg>
711
+ </button>
712
+ </div>
713
+ <div class="panel-body-slider">
714
+ <div class="slider-stage" id="sliderStage">
715
+ <div class="slider-empty" id="sliderEmpty">
716
+ <svg width="48" height="48" fill="none" stroke="currentColor" stroke-width="1.5" viewBox="0 0 24 24" style="margin-bottom:10px; opacity:0.5;">
717
+ <path d="M4 16l4.586-4.586a2 2 0 012.828 0L16 16m-2-2l1.586-1.586a2 2 0 012.828 0L20 14m-6-6h.01M6 20h12a2 2 0 002-2V6a2 2 0 00-2-2H6a2 2 0 00-2 2v12a2 2 0 002 2z"></path>
718
+ </svg>
719
+ <div>Results will appear here</div>
720
+ </div>
721
+
722
+ <img id="imgStd" class="slider-img" alt="Standard Decoder" />
723
+ <img id="imgSmall" class="slider-img" alt="Small Decoder" />
724
+
725
+ <div class="slider-labels" id="sliderLabels">
726
+ <div class="badge">Standard Decoder</div>
727
+ <div class="badge">Small Decoder</div>
728
+ </div>
729
+
730
+ <div class="slider-handle" id="sliderHandle"></div>
731
+
732
+ <div class="loader" id="loader">
733
+ <div class="spinner-single"></div>
734
+ <div class="loader-text">Running both decoders...</div>
735
+ </div>
736
+ </div>
737
+ </div>
738
+ </div>
739
+ </div>
740
+
741
+ <div class="examples-section">
742
+ <h3>Examples</h3>
743
+ <div class="examples-grid" id="examplesGrid"></div>
744
+ </div>
745
+ </div>
746
+
747
+ <script>
748
+ const examples = {examples_json};
749
+ let filesState = [];
750
+ let currentStdFilename = "";
751
+ let currentSmallFilename = "";
752
+
753
+ // UI Elements
754
+ const dropZone = document.getElementById('dropZone');
755
+ const fileInput = document.getElementById('fileInput');
756
+ const previewGrid = document.getElementById('previewGrid');
757
+ const uploadText = document.getElementById('uploadText');
758
+ const promptInput = document.getElementById('promptInput');
759
+ const runBtn = document.getElementById('runBtn');
760
+ const downloadZipBtn = document.getElementById('downloadZipBtn');
761
+
762
+ // Status Log
763
+ const statusLog = document.getElementById('statusLog');
764
+
765
+ function logMsg(msg, styleClass="") {{
766
+ const div = document.createElement('div');
767
+ const timeStr = new Date().toLocaleTimeString('en-US', {{hour12:false}});
768
+ div.innerHTML = `<span class="log-time">[${{timeStr}}]</span><span class="${{styleClass}}">${{msg}}</span>`;
769
+ statusLog.appendChild(div);
770
+ statusLog.scrollTop = statusLog.scrollHeight; // auto-scroll to bottom
771
+ }}
772
+
773
+ // Slider Elements
774
+ const sliderStage = document.getElementById('sliderStage');
775
+ const imgStd = document.getElementById('imgStd');
776
+ const imgSmall = document.getElementById('imgSmall');
777
+ const sliderHandle = document.getElementById('sliderHandle');
778
+ const sliderLabels = document.getElementById('sliderLabels');
779
+ const sliderEmpty = document.getElementById('sliderEmpty');
780
+ const loader = document.getElementById('loader');
781
+
782
+ // Advanced Toggle logic (+ / -)
783
+ document.getElementById('advToggle').onclick = function() {{
784
+ const body = document.getElementById('advBody');
785
+ body.classList.toggle('open');
786
+ document.getElementById('advIcon').innerText = body.classList.contains('open') ? '−' : '+';
787
+ }};
788
+
789
+ // --- File Upload Logic ---
790
+ function renderPreviews() {{
791
+ previewGrid.innerHTML = '';
792
+ if(filesState.length > 0) {{
793
+ uploadText.style.display = 'none';
794
+ previewGrid.style.display = 'grid';
795
+
796
+ filesState.forEach((f, i) => {{
797
+ const div = document.createElement('div');
798
+ div.className = 'thumb';
799
+ const img = document.createElement('img');
800
+ img.src = URL.createObjectURL(f);
801
+ const btn = document.createElement('button');
802
+ btn.className = 'thumb-remove';
803
+ btn.innerText = '×';
804
+ btn.onclick = (e) => {{ e.stopPropagation(); filesState.splice(i, 1); renderPreviews(); }};
805
+ div.appendChild(img); div.appendChild(btn);
806
+ previewGrid.appendChild(div);
807
+ }});
808
+
809
+ // Append dynamic + button
810
+ const addBtn = document.createElement('div');
811
+ addBtn.className = 'add-more-btn';
812
+ addBtn.innerHTML = '+';
813
+ addBtn.onclick = (e) => {{ e.stopPropagation(); fileInput.click(); }};
814
+ previewGrid.appendChild(addBtn);
815
+
816
+ }} else {{
817
+ uploadText.style.display = 'block';
818
+ previewGrid.style.display = 'none';
819
+ }}
820
+ }}
821
+
822
+ dropZone.onclick = (e) => {{ if(e.target === dropZone || e.target === uploadText) fileInput.click(); }};
823
+ fileInput.onchange = (e) => {{ filesState.push(...Array.from(e.target.files)); renderPreviews(); fileInput.value=''; }};
824
+ dropZone.ondragover = (e) => {{ e.preventDefault(); dropZone.classList.add('dragover'); }};
825
+ dropZone.ondragleave = () => dropZone.classList.remove('dragover');
826
+ dropZone.ondrop = (e) => {{
827
+ e.preventDefault(); dropZone.classList.remove('dragover');
828
+ if(e.dataTransfer.files.length) {{ filesState.push(...Array.from(e.dataTransfer.files)); renderPreviews(); }}
829
+ }};
830
+
831
+ // --- Examples Logic ---
832
+ async function loadExample(urls, text) {{
833
+ filesState = [];
834
+ renderPreviews();
835
+ promptInput.value = text;
836
+ logMsg("Loading example: " + text, "log-info");
837
+
838
+ try {{
839
+ for(let i=0; i<urls.length; i++) {{
840
+ const res = await fetch(urls[i]);
841
+ const blob = await res.blob();
842
+ const filename = urls[i].split('/').pop();
843
+ filesState.push(new File([blob], filename, {{type: blob.type}}));
844
+ }}
845
+ renderPreviews();
846
+
847
+ window.scrollTo({{top: 0, behavior: 'smooth'}});
848
+
849
+ setTimeout(() => {{
850
+ logMsg("Example loaded. Starting comparison...", "log-info");
851
+ runBtn.click();
852
+ }}, 500);
853
+
854
+ }} catch (e) {{
855
+ logMsg("Failed to load example images.", "log-error");
856
+ alert('Failed to load example image.');
857
+ }}
858
+ }}
859
+
860
+ const exGrid = document.getElementById('examplesGrid');
861
+ examples.forEach(ex => {{
862
+ const card = document.createElement('div');
863
+ card.className = 'ex-card';
864
+
865
+ let imgHTML = '';
866
+ if(ex.urls.length > 1) {{
867
+ imgHTML = `
868
+ <div class="ex-card-img-wrap">
869
+ <img src="${{ex.urls[0]}}" style="width:50%; border-right:1px solid #000;">
870
+ <img src="${{ex.urls[1]}}" style="width:50%;">
871
+ </div>
872
+ `;
873
+ }} else {{
874
+ imgHTML = `<div class="ex-card-img-wrap"><img src="${{ex.urls[0]}}" style="width:100%;"></div>`;
875
+ }}
876
+
877
+ card.innerHTML = `${{imgHTML}}<p>${{ex.prompt}}</p>`;
878
+ card.onclick = () => loadExample(ex.urls, ex.prompt);
879
+ exGrid.appendChild(card);
880
+ }});
881
+
882
+ // --- Image Slider Logic ---
883
+ let isDragging = false;
884
+
885
+ function updateSlider(clientX) {{
886
+ const rect = sliderStage.getBoundingClientRect();
887
+ let pos = Math.max(0, Math.min(clientX - rect.left, rect.width));
888
+ let percent = (pos / rect.width) * 100;
889
+
890
+ sliderHandle.style.left = percent + '%';
891
+ imgSmall.style.clipPath = `inset(0 ${{100 - percent}}% 0 0)`;
892
+ }}
893
+
894
+ sliderHandle.addEventListener('mousedown', () => isDragging = true);
895
+ window.addEventListener('mouseup', () => isDragging = false);
896
+ window.addEventListener('mousemove', (e) => {{
897
+ if (!isDragging) return;
898
+ updateSlider(e.clientX);
899
+ }});
900
+
901
+ sliderHandle.addEventListener('touchstart', () => isDragging = true);
902
+ window.addEventListener('touchend', () => isDragging = false);
903
+ window.addEventListener('touchmove', (e) => {{
904
+ if (!isDragging) return;
905
+ updateSlider(e.touches[0].clientX);
906
+ }});
907
+
908
+ // --- Download Zip Logic ---
909
+ downloadZipBtn.onclick = () => {{
910
+ if(!currentStdFilename || !currentSmallFilename) return;
911
+ logMsg("Initiating ZIP download...", "log-info");
912
+ window.location.href = `/api/download-zip?std=${{currentStdFilename}}&small=${{currentSmallFilename}}`;
913
+ }};
914
+
915
+ // --- Form Submission ---
916
+ runBtn.onclick = async () => {{
917
+ const prompt = promptInput.value.trim();
918
+ if(!prompt) {{
919
+ logMsg("Validation failed: Prompt is empty.", "log-error");
920
+ return alert("Enter a prompt");
921
+ }}
922
+
923
+ logMsg("Initializing generation sequence...", "log-info");
924
+
925
+ const fd = new FormData();
926
+ fd.append('prompt', prompt);
927
+ fd.append('seed', document.getElementById('seed').value);
928
+ fd.append('randomize_seed', document.getElementById('randomize').checked);
929
+ fd.append('width', document.getElementById('width').value);
930
+ fd.append('height', document.getElementById('height').value);
931
+ fd.append('steps', document.getElementById('steps').value);
932
+ fd.append('guidance', document.getElementById('guidance').value);
933
+
934
+ filesState.forEach(f => fd.append('images', f));
935
+
936
+ loader.style.display = 'flex';
937
+ runBtn.disabled = true;
938
+ downloadZipBtn.style.display = 'none';
939
+
940
+ logMsg("Sending request to backend. Running both VAE models...", "log-info");
941
+
942
+ try {{
943
+ const res = await fetch('/api/compare', {{ method: 'POST', body: fd }});
944
+ const data = await res.json();
945
+
946
+ if(data.success) {{
947
+ logMsg(`Success! Inference completed. Used seed: ${{data.seed}}`, "log-success");
948
+
949
+ currentStdFilename = data.std_filename;
950
+ currentSmallFilename = data.small_filename;
951
+
952
+ imgStd.src = data.std_url;
953
+ imgSmall.src = data.small_url;
954
+
955
+ imgStd.onload = () => {{
956
+ sliderEmpty.style.display = 'none';
957
+ imgStd.style.display = 'block';
958
+ imgSmall.style.display = 'block';
959
+ sliderHandle.style.display = 'block';
960
+ sliderLabels.style.display = 'flex';
961
+ downloadZipBtn.style.display = 'block'; // Reveal download button
962
+
963
+ // Reset slider to center
964
+ const rect = sliderStage.getBoundingClientRect();
965
+ updateSlider(rect.left + rect.width / 2);
966
+ }};
967
+ }} else {{
968
+ logMsg("Error processing request: " + data.error, "log-error");
969
+ alert('Error: ' + data.error);
970
+ }}
971
+ }} catch(e) {{
972
+ logMsg("Network or server connection failed.", "log-error");
973
+ alert('Failed to connect to server.');
974
+ }} finally {{
975
+ loader.style.display = 'none';
976
+ runBtn.disabled = false;
977
+ logMsg("Sequence finished. Ready for next input.", "");
978
+ }}
979
+ }};
980
+ </script>
981
+ </body>
982
+ </html>
983
+ """
984
 
985
+ app.launch()