prithivMLmods commited on
Commit
0920937
·
verified ·
1 Parent(s): 3b7b858

standard gradio block [truncated]

Browse files
Files changed (1) hide show
  1. app-truncated.py +387 -887
app-truncated.py CHANGED
@@ -1,148 +1,248 @@
1
  import os
2
- import io
3
  import gc
4
- import uuid
5
- import json
6
- import base64
7
  import random
8
- import zipfile
9
- import threading
10
- import concurrent.futures
11
- from pathlib import Path
12
- from typing import List, Optional
13
-
14
  import spaces
15
- import numpy as np
16
  import torch
17
- from PIL import Image
18
-
19
- from gradio import Server
20
- from fastapi import Request, UploadFile, File, Form
21
- from fastapi.responses import HTMLResponse, JSONResponse, FileResponse, StreamingResponse
22
  from diffusers import Flux2KleinPipeline, AutoencoderKLFlux2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # --- App Configuration & Directories ---
25
- app = Server()
26
-
27
- BASE_DIR = Path(__file__).resolve().parent
28
- STATIC_DIR = BASE_DIR / "static"
29
- OUTPUT_DIR = BASE_DIR / "outputs"
30
- EXAMPLES_DIR = BASE_DIR / "examples"
31
-
32
- STATIC_DIR.mkdir(exist_ok=True)
33
- OUTPUT_DIR.mkdir(exist_ok=True)
34
-
35
- MAX_SEED = np.iinfo(np.int32).max
36
- MAX_IMAGE_SIZE = 1024
37
 
38
  dtype = torch.bfloat16
39
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
 
41
- if torch.cuda.is_available():
42
- print("current device:", torch.cuda.current_device())
43
- print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
44
- DEVICE_LABEL = torch.cuda.get_device_name(torch.cuda.current_device()).lower()
45
- else:
46
- DEVICE_LABEL = str(device).lower()
47
 
48
- # --- Model Loading ---
49
  print("Loading 4B Distilled model (Standard VAE)...")
50
  pipe_standard = Flux2KleinPipeline.from_pretrained(
51
  "black-forest-labs/FLUX.2-klein-4B",
52
  torch_dtype=dtype,
53
- ).to(device)
54
  pipe_standard.enable_model_cpu_offload()
55
 
56
  print("Loading Small Decoder VAE...")
57
  vae_small = AutoencoderKLFlux2.from_pretrained(
58
  "black-forest-labs/FLUX.2-small-decoder",
59
  torch_dtype=dtype,
60
- ).to(device)
61
 
62
  print("Loading 4B Distilled model (Small Decoder VAE)...")
63
  pipe_small_decoder = Flux2KleinPipeline.from_pretrained(
64
  "black-forest-labs/FLUX.2-klein-4B",
65
  vae=vae_small,
66
  torch_dtype=dtype,
67
- ).to(device)
68
  pipe_small_decoder.enable_model_cpu_offload()
69
 
70
  pipe_lock_standard = threading.Lock()
71
  pipe_lock_small = threading.Lock()
72
 
73
- # --- Utility Functions ---
74
  def calc_dimensions(pil_img: Image.Image):
 
 
 
 
 
75
  iw, ih = pil_img.size
76
  aspect = iw / ih
77
 
78
- if aspect >= 1:
79
  new_width = 1024
80
  new_height = int(round(1024 / aspect))
81
- else:
82
  new_height = 1024
83
  new_width = int(round(1024 * aspect))
84
 
 
85
  new_width = max(256, min(1024, round(new_width / 8) * 8))
86
  new_height = max(256, min(1024, round(new_height / 8) * 8))
87
  return new_width, new_height
88
 
89
- def parse_and_resize_images(image_paths: List[str], width: int, height: int):
90
- if not image_paths:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  return None
92
-
93
- resized = []
94
- for path in image_paths:
95
- try:
96
- img = Image.open(path).convert("RGB")
97
- resized.append(img.resize((width, height), Image.LANCZOS))
98
- except Exception as e:
99
- print(f"Skipping invalid image: {e}")
100
-
101
- return resized if resized else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  def run_pipeline(pipe, lock, kwargs, seed):
104
  with lock:
105
- gen = torch.Generator(device="cpu").manual_seed(seed)
106
  result = pipe(**kwargs, generator=gen).images[0]
107
  return result
108
 
109
- def save_image(img: Image.Image, prefix: str = "output") -> str:
110
- filename = f"{prefix}_{uuid.uuid4().hex}.png"
111
- path = OUTPUT_DIR / filename
112
- img.save(path, format="PNG")
113
- return filename
114
-
115
- # --- Inference Function ---
116
  @spaces.GPU(duration=120)
117
  def infer(
118
- prompt: str,
119
- image_paths: List[str] = None,
120
- seed: int = 42,
121
- randomize_seed: bool = False,
122
- width: int = 1024,
123
- height: int = 1024,
124
- num_inference_steps: int = 4,
125
- guidance_scale: float = 1.0,
 
126
  ):
127
  gc.collect()
128
- if torch.cuda.is_available():
129
- torch.cuda.empty_cache()
130
 
131
  if not prompt or not prompt.strip():
132
- raise ValueError("Please enter a prompt.")
133
 
134
  if randomize_seed:
135
  seed = random.randint(0, MAX_SEED)
136
 
 
137
  image_list = None
138
- if image_paths and len(image_paths) > 0:
139
- try:
140
- first_pil = Image.open(image_paths[0]).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  width, height = calc_dimensions(first_pil)
142
- image_list = parse_and_resize_images(image_paths, width, height)
143
- except Exception as e:
144
- print(f"Error processing upload: {e}")
145
 
 
 
 
 
146
  width = max(256, min(MAX_IMAGE_SIZE, round(int(width) / 8) * 8))
147
  height = max(256, min(MAX_IMAGE_SIZE, round(int(height) / 8) * 8))
148
 
@@ -156,830 +256,230 @@ def infer(
156
  if image_list is not None:
157
  shared_kwargs["image"] = image_list
158
 
 
 
159
  with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
160
- future_std = executor.submit(run_pipeline, pipe_standard, pipe_lock_standard, shared_kwargs, seed)
161
- future_small = executor.submit(run_pipeline, pipe_small_decoder, pipe_lock_small, shared_kwargs, seed)
162
-
 
 
 
163
  concurrent.futures.wait(
164
  [future_std, future_small],
165
  return_when=concurrent.futures.ALL_COMPLETED,
166
  )
167
 
 
 
168
  out_standard = future_std.result()
169
  out_small = future_small.result()
170
 
171
  gc.collect()
172
- if torch.cuda.is_available():
173
- torch.cuda.empty_cache()
174
 
175
  return out_standard, out_small, seed
176
 
177
 
178
- # --- FastAPI Endpoints ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  def get_example_items():
180
- return [
181
- {
182
- "urls": ["/example-file/I1.jpg", "/example-file/I2.jpg"],
183
- "prompt": "Make her wear these glasses in Image 2."
184
- },
185
- {
186
- "urls": ["/example-file/1.jpg"],
187
- "prompt": "Change the weather to stormy."
188
- },
189
- {
190
- "urls": ["/example-file/2.jpg"],
191
- "prompt": "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition."
192
- },
193
- {
194
- "urls": ["/example-file/3.jpg"],
195
- "prompt": "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent."
196
- },
197
- {
198
- "urls": ["/example-file/4.jpg"],
199
- "prompt": "Make the texture high-resolution."
200
- }
201
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- @app.get("/example-file/{filename}")
204
- async def example_file(filename: str):
205
- path = EXAMPLES_DIR / filename
206
- if not path.exists():
207
- return JSONResponse({"error": "Example not found"}, status_code=404)
208
- return FileResponse(path)
209
-
210
- @app.get("/download/{filename}")
211
- async def download_file(filename: str):
212
- path = OUTPUT_DIR / filename
213
- if not path.exists():
214
- return JSONResponse({"error": "File not found"}, status_code=404)
215
- return FileResponse(path, filename=filename, media_type="image/png")
216
-
217
- @app.get("/api/download-zip")
218
- async def download_zip(std: str, small: str):
219
- """Packages both generated images into a single ZIP file and streams it."""
220
- std_name = Path(std).name
221
- small_name = Path(small).name
222
-
223
- std_path = OUTPUT_DIR / std_name
224
- small_path = OUTPUT_DIR / small_name
225
-
226
- if not std_path.exists() or not small_path.exists():
227
- return JSONResponse({"error": "Generated files not found"}, status_code=404)
228
-
229
- memory_file = io.BytesIO()
230
- with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf:
231
- zf.write(std_path, arcname=f"Standard_Decoder_{std_name}")
232
- zf.write(small_path, arcname=f"Small_Decoder_{small_name}")
233
-
234
- memory_file.seek(0)
235
-
236
- return StreamingResponse(
237
- memory_file,
238
- media_type="application/zip",
239
- headers={"Content-Disposition": f"attachment; filename=Flux2_Comparison_{uuid.uuid4().hex[:6]}.zip"}
240
- )
241
 
242
- @app.post("/api/compare")
243
- async def compare_images(
244
- prompt: str = Form(...),
245
- seed: str = Form("0"),
246
- randomize_seed: str = Form("true"),
247
- width: str = Form("1024"),
248
- height: str = Form("1024"),
249
- steps: str = Form("4"),
250
- guidance: str = Form("1.0"),
251
- images: Optional[List[UploadFile]] = File(None),
252
- ):
253
- temp_paths = []
254
- try:
255
- image_paths = []
256
- if images:
257
- for upload in images:
258
- if not upload.filename: continue
259
- suffix = Path(upload.filename).suffix or ".png"
260
- temp_path = OUTPUT_DIR / f"upload_{uuid.uuid4().hex}{suffix}"
261
- content = await upload.read()
262
- with open(temp_path, "wb") as f:
263
- f.write(content)
264
- temp_paths.append(str(temp_path))
265
- image_paths.append(str(temp_path))
266
-
267
- result_std, result_small, used_seed = infer(
268
- prompt=prompt,
269
- image_paths=image_paths,
270
- seed=int(seed),
271
- randomize_seed=(randomize_seed.lower() == "true"),
272
- width=int(width),
273
- height=int(height),
274
- num_inference_steps=int(steps),
275
- guidance_scale=float(guidance),
276
  )
277
 
278
- std_filename = save_image(result_std, prefix="std")
279
- small_filename = save_image(result_small, prefix="small")
280
-
281
- return JSONResponse({
282
- "success": True,
283
- "seed": used_seed,
284
- "std_url": f"/download/{std_filename}",
285
- "small_url": f"/download/{small_filename}",
286
- "std_filename": std_filename,
287
- "small_filename": small_filename,
288
- "device": DEVICE_LABEL,
289
- })
290
-
291
- except Exception as e:
292
- return JSONResponse({"success": False, "error": str(e)}, status_code=500)
293
- finally:
294
- for p in temp_paths:
295
- if os.path.exists(p):
296
- os.remove(p)
297
-
298
- # --- Frontend ---
299
- @app.get("/", response_class=HTMLResponse)
300
- async def homepage(request: Request):
301
- examples = get_example_items()
302
- examples_json = json.dumps(examples)
303
-
304
- return f"""
305
- <!DOCTYPE html>
306
- <html lang="en">
307
- <head>
308
- <meta charset="UTF-8" />
309
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
310
- <title>Flux.2-4B-Decoder-Comparator</title>
311
- <link href="https://fonts.googleapis.com/css2?family=Ubuntu:wght@300;400;500;700&display=swap" rel="stylesheet">
312
- <style>
313
- :root {{
314
- --ub-aubergine: #2C001E;
315
- --ub-aubergine-dark: #1f0015;
316
- --ub-orange: #E95420;
317
- --ub-orange-hover: #c4461a;
318
- --ub-panel: #3D3D3D;
319
- --ub-panel-light: #4f4f4f;
320
- --ub-border: rgba(255,255,255,0.1);
321
- --ub-text: #FFFFFF;
322
- --ub-muted: #b0b0b0;
323
- --ub-input: #2b2b2b;
324
- --panel-radius: 8px;
325
- }}
326
-
327
- * {{ box-sizing: border-box; font-family: 'Ubuntu', sans-serif; }}
328
-
329
- body {{
330
- margin: 0; padding: 0;
331
- background: var(--ub-aubergine);
332
- color: var(--ub-text);
333
- min-height: 100vh;
334
- display: flex;
335
- flex-direction: column;
336
- }}
337
-
338
- .topbar {{
339
- background: var(--ub-aubergine-dark);
340
- padding: 16px 24px;
341
- border-bottom: 1px solid var(--ub-border);
342
- text-align: center;
343
- font-weight: 700;
344
- letter-spacing: 0.5px;
345
- color: var(--ub-orange);
346
- }}
347
-
348
- .container {{
349
- max-width: 1300px;
350
- margin: 0 auto;
351
- padding: 30px 20px;
352
- flex: 1;
353
- width: 100%;
354
- }}
355
-
356
- .header-text {{
357
- text-align: center;
358
- margin-bottom: 30px;
359
- }}
360
- .header-text h1 {{
361
- margin: 0 0 10px 0;
362
- font-size: 2.2rem;
363
- }}
364
- .header-text p {{
365
- color: var(--ub-muted);
366
- margin: 0;
367
- }}
368
-
369
- /* FIXED LAYOUT GRID */
370
- .layout {{
371
- display: grid;
372
- grid-template-columns: 420px 1fr;
373
- gap: 24px;
374
- align-items: stretch;
375
- height: 650px;
376
- }}
377
-
378
- .panel {{
379
- background: var(--ub-panel);
380
- border-radius: var(--panel-radius);
381
- box-shadow: 0 8px 24px rgba(0,0,0,0.2);
382
- display: flex;
383
- flex-direction: column;
384
- overflow: hidden;
385
- height: 100%;
386
- }}
387
-
388
- .panel-header {{
389
- padding: 16px 20px;
390
- background: rgba(0,0,0,0.2);
391
- border-bottom: 1px solid var(--ub-border);
392
- font-weight: 500;
393
- font-size: 1.1rem;
394
- flex-shrink: 0;
395
- display: flex;
396
- justify-content: space-between;
397
- align-items: center;
398
- }}
399
-
400
- .panel-body-scroll {{
401
- flex: 1;
402
- padding: 20px;
403
- overflow-y: auto;
404
- display: flex;
405
- flex-direction: column;
406
- }}
407
-
408
- /* Input Forms */
409
- .form-group {{ margin-bottom: 20px; flex-shrink: 0; }}
410
- .label {{
411
- display: block; font-weight: 500; font-size: 14px;
412
- color: var(--ub-muted); margin-bottom: 8px;
413
- }}
414
-
415
- .textarea, .input {{
416
- width: 100%;
417
- background: var(--ub-input);
418
- border: 1px solid var(--ub-border);
419
- color: var(--ub-text);
420
- padding: 12px;
421
- border-radius: 4px;
422
- outline: none;
423
- font-size: 14px;
424
- }}
425
- .textarea:focus, .input:focus {{ border-color: var(--ub-orange); }}
426
- .textarea {{ min-height: 100px; resize: vertical; }}
427
-
428
- /* Upload Zone */
429
- .upload-zone {{
430
- background: var(--ub-input);
431
- border: 1px dashed var(--ub-muted);
432
- border-radius: 4px;
433
- padding: 15px;
434
- text-align: center;
435
- cursor: pointer;
436
- transition: background 0.2s, border-color 0.2s;
437
- min-height: 100px;
438
- display: flex;
439
- flex-direction: column;
440
- justify-content: center;
441
- align-items: center;
442
- }}
443
- .upload-zone:hover, .upload-zone.dragover {{
444
- border-color: var(--ub-orange);
445
- background: rgba(233,84,32,0.05);
446
- }}
447
- .upload-zone input[type="file"] {{ display: none; }}
448
- .upload-text {{ pointer-events: none; color: var(--ub-muted); }}
449
-
450
- .preview-grid {{
451
- display: none;
452
- grid-template-columns: repeat(auto-fill, minmax(70px, 1fr));
453
- gap: 10px;
454
- width: 100%;
455
- }}
456
- .thumb {{
457
- position: relative; aspect-ratio: 1;
458
- border-radius: 4px; overflow: hidden;
459
- border: 1px solid var(--ub-border);
460
- }}
461
- .thumb img {{ width: 100%; height: 100%; object-fit: cover; display: block; }}
462
- .thumb-remove {{
463
- position: absolute; top: 4px; right: 4px;
464
- background: rgba(0,0,0,0.7); color: white;
465
- border: none; border-radius: 50%; width: 20px; height: 20px;
466
- display: flex; align-items: center; justify-content: center;
467
- cursor: pointer; font-size: 12px;
468
- }}
469
-
470
- .add-more-btn {{
471
- display: flex; align-items: center; justify-content: center;
472
- border: 2px dashed var(--ub-muted); border-radius: 4px;
473
- color: var(--ub-muted); font-size: 26px; cursor: pointer;
474
- aspect-ratio: 1; transition: all 0.2s; background: transparent;
475
- }}
476
- .add-more-btn:hover {{
477
- border-color: var(--ub-orange); color: var(--ub-orange);
478
- background: rgba(233,84,32,0.05);
479
- }}
480
-
481
- /* Advanced Accordion */
482
- .advanced-toggle {{
483
- width: 100%; background: none; border: none; color: var(--ub-orange);
484
- text-align: left; padding: 10px 0; font-weight: 500; cursor: pointer;
485
- display: flex; justify-content: space-between; align-items: center;
486
- flex-shrink: 0;
487
- }}
488
- .advanced-icon {{ font-weight: bold; font-size: 18px; line-height: 1; }}
489
- .advanced-body {{ display: none; padding-top: 10px; flex-shrink: 0; }}
490
- .advanced-body.open {{ display: block; }}
491
- .grid-2 {{ display: grid; grid-template-columns: 1fr 1fr; gap: 15px; }}
492
-
493
- /* Status Container */
494
- .status-container {{
495
- margin-top: 20px; margin-bottom: 20px;
496
- border: 1px solid var(--ub-border); border-radius: 4px;
497
- background: #200014; display: flex; flex-direction: column;
498
- flex: 1; min-height: 100px; max-height: 200px;
499
- }}
500
- .status-header {{
501
- padding: 8px 12px; font-size: 11px; font-weight: 700;
502
- color: var(--ub-muted); background: rgba(0,0,0,0.4);
503
- border-bottom: 1px solid var(--ub-border); text-transform: uppercase;
504
- letter-spacing: 0.5px; flex-shrink: 0;
505
- }}
506
- .status-log {{
507
- padding: 10px; font-family: 'Courier New', Courier, monospace;
508
- font-size: 12px; color: #eeeeee; overflow-y: auto;
509
- flex: 1; display: flex; flex-direction: column; gap: 4px;
510
- }}
511
- .log-time {{ color: #777; margin-right: 8px; }}
512
- .log-info {{ color: #5bc0eb; }}
513
- .log-success {{ color: #9bc53d; }}
514
- .log-error {{ color: #ff5e5b; }}
515
-
516
- /* Buttons */
517
- .btn {{
518
- width: 100%; padding: 14px; border: none; border-radius: 4px;
519
- font-size: 16px; font-weight: 700; cursor: pointer;
520
- transition: opacity 0.2s, background 0.2s; flex-shrink: 0;
521
- }}
522
- .btn-primary {{
523
- background: var(--ub-orange); color: white;
524
- box-shadow: 0 4px 12px rgba(233,84,32,0.3);
525
- }}
526
- .btn-primary:hover {{ background: var(--ub-orange-hover); }}
527
- .btn:disabled {{ opacity: 0.6; cursor: not-allowed; }}
528
-
529
- /* Top-Right Download Icon */
530
- .action-icon {{
531
- display: none; background: none; border: none; color: var(--ub-muted);
532
- cursor: pointer; padding: 4px; transition: color 0.2s;
533
- }}
534
- .action-icon:hover {{ color: var(--ub-orange); }}
535
-
536
- /* SLIDER CONTAINER */
537
- .panel-body-slider {{
538
- flex: 1; display: flex; flex-direction: column;
539
- padding: 0; position: relative;
540
- }}
541
- .slider-stage {{
542
- position: absolute; top: 0; left: 0; right: 0; bottom: 0;
543
- background: #111; overflow: hidden; display: flex;
544
- align-items: center; justify-content: center;
545
- }}
546
- .slider-empty {{ color: var(--ub-muted); text-align: center; z-index: 1; }}
547
-
548
- .slider-img {{
549
- position: absolute; top: 0; left: 0; width: 100%; height: 100%;
550
- object-fit: contain; display: none; user-select: none; -webkit-user-drag: none;
551
- }}
552
- #imgSmall {{ clip-path: inset(0 50% 0 0); }}
553
-
554
- .slider-handle {{
555
- position: absolute; left: 50%; top: 0; bottom: 0;
556
- width: 4px; background: var(--ub-orange); cursor: ew-resize; display: none; z-index: 10;
557
- }}
558
- .slider-handle::after {{
559
- content: '◀ ▶'; position: absolute; top: 50%; left: 50%;
560
- transform: translate(-50%, -50%); width: 40px; height: 30px;
561
- background: var(--ub-orange); color: white; border-radius: 15px;
562
- display: flex; align-items: center; justify-content: center;
563
- font-size: 10px; font-weight: bold; box-shadow: 0 2px 6px rgba(0,0,0,0.5);
564
- }}
565
-
566
- .slider-labels {{
567
- position: absolute; top: 15px; left: 15px; right: 15px;
568
- display: none; justify-content: space-between;
569
- pointer-events: none; z-index: 5;
570
- }}
571
- .badge {{
572
- background: rgba(0,0,0,0.6); color: white; padding: 6px 12px;
573
- border-radius: 20px; font-size: 13px; backdrop-filter: blur(4px);
574
- }}
575
-
576
- /* UPDATED LOADER ANIMATION (Minimalist Single Circle) */
577
- .loader {{
578
- position: absolute; inset: 0;
579
- background: rgba(20, 0, 10, 0.7); /* dark aubergine tint */
580
- backdrop-filter: blur(6px);
581
- display: none; flex-direction: column;
582
- align-items: center; justify-content: center; z-index: 20;
583
- }}
584
- .spinner-single {{
585
- width: 55px; height: 55px;
586
- border: 3px solid rgba(255, 255, 255, 0.1);
587
- border-top-color: var(--ub-orange);
588
- border-radius: 50%;
589
- animation: spin 1s cubic-bezier(0.4, 0.0, 0.2, 1) infinite;
590
- margin-bottom: 20px;
591
- }}
592
- .loader-text {{
593
- font-weight: 500;
594
- font-size: 15px;
595
- color: #ffffff;
596
- letter-spacing: 1px;
597
- animation: pulse 1.5s ease-in-out infinite;
598
- }}
599
- @keyframes pulse {{
600
- 0%, 100% {{ opacity: 1; }}
601
- 50% {{ opacity: 0.5; }}
602
- }}
603
- @keyframes spin {{
604
- to {{ transform: rotate(360deg); }}
605
- }}
606
-
607
- /* Examples */
608
- .examples-section {{ margin-top: 40px; }}
609
- .examples-section h3 {{ border-bottom: 1px solid var(--ub-border); padding-bottom: 10px; }}
610
- .examples-grid {{
611
- display: grid; grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); gap: 20px;
612
- }}
613
- .ex-card {{
614
- background: var(--ub-panel); border-radius: 4px; overflow: hidden;
615
- cursor: pointer; transition: transform 0.2s, box-shadow 0.2s;
616
- }}
617
- .ex-card:hover {{ transform: translateY(-3px); box-shadow: 0 6px 16px rgba(0,0,0,0.3); }}
618
- .ex-card-img-wrap {{ width: 100%; aspect-ratio: 1; display: flex; background: #000; }}
619
- .ex-card-img-wrap img {{ height: 100%; object-fit: cover; }}
620
- .ex-card p {{ padding: 12px; margin: 0; font-size: 13px; color: var(--ub-muted); line-height: 1.4; }}
621
-
622
- @media (max-width: 900px) {{
623
- .layout {{ grid-template-columns: 1fr; height: auto; }}
624
- .panel-body-slider {{ height: 450px; flex: none; }}
625
- .slider-stage {{ position: relative; height: 100%; }}
626
- }}
627
- </style>
628
- </head>
629
- <body>
630
-
631
- <div class="topbar">Flux.2-4B VAE Decoder Comparator</div>
632
-
633
- <div class="container">
634
- <div class="header-text">
635
- <h1>Standard vs. Small Decoder</h1>
636
- <p>Upload an image, enter a prompt, and use the slider to compare outputs in real-time.</p>
637
- </div>
638
-
639
- <div class="layout">
640
- <div class="panel">
641
- <div class="panel-header">Settings</div>
642
- <div class="panel-body-scroll">
643
- <div class="form-group">
644
- <label class="label">Input Images (Optional)</label>
645
- <div class="upload-zone" id="dropZone">
646
- <input type="file" id="fileInput" multiple accept="image/*" />
647
- <div class="upload-text" id="uploadText">Click or Drag & Drop images here</div>
648
- <div class="preview-grid" id="previewGrid"></div>
649
- </div>
650
- </div>
651
-
652
- <div class="form-group">
653
- <label class="label">Prompt</label>
654
- <textarea id="promptInput" class="textarea" placeholder="Describe the edit or generation..."></textarea>
655
- </div>
656
-
657
- <button class="advanced-toggle" id="advToggle">
658
- <span>Advanced Settings</span> <span class="advanced-icon" id="advIcon">+</span>
659
- </button>
660
-
661
- <div class="advanced-body" id="advBody">
662
- <div class="grid-2">
663
- <div class="form-group">
664
- <label class="label">Seed</label>
665
- <input type="number" id="seed" class="input" value="0">
666
- </div>
667
- <div class="form-group">
668
- <label class="label">Steps</label>
669
- <input type="number" id="steps" class="input" value="4">
670
- </div>
671
- <div class="form-group">
672
- <label class="label">Width</label>
673
- <input type="number" id="width" class="input" value="1024" step="8">
674
- </div>
675
- <div class="form-group">
676
- <label class="label">Height</label>
677
- <input type="number" id="height" class="input" value="1024" step="8">
678
- </div>
679
- <div class="form-group" style="grid-column: span 2;">
680
- <label class="label">Guidance Scale</label>
681
- <input type="number" id="guidance" class="input" value="1.0" step="0.1">
682
- </div>
683
- <div class="form-group" style="grid-column: span 2;">
684
- <label style="display:flex; align-items:center; gap:8px; font-size:14px; color:var(--ub-text);">
685
- <input type="checkbox" id="randomize" checked> Randomize Seed
686
- </label>
687
- </div>
688
- </div>
689
- </div>
690
-
691
- <div class="status-container">
692
- <div class="status-header">Execution Log</div>
693
- <div class="status-log" id="statusLog">
694
- <div><span class="log-time">[{DEVICE_LABEL}]</span><span>System Ready...</span></div>
695
- </div>
696
- </div>
697
-
698
- <button class="btn btn-primary" id="runBtn">Run Comparison</button>
699
- </div>
700
- </div>
701
-
702
- <div class="panel">
703
- <div class="panel-header">
704
- <span>Comparison View</span>
705
- <button id="downloadZipBtn" class="action-icon" title="Download Both Images (ZIP)">
706
- <svg width="22" height="22" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" viewBox="0 0 24 24">
707
- <path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4"></path>
708
- <polyline points="7 10 12 15 17 10"></polyline>
709
- <line x1="12" y1="15" x2="12" y2="3"></line>
710
- </svg>
711
- </button>
712
- </div>
713
- <div class="panel-body-slider">
714
- <div class="slider-stage" id="sliderStage">
715
- <div class="slider-empty" id="sliderEmpty">
716
- <svg width="48" height="48" fill="none" stroke="currentColor" stroke-width="1.5" viewBox="0 0 24 24" style="margin-bottom:10px; opacity:0.5;">
717
- <path d="M4 16l4.586-4.586a2 2 0 012.828 0L16 16m-2-2l1.586-1.586a2 2 0 012.828 0L20 14m-6-6h.01M6 20h12a2 2 0 002-2V6a2 2 0 00-2-2H6a2 2 0 00-2 2v12a2 2 0 002 2z"></path>
718
- </svg>
719
- <div>Results will appear here</div>
720
- </div>
721
-
722
- <img id="imgStd" class="slider-img" alt="Standard Decoder" />
723
- <img id="imgSmall" class="slider-img" alt="Small Decoder" />
724
-
725
- <div class="slider-labels" id="sliderLabels">
726
- <div class="badge">Standard Decoder</div>
727
- <div class="badge">Small Decoder</div>
728
- </div>
729
-
730
- <div class="slider-handle" id="sliderHandle"></div>
731
-
732
- <div class="loader" id="loader">
733
- <div class="spinner-single"></div>
734
- <div class="loader-text">Running both decoders...</div>
735
- </div>
736
- </div>
737
- </div>
738
- </div>
739
- </div>
740
-
741
- <div class="examples-section">
742
- <h3>Examples</h3>
743
- <div class="examples-grid" id="examplesGrid"></div>
744
- </div>
745
- </div>
746
-
747
- <script>
748
- const examples = {examples_json};
749
- let filesState = [];
750
- let currentStdFilename = "";
751
- let currentSmallFilename = "";
752
-
753
- // UI Elements
754
- const dropZone = document.getElementById('dropZone');
755
- const fileInput = document.getElementById('fileInput');
756
- const previewGrid = document.getElementById('previewGrid');
757
- const uploadText = document.getElementById('uploadText');
758
- const promptInput = document.getElementById('promptInput');
759
- const runBtn = document.getElementById('runBtn');
760
- const downloadZipBtn = document.getElementById('downloadZipBtn');
761
-
762
- // Status Log
763
- const statusLog = document.getElementById('statusLog');
764
-
765
- function logMsg(msg, styleClass="") {{
766
- const div = document.createElement('div');
767
- const timeStr = new Date().toLocaleTimeString('en-US', {{hour12:false}});
768
- div.innerHTML = `<span class="log-time">[${{timeStr}}]</span><span class="${{styleClass}}">${{msg}}</span>`;
769
- statusLog.appendChild(div);
770
- statusLog.scrollTop = statusLog.scrollHeight; // auto-scroll to bottom
771
- }}
772
-
773
- // Slider Elements
774
- const sliderStage = document.getElementById('sliderStage');
775
- const imgStd = document.getElementById('imgStd');
776
- const imgSmall = document.getElementById('imgSmall');
777
- const sliderHandle = document.getElementById('sliderHandle');
778
- const sliderLabels = document.getElementById('sliderLabels');
779
- const sliderEmpty = document.getElementById('sliderEmpty');
780
- const loader = document.getElementById('loader');
781
-
782
- // Advanced Toggle logic (+ / -)
783
- document.getElementById('advToggle').onclick = function() {{
784
- const body = document.getElementById('advBody');
785
- body.classList.toggle('open');
786
- document.getElementById('advIcon').innerText = body.classList.contains('open') ? '−' : '+';
787
- }};
788
-
789
- // --- File Upload Logic ---
790
- function renderPreviews() {{
791
- previewGrid.innerHTML = '';
792
- if(filesState.length > 0) {{
793
- uploadText.style.display = 'none';
794
- previewGrid.style.display = 'grid';
795
-
796
- filesState.forEach((f, i) => {{
797
- const div = document.createElement('div');
798
- div.className = 'thumb';
799
- const img = document.createElement('img');
800
- img.src = URL.createObjectURL(f);
801
- const btn = document.createElement('button');
802
- btn.className = 'thumb-remove';
803
- btn.innerText = '×';
804
- btn.onclick = (e) => {{ e.stopPropagation(); filesState.splice(i, 1); renderPreviews(); }};
805
- div.appendChild(img); div.appendChild(btn);
806
- previewGrid.appendChild(div);
807
- }});
808
-
809
- // Append dynamic + button
810
- const addBtn = document.createElement('div');
811
- addBtn.className = 'add-more-btn';
812
- addBtn.innerHTML = '+';
813
- addBtn.onclick = (e) => {{ e.stopPropagation(); fileInput.click(); }};
814
- previewGrid.appendChild(addBtn);
815
-
816
- }} else {{
817
- uploadText.style.display = 'block';
818
- previewGrid.style.display = 'none';
819
- }}
820
- }}
821
-
822
- dropZone.onclick = (e) => {{ if(e.target === dropZone || e.target === uploadText) fileInput.click(); }};
823
- fileInput.onchange = (e) => {{ filesState.push(...Array.from(e.target.files)); renderPreviews(); fileInput.value=''; }};
824
- dropZone.ondragover = (e) => {{ e.preventDefault(); dropZone.classList.add('dragover'); }};
825
- dropZone.ondragleave = () => dropZone.classList.remove('dragover');
826
- dropZone.ondrop = (e) => {{
827
- e.preventDefault(); dropZone.classList.remove('dragover');
828
- if(e.dataTransfer.files.length) {{ filesState.push(...Array.from(e.dataTransfer.files)); renderPreviews(); }}
829
- }};
830
-
831
- // --- Examples Logic ---
832
- async function loadExample(urls, text) {{
833
- filesState = [];
834
- renderPreviews();
835
- promptInput.value = text;
836
- logMsg("Loading example: " + text, "log-info");
837
-
838
- try {{
839
- for(let i=0; i<urls.length; i++) {{
840
- const res = await fetch(urls[i]);
841
- const blob = await res.blob();
842
- const filename = urls[i].split('/').pop();
843
- filesState.push(new File([blob], filename, {{type: blob.type}}));
844
- }}
845
- renderPreviews();
846
-
847
- window.scrollTo({{top: 0, behavior: 'smooth'}});
848
-
849
- setTimeout(() => {{
850
- logMsg("Example loaded. Starting comparison...", "log-info");
851
- runBtn.click();
852
- }}, 500);
853
-
854
- }} catch (e) {{
855
- logMsg("Failed to load example images.", "log-error");
856
- alert('Failed to load example image.');
857
- }}
858
- }}
859
-
860
- const exGrid = document.getElementById('examplesGrid');
861
- examples.forEach(ex => {{
862
- const card = document.createElement('div');
863
- card.className = 'ex-card';
864
-
865
- let imgHTML = '';
866
- if(ex.urls.length > 1) {{
867
- imgHTML = `
868
- <div class="ex-card-img-wrap">
869
- <img src="${{ex.urls[0]}}" style="width:50%; border-right:1px solid #000;">
870
- <img src="${{ex.urls[1]}}" style="width:50%;">
871
- </div>
872
- `;
873
- }} else {{
874
- imgHTML = `<div class="ex-card-img-wrap"><img src="${{ex.urls[0]}}" style="width:100%;"></div>`;
875
- }}
876
-
877
- card.innerHTML = `${{imgHTML}}<p>${{ex.prompt}}</p>`;
878
- card.onclick = () => loadExample(ex.urls, ex.prompt);
879
- exGrid.appendChild(card);
880
- }});
881
-
882
- // --- Image Slider Logic ---
883
- let isDragging = false;
884
-
885
- function updateSlider(clientX) {{
886
- const rect = sliderStage.getBoundingClientRect();
887
- let pos = Math.max(0, Math.min(clientX - rect.left, rect.width));
888
- let percent = (pos / rect.width) * 100;
889
-
890
- sliderHandle.style.left = percent + '%';
891
- imgSmall.style.clipPath = `inset(0 ${{100 - percent}}% 0 0)`;
892
- }}
893
-
894
- sliderHandle.addEventListener('mousedown', () => isDragging = true);
895
- window.addEventListener('mouseup', () => isDragging = false);
896
- window.addEventListener('mousemove', (e) => {{
897
- if (!isDragging) return;
898
- updateSlider(e.clientX);
899
- }});
900
-
901
- sliderHandle.addEventListener('touchstart', () => isDragging = true);
902
- window.addEventListener('touchend', () => isDragging = false);
903
- window.addEventListener('touchmove', (e) => {{
904
- if (!isDragging) return;
905
- updateSlider(e.touches[0].clientX);
906
- }});
907
-
908
- // --- Download Zip Logic ---
909
- downloadZipBtn.onclick = () => {{
910
- if(!currentStdFilename || !currentSmallFilename) return;
911
- logMsg("Initiating ZIP download...", "log-info");
912
- window.location.href = `/api/download-zip?std=${{currentStdFilename}}&small=${{currentSmallFilename}}`;
913
- }};
914
-
915
- // --- Form Submission ---
916
- runBtn.onclick = async () => {{
917
- const prompt = promptInput.value.trim();
918
- if(!prompt) {{
919
- logMsg("Validation failed: Prompt is empty.", "log-error");
920
- return alert("Enter a prompt");
921
- }}
922
-
923
- logMsg("Initializing generation sequence...", "log-info");
924
-
925
- const fd = new FormData();
926
- fd.append('prompt', prompt);
927
- fd.append('seed', document.getElementById('seed').value);
928
- fd.append('randomize_seed', document.getElementById('randomize').checked);
929
- fd.append('width', document.getElementById('width').value);
930
- fd.append('height', document.getElementById('height').value);
931
- fd.append('steps', document.getElementById('steps').value);
932
- fd.append('guidance', document.getElementById('guidance').value);
933
-
934
- filesState.forEach(f => fd.append('images', f));
935
-
936
- loader.style.display = 'flex';
937
- runBtn.disabled = true;
938
- downloadZipBtn.style.display = 'none';
939
-
940
- logMsg("Sending request to backend. Running both VAE models...", "log-info");
941
-
942
- try {{
943
- const res = await fetch('/api/compare', {{ method: 'POST', body: fd }});
944
- const data = await res.json();
945
-
946
- if(data.success) {{
947
- logMsg(`Success! Inference completed. Used seed: ${{data.seed}}`, "log-success");
948
-
949
- currentStdFilename = data.std_filename;
950
- currentSmallFilename = data.small_filename;
951
-
952
- imgStd.src = data.std_url;
953
- imgSmall.src = data.small_url;
954
-
955
- imgStd.onload = () => {{
956
- sliderEmpty.style.display = 'none';
957
- imgStd.style.display = 'block';
958
- imgSmall.style.display = 'block';
959
- sliderHandle.style.display = 'block';
960
- sliderLabels.style.display = 'flex';
961
- downloadZipBtn.style.display = 'block'; // Reveal download button
962
-
963
- // Reset slider to center
964
- const rect = sliderStage.getBoundingClientRect();
965
- updateSlider(rect.left + rect.width / 2);
966
- }};
967
- }} else {{
968
- logMsg("Error processing request: " + data.error, "log-error");
969
- alert('Error: ' + data.error);
970
- }}
971
- }} catch(e) {{
972
- logMsg("Network or server connection failed.", "log-error");
973
- alert('Failed to connect to server.');
974
- }} finally {{
975
- loader.style.display = 'none';
976
- runBtn.disabled = false;
977
- logMsg("Sequence finished. Ready for next input.", "");
978
- }}
979
- }};
980
- </script>
981
- </body>
982
- </html>
983
- """
984
 
985
- app.launch()
 
 
 
 
 
 
 
1
  import os
 
2
  import gc
3
+ import gradio as gr
4
+ import numpy as np
 
5
  import random
 
 
 
 
 
 
6
  import spaces
 
7
  import torch
 
 
 
 
 
8
  from diffusers import Flux2KleinPipeline, AutoencoderKLFlux2
9
+ from PIL import Image
10
+ from pathlib import Path
11
+ import concurrent.futures
12
+ import threading
13
+ from typing import Iterable
14
+
15
+ from gradio.themes import Soft
16
+ from gradio.themes.utils import colors, fonts, sizes
17
+
18
+ colors.orange_red = colors.Color(
19
+ name="orange_red",
20
+ c50="#FFF0E5",
21
+ c100="#FFE0CC",
22
+ c200="#FFC299",
23
+ c300="#FFA366",
24
+ c400="#FF8533",
25
+ c500="#FF4500",
26
+ c600="#E63E00",
27
+ c700="#CC3700",
28
+ c800="#B33000",
29
+ c900="#992900",
30
+ c950="#802200",
31
+ )
32
+
33
+ class OrangeRedTheme(Soft):
34
+ def __init__(
35
+ self,
36
+ *,
37
+ primary_hue: colors.Color | str = colors.gray,
38
+ secondary_hue: colors.Color | str = colors.orange_red,
39
+ neutral_hue: colors.Color | str = colors.slate,
40
+ text_size: sizes.Size | str = sizes.text_lg,
41
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
42
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
43
+ ),
44
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
45
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
46
+ ),
47
+ ):
48
+ super().__init__(
49
+ primary_hue=primary_hue,
50
+ secondary_hue=secondary_hue,
51
+ neutral_hue=neutral_hue,
52
+ text_size=text_size,
53
+ font=font,
54
+ font_mono=font_mono,
55
+ )
56
+ super().set(
57
+ background_fill_primary="*primary_50",
58
+ background_fill_primary_dark="*primary_900",
59
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
60
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
61
+ button_primary_text_color="white",
62
+ button_primary_text_color_hover="white",
63
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
64
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
65
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
66
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
67
+ button_secondary_text_color="black",
68
+ button_secondary_text_color_hover="white",
69
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
70
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
71
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
72
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
73
+ slider_color="*secondary_500",
74
+ slider_color_dark="*secondary_600",
75
+ block_title_text_weight="600",
76
+ block_border_width="3px",
77
+ block_shadow="*shadow_drop_lg",
78
+ button_primary_shadow="*shadow_drop_lg",
79
+ button_large_padding="11px",
80
+ color_accent_soft="*primary_100",
81
+ block_label_background_fill="*primary_200",
82
+ )
83
 
84
+ orange_red_theme = OrangeRedTheme()
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  dtype = torch.bfloat16
87
+ device = "cuda" if torch.cuda.is_available() else "cpu"
88
 
89
+ MAX_SEED = np.iinfo(np.int32).max
90
+ MAX_IMAGE_SIZE = 1024
91
+ EXAMPLES_DIR = Path("examples")
 
 
 
92
 
 
93
  print("Loading 4B Distilled model (Standard VAE)...")
94
  pipe_standard = Flux2KleinPipeline.from_pretrained(
95
  "black-forest-labs/FLUX.2-klein-4B",
96
  torch_dtype=dtype,
97
+ )
98
  pipe_standard.enable_model_cpu_offload()
99
 
100
  print("Loading Small Decoder VAE...")
101
  vae_small = AutoencoderKLFlux2.from_pretrained(
102
  "black-forest-labs/FLUX.2-small-decoder",
103
  torch_dtype=dtype,
104
+ )
105
 
106
  print("Loading 4B Distilled model (Small Decoder VAE)...")
107
  pipe_small_decoder = Flux2KleinPipeline.from_pretrained(
108
  "black-forest-labs/FLUX.2-klein-4B",
109
  vae=vae_small,
110
  torch_dtype=dtype,
111
+ )
112
  pipe_small_decoder.enable_model_cpu_offload()
113
 
114
  pipe_lock_standard = threading.Lock()
115
  pipe_lock_small = threading.Lock()
116
 
 
117
  def calc_dimensions(pil_img: Image.Image):
118
+ """
119
+ Given a PIL image return (width, height) snapped to multiples of 8,
120
+ fitting within 1024 px on the long side, min 256 px on each side.
121
+ Uses round() so we match the reference app exactly.
122
+ """
123
  iw, ih = pil_img.size
124
  aspect = iw / ih
125
 
126
+ if aspect >= 1: # landscape / square
127
  new_width = 1024
128
  new_height = int(round(1024 / aspect))
129
+ else: # portrait
130
  new_height = 1024
131
  new_width = int(round(1024 * aspect))
132
 
133
+ # snap to 8-pixel grid with round(), clamp to [256, 1024]
134
  new_width = max(256, min(1024, round(new_width / 8) * 8))
135
  new_height = max(256, min(1024, round(new_height / 8) * 8))
136
  return new_width, new_height
137
 
138
+
139
+ def update_dimensions_from_image(image_list):
140
+ """
141
+ Called by the gallery .upload() event.
142
+ Returns updated slider values for width and height.
143
+ """
144
+ if not image_list:
145
+ return 1024, 1024
146
+
147
+ # gallery items arrive as PIL images when type="pil"
148
+ item = image_list[0]
149
+ img = item[0] if isinstance(item, tuple) else item
150
+
151
+ if isinstance(img, str):
152
+ img = Image.open(img).convert("RGB")
153
+ elif not isinstance(img, Image.Image):
154
+ return 1024, 1024
155
+
156
+ return calc_dimensions(img)
157
+
158
+ def parse_and_resize_images(input_images, width: int, height: int):
159
+ """
160
+ Parse the gallery input and resize every frame to (width, height).
161
+ Returns a list[PIL.Image] or None.
162
+ """
163
+ if input_images is None:
164
  return None
165
+
166
+ raw_list = []
167
+
168
+ if isinstance(input_images, str):
169
+ if os.path.exists(input_images):
170
+ raw_list = [Image.open(input_images).convert("RGB")]
171
+ elif isinstance(input_images, Image.Image):
172
+ raw_list = [input_images.convert("RGB")]
173
+ elif isinstance(input_images, list):
174
+ for item in input_images:
175
+ try:
176
+ src = item[0] if isinstance(item, tuple) else item
177
+ if isinstance(src, str):
178
+ raw_list.append(Image.open(src).convert("RGB"))
179
+ elif isinstance(src, Image.Image):
180
+ raw_list.append(src.convert("RGB"))
181
+ elif hasattr(src, "name"):
182
+ raw_list.append(Image.open(src.name).convert("RGB"))
183
+ except Exception as e:
184
+ print(f"Skipping invalid image: {e}")
185
+
186
+ if not raw_list:
187
+ return None
188
+
189
+ resized = [
190
+ img.resize((width, height), Image.LANCZOS)
191
+ for img in raw_list
192
+ ]
193
+ return resized
194
 
195
  def run_pipeline(pipe, lock, kwargs, seed):
196
  with lock:
197
+ gen = torch.Generator(device="cpu").manual_seed(seed)
198
  result = pipe(**kwargs, generator=gen).images[0]
199
  return result
200
 
 
 
 
 
 
 
 
201
  @spaces.GPU(duration=120)
202
  def infer(
203
+ prompt,
204
+ input_images=None,
205
+ seed=42,
206
+ randomize_seed=False,
207
+ width=1024,
208
+ height=1024,
209
+ num_inference_steps=4,
210
+ guidance_scale=1.0,
211
+ progress=gr.Progress(track_tqdm=True),
212
  ):
213
  gc.collect()
214
+ torch.cuda.empty_cache()
 
215
 
216
  if not prompt or not prompt.strip():
217
+ raise gr.Error("Please enter a prompt.")
218
 
219
  if randomize_seed:
220
  seed = random.randint(0, MAX_SEED)
221
 
222
+ # ── width / height: derive from the first uploaded image if present ──
223
  image_list = None
224
+ if input_images:
225
+ # Re-derive dimensions from the actual first image so they are
226
+ # always consistent with what the pipeline will receive.
227
+ item = (
228
+ input_images[0][0]
229
+ if isinstance(input_images[0], tuple)
230
+ else input_images[0]
231
+ )
232
+ if isinstance(item, str):
233
+ first_pil = Image.open(item).convert("RGB")
234
+ elif isinstance(item, Image.Image):
235
+ first_pil = item.convert("RGB")
236
+ else:
237
+ first_pil = None
238
+
239
+ if first_pil is not None:
240
  width, height = calc_dimensions(first_pil)
 
 
 
241
 
242
+ # parse + resize all images to the final (width, height)
243
+ image_list = parse_and_resize_images(input_images, width, height)
244
+
245
+ # ensure dims are multiples of 8 even for text-only runs
246
  width = max(256, min(MAX_IMAGE_SIZE, round(int(width) / 8) * 8))
247
  height = max(256, min(MAX_IMAGE_SIZE, round(int(height) / 8) * 8))
248
 
 
256
  if image_list is not None:
257
  shared_kwargs["image"] = image_list
258
 
259
+ progress(0.30, desc="Launching both pipelines simultaneously...")
260
+
261
  with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
262
+ future_std = executor.submit(
263
+ run_pipeline, pipe_standard, pipe_lock_standard, shared_kwargs, seed
264
+ )
265
+ future_small = executor.submit(
266
+ run_pipeline, pipe_small_decoder, pipe_lock_small, shared_kwargs, seed
267
+ )
268
  concurrent.futures.wait(
269
  [future_std, future_small],
270
  return_when=concurrent.futures.ALL_COMPLETED,
271
  )
272
 
273
+ progress(0.80, desc="✅ Both pipelines done!")
274
+
275
  out_standard = future_std.result()
276
  out_small = future_small.result()
277
 
278
  gc.collect()
279
+ torch.cuda.empty_cache()
 
280
 
281
  return out_standard, out_small, seed
282
 
283
 
284
+ @spaces.GPU(duration=120)
285
+ def infer_example(prompt):
286
+ out_std, out_small, seed_used = infer(
287
+ prompt=prompt,
288
+ input_images=None,
289
+ seed=0,
290
+ randomize_seed=True,
291
+ width=1024,
292
+ height=1024,
293
+ num_inference_steps=4,
294
+ guidance_scale=1.0,
295
+ )
296
+ return out_std, out_small, seed_used
297
+
298
+
299
  def get_example_items():
300
+ example_prompts = {
301
+ "1.jpg": "Change the weather to stormy.",
302
+ "2.jpg": "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition.",
303
+ "3.jpg": "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent.",
304
+ "4.jpg": "Make the texture high-resolution.",
305
+ }
306
+ items = []
307
+ if EXAMPLES_DIR.exists():
308
+ for name in sorted(os.listdir(EXAMPLES_DIR)):
309
+ if name.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
310
+ items.append({
311
+ "file": name,
312
+ "path": str(EXAMPLES_DIR / name),
313
+ "prompt": example_prompts.get(
314
+ name, "Edit this image while preserving composition."
315
+ ),
316
+ })
317
+ return items
318
+
319
+ EXAMPLE_ITEMS = get_example_items()
320
+
321
+ css = """
322
+ #col-container {
323
+ margin: 0 auto;
324
+ max-width: 1100px;
325
+ }
326
+ #main-title h1 {
327
+ font-size: 2.4em !important;
328
+ }
329
+ .vae-badge {
330
+ font-weight: 700;
331
+ font-size: 0.95em;
332
+ text-align: center;
333
+ padding: 4px 16px;
334
+ border-radius: 20px;
335
+ display: block;
336
+ margin-bottom: 6px;
337
+ }
338
+ """
339
 
340
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
+ with gr.Column(elem_id="col-container"):
343
+
344
+ gr.Markdown(
345
+ "# **Flux.2-4B-Decoder-Comparator**",
346
+ elem_id="main-title",
347
+ )
348
+ gr.Markdown(
349
+ "Compare **FLUX.2-klein-4B** side-by-side with "
350
+ "[small decoder](https://huggingface.co/black-forest-labs/FLUX.2-small-decoder)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  )
352
 
353
+ with gr.Row(equal_height=True):
354
+
355
+ with gr.Column():
356
+ input_images = gr.Gallery(
357
+ label="Input Images",
358
+ type="pil",
359
+ columns=2,
360
+ rows=1,
361
+ height=300,
362
+ allow_preview=True,
363
+ )
364
+
365
+ prompt = gr.Text(
366
+ label="Prompt",
367
+ max_lines=1,
368
+ show_label=True,
369
+ placeholder="e.g., A black cat holding a sign that says hello world...",
370
+ )
371
+
372
+ run_button = gr.Button("Run Comparison", variant="primary")
373
+
374
+ with gr.Column():
375
+ with gr.Row():
376
+ with gr.Column():
377
+ result_standard = gr.Image(
378
+ label="Standard Decoder",
379
+ show_label=True,
380
+ interactive=False,
381
+ format="png",
382
+ height=250,
383
+ )
384
+ with gr.Column():
385
+ result_small = gr.Image(
386
+ label="Small Decoder",
387
+ show_label=True,
388
+ interactive=False,
389
+ format="png",
390
+ height=250,
391
+ )
392
+
393
+ seed_output = gr.Number(label="Seed Used", precision=0, visible=False)
394
+
395
+ with gr.Accordion("Advanced Settings", open=False, visible=False):
396
+ seed = gr.Slider(
397
+ label="Seed",
398
+ minimum=0,
399
+ maximum=MAX_SEED,
400
+ step=1,
401
+ value=0,
402
+ )
403
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
404
+
405
+ with gr.Row():
406
+ width = gr.Slider(
407
+ label="Width",
408
+ minimum=256,
409
+ maximum=MAX_IMAGE_SIZE,
410
+ step=8,
411
+ value=1024,
412
+ )
413
+ height_slider = gr.Slider(
414
+ label="Height",
415
+ minimum=256,
416
+ maximum=MAX_IMAGE_SIZE,
417
+ step=8,
418
+ value=1024,
419
+ )
420
+
421
+ with gr.Row():
422
+ num_inference_steps = gr.Slider(
423
+ label="Inference Steps",
424
+ minimum=1,
425
+ maximum=20,
426
+ step=1,
427
+ value=4,
428
+ )
429
+ guidance_scale = gr.Slider(
430
+ label="Guidance Scale",
431
+ minimum=0.0,
432
+ maximum=10.0,
433
+ step=0.1,
434
+ value=1.0,
435
+ )
436
+
437
+ gr.Examples(
438
+ examples=[
439
+ [["examples/I1.jpg", "examples/I2.jpg"], "Make her wear these glasses in Image 2."],
440
+ [["examples/1.jpg"], "Change the weather to stormy."],
441
+ [["examples/2.jpg"], "Transform the scene into a snowy winter day while preserving the original subject identity, framing, and composition."],
442
+ [["examples/3.jpg"], "Relight the image with soft golden sunset lighting while keeping all structures and subject details consistent."],
443
+ [["examples/4.jpg"], "Make the texture high-resolution."],
444
+ ],
445
+ inputs=[input_images, prompt],
446
+ outputs=[result_standard, result_small, seed_output],
447
+ fn=infer_example,
448
+ cache_examples=False,
449
+ label="Examples",
450
+ )
451
+
452
+ gr.Markdown(
453
+ "[*](https://huggingface.co/black-forest-labs/FLUX.2-klein-4B) "
454
+ "Experimental Space — FLUX.2 [klein] 4B VAE Decoder Comparison."
455
+ )
456
+
457
+ input_images.upload(
458
+ fn=update_dimensions_from_image,
459
+ inputs=[input_images],
460
+ outputs=[width, height_slider],
461
+ )
462
+
463
+ gr.on(
464
+ triggers=[run_button.click, prompt.submit],
465
+ fn=infer,
466
+ inputs=[
467
+ prompt,
468
+ input_images,
469
+ seed,
470
+ randomize_seed,
471
+ width,
472
+ height_slider,
473
+ num_inference_steps,
474
+ guidance_scale,
475
+ ],
476
+ outputs=[result_standard, result_small, seed_output],
477
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
 
479
+ if __name__ == "__main__":
480
+ demo.queue(max_size=20).launch(
481
+ theme=orange_red_theme, css=css,
482
+ mcp_server=True,
483
+ ssr_mode=False,
484
+ show_error=True,
485
+ )