Nick088 commited on
Commit
7bfd773
·
verified ·
1 Parent(s): 9c05349

add theme, compare from 2 to 4 models, sd1.5, height & width diff per model

Browse files
Files changed (1) hide show
  1. app.py +824 -126
app.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
  from PIL import Image
8
  import spaces
9
 
10
- HF_TOKEN = os.getenv("HF_TOKEN") # login with hf token to access sd gated models
11
 
12
  if torch.cuda.is_available():
13
  device = "cuda"
@@ -20,25 +20,52 @@ else:
20
  MAX_SEED = np.iinfo(np.int32).max
21
 
22
  # Initialize the pipelines for each sd model
23
- sd3_medium_pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
 
 
 
 
24
  sd3_medium_pipe.enable_model_cpu_offload()
25
 
26
- sd2_1_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16)
 
 
 
27
  sd2_1_pipe.enable_model_cpu_offload()
28
 
29
- sdxl_pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
 
 
 
30
  sdxl_pipe.enable_model_cpu_offload()
31
 
32
- sdxl_flash_pipe = StableDiffusionXLPipeline.from_pretrained("sd-community/sdxl-flash", torch_dtype=torch.float16)
 
 
 
33
  sdxl_flash_pipe.enable_model_cpu_offload()
34
  # Ensure sampler uses "trailing" timesteps for sdxl flash.
35
- sdxl_flash_pipe.scheduler = DPMSolverSinglestepScheduler.from_config(sdxl_flash_pipe.scheduler.config, timestep_spacing="trailing")
 
 
36
 
37
- stable_cascade_prior_pipe = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16)
38
- stable_cascade_decoder_pipe = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
 
 
 
 
 
39
  stable_cascade_prior_pipe.enable_model_cpu_offload()
40
  stable_cascade_decoder_pipe.enable_model_cpu_offload()
41
 
 
 
 
 
 
 
 
42
  # Helper function to generate images for a single model
43
  @spaces.GPU(duration=80)
44
  def generate_single_image(
@@ -68,9 +95,12 @@ def generate_single_image(
68
  pipe = sdxl_flash_pipe
69
  elif model_choice == "stable cascade":
70
  pipe = stable_cascade_prior_pipe
 
 
71
  else:
72
  raise ValueError(f"Invalid model choice: {model_choice}")
73
 
 
74
  if model_choice == "stable cascade":
75
  prior_output = pipe(
76
  prompt=prompt,
@@ -90,6 +120,8 @@ def generate_single_image(
90
  num_inference_steps=decoder_num_inference_steps,
91
  guidance_scale=decoder_guidance_scale,
92
  ).images
 
 
93
  else:
94
  output = pipe(
95
  prompt=prompt,
@@ -104,21 +136,35 @@ def generate_single_image(
104
 
105
  return output
106
 
 
107
  # Define the image generation function for the Arena tab
108
  @spaces.GPU(duration=80)
109
  def generate_arena_images(
110
  prompt,
111
  negative_prompt,
 
112
  num_inference_steps_a,
113
  guidance_scale_a,
114
  num_inference_steps_b,
115
  guidance_scale_b,
116
- height,
117
- width,
 
 
 
 
 
 
 
 
 
 
118
  seed,
119
  num_images_per_prompt,
120
  model_choice_a,
121
  model_choice_b,
 
 
122
  prior_num_inference_steps_a,
123
  prior_guidance_scale_a,
124
  decoder_num_inference_steps_a,
@@ -127,48 +173,97 @@ def generate_arena_images(
127
  prior_guidance_scale_b,
128
  decoder_num_inference_steps_b,
129
  decoder_guidance_scale_b,
 
 
 
 
 
 
 
 
130
  progress=gr.Progress(track_tqdm=True),
131
  ):
132
  if seed == 0:
133
- seed = random.randint(1, 2**32 - 1)
134
 
135
  generator = torch.Generator().manual_seed(seed)
136
 
137
- # Generate images for both models
138
- images_a = generate_single_image(
139
- prompt,
140
- negative_prompt,
141
- num_inference_steps_a,
142
- guidance_scale_a,
143
- height,
144
- width,
145
- seed,
146
- num_images_per_prompt,
147
- model_choice_a,
148
- generator,
149
- prior_num_inference_steps_a,
150
- prior_guidance_scale_a,
151
- decoder_num_inference_steps_a,
152
- decoder_guidance_scale_a,
153
- )
154
- images_b = generate_single_image(
155
- prompt,
156
- negative_prompt,
157
- num_inference_steps_b,
158
- guidance_scale_b,
159
- height,
160
- width,
161
- seed,
162
- num_images_per_prompt,
163
- model_choice_b,
164
- generator,
165
- prior_num_inference_steps_b,
166
- prior_guidance_scale_b,
167
- decoder_num_inference_steps_b,
168
- decoder_guidance_scale_b,
169
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- return images_a, images_b
172
 
173
  # Define the image generation function for the Individual tab
174
  @spaces.GPU(duration=80)
@@ -189,7 +284,7 @@ def generate_individual_image(
189
  progress=gr.Progress(track_tqdm=True),
190
  ):
191
  if seed == 0:
192
- seed = random.randint(1, 2**32 - 1)
193
 
194
  generator = torch.Generator().manual_seed(seed)
195
 
@@ -213,51 +308,93 @@ def generate_individual_image(
213
  return output
214
 
215
 
216
- # Create the Gradio interface
217
  examples_arena = [
218
  [
219
  "A woman in a red dress singing on top of a building.",
220
  "deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
 
 
 
 
 
221
  25,
222
  7.5,
223
  25,
224
  7.5,
225
  1024,
226
  1024,
 
 
 
 
 
 
227
  42,
228
  2,
229
  "sd3 medium",
230
  "sdxl",
231
- 25, #prior_num_inference_steps_a
232
- 4.0, #prior_guidance_scale_a
233
- 12, #decoder_num_inference_steps_a
234
- 0.0, #decoder_guidance_scale_a
235
- 25, #prior_num_inference_steps_b
236
- 4.0, #prior_guidance_scale_b
237
- 12, #decoder_num_inference_steps_b
238
- 0.0 #decoder_guidance_scale_b
 
 
 
 
 
 
 
 
 
 
239
  ],
240
  [
241
  "An astronaut on mars in a futuristic cyborg suit.",
242
  "deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
 
 
 
 
 
243
  25,
244
  7.5,
245
  25,
246
  7.5,
247
  1024,
248
  1024,
 
 
 
 
 
 
249
  42,
250
  2,
251
  "sd3 medium",
252
  "sdxl",
253
- 25, #prior_num_inference_steps_a
254
- 4.0, #prior_guidance_scale_a
255
- 12, #decoder_num_inference_steps_a
256
- 0.0, #decoder_guidance_scale_a
257
- 25, #prior_num_inference_steps_b
258
- 4.0, #prior_guidance_scale_b
259
- 12, #decoder_num_inference_steps_b
260
- 0.0 #decoder_guidance_scale_b
 
 
 
 
 
 
 
 
 
 
261
  ],
262
  ]
263
  examples_individual = [
@@ -271,10 +408,10 @@ examples_individual = [
271
  42,
272
  2,
273
  "sdxl",
274
- 25, #prior_num_inference_steps
275
- 4.0, #prior_guidance_scale
276
- 12, #decoder_num_inference_steps
277
- 0.0 #decoder_guidance_scale
278
  ],
279
  [
280
  "An astronaut on mars in a futuristic cyborg suit.",
@@ -286,18 +423,42 @@ examples_individual = [
286
  42,
287
  2,
288
  "sdxl",
289
- 25, #prior_num_inference_steps
290
- 4.0, #prior_guidance_scale
291
- 12, #decoder_num_inference_steps
292
- 0.0 #decoder_guidance_scale
293
  ],
294
  ]
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  css = """
297
- .gradio-container{max-width: 1000px !important}
298
  h1{text-align:center}
 
 
 
 
 
 
299
  """
300
- with gr.Blocks(css=css) as demo:
 
301
  with gr.Row():
302
  with gr.Column():
303
  gr.HTML(
@@ -322,19 +483,78 @@ with gr.Blocks(css=css) as demo:
322
  info="Describe the image you want",
323
  placeholder="A cat...",
324
  )
 
 
 
 
 
325
  model_choice_a = gr.Dropdown(
326
  label="Stable Diffusion Model A",
327
- choices=["sd3 medium", "sd2.1", "sdxl", "sdxl flash", "stable cascade"],
 
 
 
 
 
 
 
328
  value="sd3 medium",
329
  )
330
  model_choice_b = gr.Dropdown(
331
  label="Stable Diffusion Model B",
332
- choices=["sd3 medium", "sd2.1", "sdxl", "sdxl flash", "stable cascade"],
 
 
 
 
 
 
 
333
  value="sdxl",
334
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  run_button = gr.Button("Run")
336
- result_1 = gr.Gallery(label="Generated Images (Model A)", elem_id="gallery_1")
337
- result_2 = gr.Gallery(label="Generated Images (Model B)", elem_id="gallery_2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  with gr.Accordion("Advanced options", open=False):
339
  negative_prompt = gr.Textbox(
340
  label="Negative Prompt",
@@ -351,7 +571,7 @@ with gr.Blocks(css=css) as demo:
351
  maximum=50,
352
  value=25,
353
  step=1,
354
- visible=True
355
  )
356
  guidance_scale_a = gr.Slider(
357
  label="Guidance Scale (Model A)",
@@ -360,7 +580,7 @@ with gr.Blocks(css=css) as demo:
360
  maximum=10.0,
361
  value=7.5,
362
  step=0.1,
363
- visible=True
364
  )
365
  prior_num_inference_steps_a = gr.Slider(
366
  label="Prior Inference Steps (Model A)",
@@ -369,7 +589,7 @@ with gr.Blocks(css=css) as demo:
369
  maximum=50,
370
  value=25,
371
  step=1,
372
- visible=False
373
  )
374
  prior_guidance_scale_a = gr.Slider(
375
  label="Prior Guidance Scale (Model A)",
@@ -378,7 +598,7 @@ with gr.Blocks(css=css) as demo:
378
  maximum=10.0,
379
  value=4.0,
380
  step=0.1,
381
- visible=False
382
  )
383
  decoder_num_inference_steps_a = gr.Slider(
384
  label="Decoder Inference Steps (Model A)",
@@ -387,7 +607,7 @@ with gr.Blocks(css=css) as demo:
387
  maximum=15,
388
  value=15,
389
  step=1,
390
- visible=False
391
  )
392
  decoder_guidance_scale_a = gr.Slider(
393
  label="Decoder Guidance Scale (Model A)",
@@ -396,7 +616,23 @@ with gr.Blocks(css=css) as demo:
396
  maximum=10.0,
397
  value=0.0,
398
  step=0.1,
399
- visible=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  )
401
  with gr.Column():
402
  num_inference_steps_b = gr.Slider(
@@ -406,7 +642,7 @@ with gr.Blocks(css=css) as demo:
406
  maximum=50,
407
  value=25,
408
  step=1,
409
- visible=True
410
  )
411
  guidance_scale_b = gr.Slider(
412
  label="Guidance Scale (Model B)",
@@ -415,7 +651,7 @@ with gr.Blocks(css=css) as demo:
415
  maximum=10.0,
416
  value=7.5,
417
  step=0.1,
418
- visible=True
419
  )
420
  prior_num_inference_steps_b = gr.Slider(
421
  label="Prior Inference Steps (Model B)",
@@ -424,7 +660,7 @@ with gr.Blocks(css=css) as demo:
424
  maximum=50,
425
  value=25,
426
  step=1,
427
- visible=False
428
  )
429
  prior_guidance_scale_b = gr.Slider(
430
  label="Prior Guidance Scale (Model B)",
@@ -433,7 +669,7 @@ with gr.Blocks(css=css) as demo:
433
  maximum=10.0,
434
  value=4.0,
435
  step=0.1,
436
- visible=False
437
  )
438
  decoder_num_inference_steps_b = gr.Slider(
439
  label="Decoder Inference Steps (Model B)",
@@ -442,7 +678,7 @@ with gr.Blocks(css=css) as demo:
442
  maximum=15,
443
  value=12,
444
  step=1,
445
- visible=False
446
  )
447
  decoder_guidance_scale_b = gr.Slider(
448
  label="Decoder Guidance Scale (Model B)",
@@ -451,25 +687,166 @@ with gr.Blocks(css=css) as demo:
451
  maximum=10.0,
452
  value=0.0,
453
  step=0.1,
454
- visible=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  )
456
- with gr.Row():
457
- width = gr.Slider(
458
- label="Width",
459
- info="Width of the Image",
460
- minimum=256,
461
- maximum=1344,
462
- step=32,
463
- value=1024,
464
- )
465
- height = gr.Slider(
466
- label="Height",
467
- info="Height of the Image",
468
- minimum=256,
469
- maximum=1344,
470
- step=32,
471
- value=1024,
472
- )
473
  with gr.Row():
474
  seed = gr.Slider(
475
  value=42,
@@ -507,6 +884,45 @@ with gr.Blocks(css=css) as demo:
507
  decoder_num_inference_steps_a: gr.update(visible=False),
508
  decoder_guidance_scale_a: gr.update(visible=False),
509
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  else:
511
  return {
512
  num_inference_steps_a: gr.update(visible=True, maximum=50, value=25),
@@ -515,6 +931,8 @@ with gr.Blocks(css=css) as demo:
515
  prior_guidance_scale_a: gr.update(visible=False),
516
  decoder_num_inference_steps_a: gr.update(visible=False),
517
  decoder_guidance_scale_a: gr.update(visible=False),
 
 
518
  }
519
 
520
  def toggle_visibility_arena_b(model_choice_b):
@@ -536,6 +954,28 @@ with gr.Blocks(css=css) as demo:
536
  decoder_num_inference_steps_b: gr.update(visible=False),
537
  decoder_guidance_scale_b: gr.update(visible=False),
538
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  else:
540
  return {
541
  num_inference_steps_b: gr.update(visible=True, maximum=50, value=25),
@@ -544,6 +984,114 @@ with gr.Blocks(css=css) as demo:
544
  prior_guidance_scale_b: gr.update(visible=False),
545
  decoder_num_inference_steps_b: gr.update(visible=False),
546
  decoder_guidance_scale_b: gr.update(visible=False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  }
548
 
549
  model_choice_a.change(
@@ -555,8 +1103,10 @@ with gr.Blocks(css=css) as demo:
555
  prior_num_inference_steps_a,
556
  prior_guidance_scale_a,
557
  decoder_num_inference_steps_a,
558
- decoder_guidance_scale_a
559
- ]
 
 
560
  )
561
  model_choice_b.change(
562
  toggle_visibility_arena_b,
@@ -567,26 +1117,110 @@ with gr.Blocks(css=css) as demo:
567
  prior_num_inference_steps_b,
568
  prior_guidance_scale_b,
569
  decoder_num_inference_steps_b,
570
- decoder_guidance_scale_b
571
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
  )
573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
  gr.Examples(
576
  examples=examples_arena,
577
  inputs=[
578
  prompt,
579
  negative_prompt,
 
580
  num_inference_steps_a,
581
  guidance_scale_a,
582
  num_inference_steps_b,
583
  guidance_scale_b,
584
- height,
585
- width,
 
 
 
 
 
 
 
 
 
 
586
  seed,
587
  num_images_per_prompt,
588
  model_choice_a,
589
  model_choice_b,
 
 
590
  prior_num_inference_steps_a,
591
  prior_guidance_scale_a,
592
  decoder_num_inference_steps_a,
@@ -595,8 +1229,16 @@ with gr.Blocks(css=css) as demo:
595
  prior_guidance_scale_b,
596
  decoder_num_inference_steps_b,
597
  decoder_guidance_scale_b,
 
 
 
 
 
 
 
 
598
  ],
599
- outputs=[result_1, result_2],
600
  fn=generate_arena_images,
601
  )
602
 
@@ -609,16 +1251,29 @@ with gr.Blocks(css=css) as demo:
609
  inputs=[
610
  prompt,
611
  negative_prompt,
 
612
  num_inference_steps_a,
613
  guidance_scale_a,
614
  num_inference_steps_b,
615
  guidance_scale_b,
616
- height,
617
- width,
 
 
 
 
 
 
 
 
 
 
618
  seed,
619
  num_images_per_prompt,
620
  model_choice_a,
621
  model_choice_b,
 
 
622
  prior_num_inference_steps_a,
623
  prior_guidance_scale_a,
624
  decoder_num_inference_steps_a,
@@ -627,8 +1282,16 @@ with gr.Blocks(css=css) as demo:
627
  prior_guidance_scale_b,
628
  decoder_num_inference_steps_b,
629
  decoder_guidance_scale_b,
 
 
 
 
 
 
 
 
630
  ],
631
- outputs=[result_1, result_2],
632
  )
633
 
634
  with gr.TabItem("Individual"):
@@ -641,11 +1304,20 @@ with gr.Blocks(css=css) as demo:
641
  )
642
  model_choice = gr.Dropdown(
643
  label="Stable Diffusion Model",
644
- choices=["sd3 medium", "sd2.1", "sdxl", "sdxl flash", "stable cascade"],
 
 
 
 
 
 
 
645
  value="sd3 medium",
646
  )
647
  run_button = gr.Button("Run")
648
- result = gr.Gallery(label="Generated AI Images", elem_id="gallery")
 
 
649
  with gr.Accordion("Advanced options", open=False):
650
  with gr.Row():
651
  negative_prompt = gr.Textbox(
@@ -662,7 +1334,7 @@ with gr.Blocks(css=css) as demo:
662
  maximum=50,
663
  value=25,
664
  step=1,
665
- visible=True
666
  )
667
  guidance_scale = gr.Slider(
668
  label="Guidance Scale",
@@ -671,7 +1343,7 @@ with gr.Blocks(css=css) as demo:
671
  maximum=10.0,
672
  value=7.5,
673
  step=0.1,
674
- visible=True
675
  )
676
  prior_num_inference_steps = gr.Slider(
677
  label="Prior Inference Steps",
@@ -680,7 +1352,7 @@ with gr.Blocks(css=css) as demo:
680
  maximum=50,
681
  value=25,
682
  step=1,
683
- visible=False
684
  )
685
  prior_guidance_scale = gr.Slider(
686
  label="Prior Guidance Scale",
@@ -689,7 +1361,7 @@ with gr.Blocks(css=css) as demo:
689
  maximum=10.0,
690
  value=4.0,
691
  step=0.1,
692
- visible=False
693
  )
694
  decoder_num_inference_steps = gr.Slider(
695
  label="Decoder Inference Steps",
@@ -698,7 +1370,7 @@ with gr.Blocks(css=css) as demo:
698
  maximum=15,
699
  value=12,
700
  step=1,
701
- visible=False
702
  )
703
  decoder_guidance_scale = gr.Slider(
704
  label="Decoder Guidance Scale",
@@ -707,7 +1379,7 @@ with gr.Blocks(css=css) as demo:
707
  maximum=10.0,
708
  value=0.0,
709
  step=0.1,
710
- visible=False
711
  )
712
  with gr.Row():
713
  width = gr.Slider(
@@ -763,6 +1435,28 @@ with gr.Blocks(css=css) as demo:
763
  decoder_num_inference_steps: gr.update(visible=False),
764
  decoder_guidance_scale: gr.update(visible=False),
765
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
766
  else:
767
  return {
768
  num_inference_steps: gr.update(visible=True, maximum=50, value=25),
@@ -771,6 +1465,8 @@ with gr.Blocks(css=css) as demo:
771
  prior_guidance_scale: gr.update(visible=False),
772
  decoder_num_inference_steps: gr.update(visible=False),
773
  decoder_guidance_scale: gr.update(visible=False),
 
 
774
  }
775
 
776
  model_choice.change(
@@ -782,8 +1478,10 @@ with gr.Blocks(css=css) as demo:
782
  prior_num_inference_steps,
783
  prior_guidance_scale,
784
  decoder_num_inference_steps,
785
- decoder_guidance_scale
786
- ]
 
 
787
  )
788
 
789
  gr.Examples(
 
7
  from PIL import Image
8
  import spaces
9
 
10
+ HF_TOKEN = os.getenv("HF_TOKEN") # login with hf token to access sd gated models
11
 
12
  if torch.cuda.is_available():
13
  device = "cuda"
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
 
22
  # Initialize the pipelines for each sd model
23
+
24
+ # sd3 medium
25
+ sd3_medium_pipe = StableDiffusion3Pipeline.from_pretrained(
26
+ "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
27
+ )
28
  sd3_medium_pipe.enable_model_cpu_offload()
29
 
30
+ # sd 2.1
31
+ sd2_1_pipe = StableDiffusionPipeline.from_pretrained(
32
+ "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
33
+ )
34
  sd2_1_pipe.enable_model_cpu_offload()
35
 
36
+ # sdxl
37
+ sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(
38
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
39
+ )
40
  sdxl_pipe.enable_model_cpu_offload()
41
 
42
+ # sdxl flash
43
+ sdxl_flash_pipe = StableDiffusionXLPipeline.from_pretrained(
44
+ "sd-community/sdxl-flash", torch_dtype=torch.float16
45
+ )
46
  sdxl_flash_pipe.enable_model_cpu_offload()
47
  # Ensure sampler uses "trailing" timesteps for sdxl flash.
48
+ sdxl_flash_pipe.scheduler = DPMSolverSinglestepScheduler.from_config(
49
+ sdxl_flash_pipe.scheduler.config, timestep_spacing="trailing"
50
+ )
51
 
52
+ # stable cascade
53
+ stable_cascade_prior_pipe = StableCascadePriorPipeline.from_pretrained(
54
+ "stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16
55
+ )
56
+ stable_cascade_decoder_pipe = StableCascadeDecoderPipeline.from_pretrained(
57
+ "stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16
58
+ )
59
  stable_cascade_prior_pipe.enable_model_cpu_offload()
60
  stable_cascade_decoder_pipe.enable_model_cpu_offload()
61
 
62
+ # sd 1.5
63
+ sd1_5_pipe = StableDiffusionPipeline.from_pretrained(
64
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
65
+ )
66
+ sd1_5_pipe.enable_model_cpu_offload()
67
+
68
+
69
  # Helper function to generate images for a single model
70
  @spaces.GPU(duration=80)
71
  def generate_single_image(
 
95
  pipe = sdxl_flash_pipe
96
  elif model_choice == "stable cascade":
97
  pipe = stable_cascade_prior_pipe
98
+ elif model_choice == "sd1.5":
99
+ pipe = sd1_5_pipe
100
  else:
101
  raise ValueError(f"Invalid model choice: {model_choice}")
102
 
103
+ # stable cascade has 2 different type of pipelines
104
  if model_choice == "stable cascade":
105
  prior_output = pipe(
106
  prompt=prompt,
 
120
  num_inference_steps=decoder_num_inference_steps,
121
  guidance_scale=decoder_guidance_scale,
122
  ).images
123
+
124
+ # the rest of the models have similar pipeline
125
  else:
126
  output = pipe(
127
  prompt=prompt,
 
136
 
137
  return output
138
 
139
+
140
  # Define the image generation function for the Arena tab
141
  @spaces.GPU(duration=80)
142
  def generate_arena_images(
143
  prompt,
144
  negative_prompt,
145
+ num_models_to_compare,
146
  num_inference_steps_a,
147
  guidance_scale_a,
148
  num_inference_steps_b,
149
  guidance_scale_b,
150
+ num_inference_steps_c,
151
+ guidance_scale_c,
152
+ num_inference_steps_d,
153
+ guidance_scale_d,
154
+ height_a,
155
+ width_a,
156
+ height_b,
157
+ width_b,
158
+ height_c,
159
+ width_c,
160
+ height_d,
161
+ width_d,
162
  seed,
163
  num_images_per_prompt,
164
  model_choice_a,
165
  model_choice_b,
166
+ model_choice_c,
167
+ model_choice_d,
168
  prior_num_inference_steps_a,
169
  prior_guidance_scale_a,
170
  decoder_num_inference_steps_a,
 
173
  prior_guidance_scale_b,
174
  decoder_num_inference_steps_b,
175
  decoder_guidance_scale_b,
176
+ prior_num_inference_steps_c,
177
+ prior_guidance_scale_c,
178
+ decoder_num_inference_steps_c,
179
+ decoder_guidance_scale_c,
180
+ prior_num_inference_steps_d,
181
+ prior_guidance_scale_d,
182
+ decoder_num_inference_steps_d,
183
+ decoder_guidance_scale_d,
184
  progress=gr.Progress(track_tqdm=True),
185
  ):
186
  if seed == 0:
187
+ seed = random.randint(1, MAX_SEED)
188
 
189
  generator = torch.Generator().manual_seed(seed)
190
 
191
+ # Generate images for selected models
192
+ images = []
193
+ if num_models_to_compare >= 2:
194
+ images_a = generate_single_image(
195
+ prompt,
196
+ negative_prompt,
197
+ num_inference_steps_a,
198
+ guidance_scale_a,
199
+ height_a,
200
+ width_a,
201
+ seed,
202
+ num_images_per_prompt,
203
+ model_choice_a,
204
+ generator,
205
+ prior_num_inference_steps_a,
206
+ prior_guidance_scale_a,
207
+ decoder_num_inference_steps_a,
208
+ decoder_guidance_scale_a,
209
+ )
210
+ images.append(images_a)
211
+ images_b = generate_single_image(
212
+ prompt,
213
+ negative_prompt,
214
+ num_inference_steps_b,
215
+ guidance_scale_b,
216
+ height_b,
217
+ width_b,
218
+ seed,
219
+ num_images_per_prompt,
220
+ model_choice_b,
221
+ generator,
222
+ prior_num_inference_steps_b,
223
+ prior_guidance_scale_b,
224
+ decoder_num_inference_steps_b,
225
+ decoder_guidance_scale_b,
226
+ )
227
+ images.append(images_b)
228
+ if num_models_to_compare >= 3:
229
+ images_c = generate_single_image(
230
+ prompt,
231
+ negative_prompt,
232
+ num_inference_steps_c,
233
+ guidance_scale_c,
234
+ height_c,
235
+ width_c,
236
+ seed,
237
+ num_images_per_prompt,
238
+ model_choice_c,
239
+ generator,
240
+ prior_num_inference_steps_c,
241
+ prior_guidance_scale_c,
242
+ decoder_num_inference_steps_c,
243
+ decoder_guidance_scale_c,
244
+ )
245
+ images.append(images_c)
246
+ if num_models_to_compare >= 4:
247
+ images_d = generate_single_image(
248
+ prompt,
249
+ negative_prompt,
250
+ num_inference_steps_d,
251
+ guidance_scale_d,
252
+ height_d,
253
+ width_d,
254
+ seed,
255
+ num_images_per_prompt,
256
+ model_choice_d,
257
+ generator,
258
+ prior_num_inference_steps_d,
259
+ prior_guidance_scale_d,
260
+ decoder_num_inference_steps_d,
261
+ decoder_guidance_scale_d,
262
+ )
263
+ images.append(images_d)
264
+
265
+ return images
266
 
 
267
 
268
  # Define the image generation function for the Individual tab
269
  @spaces.GPU(duration=80)
 
284
  progress=gr.Progress(track_tqdm=True),
285
  ):
286
  if seed == 0:
287
+ seed = random.randint(1, MAX_SEED)
288
 
289
  generator = torch.Generator().manual_seed(seed)
290
 
 
308
  return output
309
 
310
 
311
+ # Gradio interface
312
  examples_arena = [
313
  [
314
  "A woman in a red dress singing on top of a building.",
315
  "deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
316
+ 2, # num_models_to_compare
317
+ 25,
318
+ 7.5,
319
+ 25,
320
+ 7.5,
321
  25,
322
  7.5,
323
  25,
324
  7.5,
325
  1024,
326
  1024,
327
+ 1024,
328
+ 1024,
329
+ 1024,
330
+ 1024,
331
+ 1024,
332
+ 1024,
333
  42,
334
  2,
335
  "sd3 medium",
336
  "sdxl",
337
+ "sd3 medium",
338
+ "sdxl",
339
+ 25, # prior_num_inference_steps_a
340
+ 4.0, # prior_guidance_scale_a
341
+ 12, # decoder_num_inference_steps_a
342
+ 0.0, # decoder_guidance_scale_a
343
+ 25, # prior_num_inference_steps_b
344
+ 4.0, # prior_guidance_scale_b
345
+ 12, # decoder_num_inference_steps_b
346
+ 0.0, # decoder_guidance_scale_b
347
+ 25, # prior_num_inference_steps_c
348
+ 4.0, # prior_guidance_scale_c
349
+ 12, # decoder_num_inference_steps_c
350
+ 0.0, # decoder_guidance_scale_c
351
+ 25, # prior_num_inference_steps_d
352
+ 4.0, # prior_guidance_scale_d
353
+ 12, # decoder_num_inference_steps_d
354
+ 0.0, # decoder_guidance_scale_d
355
  ],
356
  [
357
  "An astronaut on mars in a futuristic cyborg suit.",
358
  "deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
359
+ 2, # num_models_to_compare
360
+ 25,
361
+ 7.5,
362
+ 25,
363
+ 7.5,
364
  25,
365
  7.5,
366
  25,
367
  7.5,
368
  1024,
369
  1024,
370
+ 1024,
371
+ 1024,
372
+ 1024,
373
+ 1024,
374
+ 1024,
375
+ 1024,
376
  42,
377
  2,
378
  "sd3 medium",
379
  "sdxl",
380
+ "sd3 medium",
381
+ "sdxl",
382
+ 25, # prior_num_inference_steps_a
383
+ 4.0, # prior_guidance_scale_a
384
+ 12, # decoder_num_inference_steps_a
385
+ 0.0, # decoder_guidance_scale_a
386
+ 25, # prior_num_inference_steps_b
387
+ 4.0, # prior_guidance_scale_b
388
+ 12, # decoder_num_inference_steps_b
389
+ 0.0, # decoder_guidance_scale_b
390
+ 25, # prior_num_inference_steps_c
391
+ 4.0, # prior_guidance_scale_c
392
+ 12, # decoder_num_inference_steps_c
393
+ 0.0, # decoder_guidance_scale_c
394
+ 25, # prior_num_inference_steps_d
395
+ 4.0, # prior_guidance_scale_d
396
+ 12, # decoder_num_inference_steps_d
397
+ 0.0, # decoder_guidance_scale_d
398
  ],
399
  ]
400
  examples_individual = [
 
408
  42,
409
  2,
410
  "sdxl",
411
+ 25, # prior_num_inference_steps
412
+ 4.0, # prior_guidance_scale
413
+ 12, # decoder_num_inference_steps
414
+ 0.0, # decoder_guidance_scale
415
  ],
416
  [
417
  "An astronaut on mars in a futuristic cyborg suit.",
 
423
  42,
424
  2,
425
  "sdxl",
426
+ 25, # prior_num_inference_steps
427
+ 4.0, # prior_guidance_scale
428
+ 12, # decoder_num_inference_steps
429
+ 0.0, # decoder_guidance_scale
430
  ],
431
  ]
432
 
433
+ theme = gr.themes.Soft(
434
+ primary_hue="emerald",
435
+ secondary_hue="blue",
436
+ ).set(
437
+ border_color_primary='*neutral_300',
438
+ block_border_width='1px',
439
+ block_border_width_dark='1px',
440
+ block_title_border_color='*secondary_100',
441
+ block_title_border_color_dark='*secondary_200',
442
+ input_background_fill_focus='*secondary_300',
443
+ input_border_color_focus='*secondary_500',
444
+ input_border_width='1px',
445
+ input_border_width_dark='1px',
446
+ slider_color='*secondary_500',
447
+ slider_color_dark='*secondary_600'
448
+ )
449
+
450
  css = """
451
+ .gradio-container{max-width: 1400px !important}
452
  h1{text-align:center}
453
+ .extra-option {
454
+ display: none;
455
+ }
456
+ .extra-option.visible {
457
+ display: block;
458
+ }
459
  """
460
+
461
+ with gr.Blocks(theme=theme, css=css) as demo:
462
  with gr.Row():
463
  with gr.Column():
464
  gr.HTML(
 
483
  info="Describe the image you want",
484
  placeholder="A cat...",
485
  )
486
+ num_models_to_compare = gr.Dropdown(
487
+ label="How many models to compare",
488
+ choices=[2, 3, 4],
489
+ value=2,
490
+ )
491
  model_choice_a = gr.Dropdown(
492
  label="Stable Diffusion Model A",
493
+ choices=[
494
+ "sd3 medium",
495
+ "sd2.1",
496
+ "sdxl",
497
+ "sdxl flash",
498
+ "stable cascade",
499
+ "sd1.5",
500
+ ],
501
  value="sd3 medium",
502
  )
503
  model_choice_b = gr.Dropdown(
504
  label="Stable Diffusion Model B",
505
+ choices=[
506
+ "sd3 medium",
507
+ "sd2.1",
508
+ "sdxl",
509
+ "sdxl flash",
510
+ "stable cascade",
511
+ "sd1.5",
512
+ ],
513
  value="sdxl",
514
  )
515
+ model_choice_c = gr.Dropdown(
516
+ label="Stable Diffusion Model C",
517
+ choices=[
518
+ "sd3 medium",
519
+ "sd2.1",
520
+ "sdxl",
521
+ "sdxl flash",
522
+ "stable cascade",
523
+ "sd1.5",
524
+ ],
525
+ value="sdxl flash",
526
+ visible=False,
527
+ )
528
+ model_choice_d = gr.Dropdown(
529
+ label="Stable Diffusion Model D",
530
+ choices=[
531
+ "sd3 medium",
532
+ "sd2.1",
533
+ "sdxl",
534
+ "sdxl flash",
535
+ "stable cascade",
536
+ "sd1.5",
537
+ ],
538
+ value="sd2.1",
539
+ visible=False,
540
+ )
541
  run_button = gr.Button("Run")
542
+ result_1 = gr.Gallery(
543
+ label="Generated Images (Model A)", elem_id="gallery_1"
544
+ )
545
+ result_2 = gr.Gallery(
546
+ label="Generated Images (Model B)", elem_id="gallery_2"
547
+ )
548
+ result_3 = gr.Gallery(
549
+ label="Generated Images (Model C)",
550
+ elem_id="gallery_3",
551
+ visible=False,
552
+ )
553
+ result_4 = gr.Gallery(
554
+ label="Generated Images (Model D)",
555
+ elem_id="gallery_4",
556
+ visible=False,
557
+ )
558
  with gr.Accordion("Advanced options", open=False):
559
  negative_prompt = gr.Textbox(
560
  label="Negative Prompt",
 
571
  maximum=50,
572
  value=25,
573
  step=1,
574
+ visible=True,
575
  )
576
  guidance_scale_a = gr.Slider(
577
  label="Guidance Scale (Model A)",
 
580
  maximum=10.0,
581
  value=7.5,
582
  step=0.1,
583
+ visible=True,
584
  )
585
  prior_num_inference_steps_a = gr.Slider(
586
  label="Prior Inference Steps (Model A)",
 
589
  maximum=50,
590
  value=25,
591
  step=1,
592
+ visible=False,
593
  )
594
  prior_guidance_scale_a = gr.Slider(
595
  label="Prior Guidance Scale (Model A)",
 
598
  maximum=10.0,
599
  value=4.0,
600
  step=0.1,
601
+ visible=False,
602
  )
603
  decoder_num_inference_steps_a = gr.Slider(
604
  label="Decoder Inference Steps (Model A)",
 
607
  maximum=15,
608
  value=15,
609
  step=1,
610
+ visible=False,
611
  )
612
  decoder_guidance_scale_a = gr.Slider(
613
  label="Decoder Guidance Scale (Model A)",
 
616
  maximum=10.0,
617
  value=0.0,
618
  step=0.1,
619
+ visible=False,
620
+ )
621
+ width_a = gr.Slider(
622
+ label="Width (Model A)",
623
+ info="Width of the Image",
624
+ minimum=256,
625
+ maximum=1344,
626
+ step=32,
627
+ value=1024,
628
+ )
629
+ height_a = gr.Slider(
630
+ label="Height (Model A)",
631
+ info="Height of the Image",
632
+ minimum=256,
633
+ maximum=1344,
634
+ step=32,
635
+ value=1024,
636
  )
637
  with gr.Column():
638
  num_inference_steps_b = gr.Slider(
 
642
  maximum=50,
643
  value=25,
644
  step=1,
645
+ visible=True,
646
  )
647
  guidance_scale_b = gr.Slider(
648
  label="Guidance Scale (Model B)",
 
651
  maximum=10.0,
652
  value=7.5,
653
  step=0.1,
654
+ visible=True,
655
  )
656
  prior_num_inference_steps_b = gr.Slider(
657
  label="Prior Inference Steps (Model B)",
 
660
  maximum=50,
661
  value=25,
662
  step=1,
663
+ visible=False,
664
  )
665
  prior_guidance_scale_b = gr.Slider(
666
  label="Prior Guidance Scale (Model B)",
 
669
  maximum=10.0,
670
  value=4.0,
671
  step=0.1,
672
+ visible=False,
673
  )
674
  decoder_num_inference_steps_b = gr.Slider(
675
  label="Decoder Inference Steps (Model B)",
 
678
  maximum=15,
679
  value=12,
680
  step=1,
681
+ visible=False,
682
  )
683
  decoder_guidance_scale_b = gr.Slider(
684
  label="Decoder Guidance Scale (Model B)",
 
687
  maximum=10.0,
688
  value=0.0,
689
  step=0.1,
690
+ visible=False,
691
+ )
692
+ width_b = gr.Slider(
693
+ label="Width (Model B)",
694
+ info="Width of the Image",
695
+ minimum=256,
696
+ maximum=1344,
697
+ step=32,
698
+ value=1024,
699
+ )
700
+ height_b = gr.Slider(
701
+ label="Height (Model B)",
702
+ info="Height of the Image",
703
+ minimum=256,
704
+ maximum=1344,
705
+ step=32,
706
+ value=1024,
707
+ )
708
+ with gr.Column(visible=False) as model_c_options:
709
+ num_inference_steps_c = gr.Slider(
710
+ label="Inference Steps (Model C)",
711
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
712
+ minimum=1,
713
+ maximum=50,
714
+ value=25,
715
+ step=1,
716
+ visible=True,
717
+ )
718
+ guidance_scale_c = gr.Slider(
719
+ label="Guidance Scale (Model C)",
720
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
721
+ minimum=0.0,
722
+ maximum=10.0,
723
+ value=7.5,
724
+ step=0.1,
725
+ visible=True,
726
+ )
727
+ prior_num_inference_steps_c = gr.Slider(
728
+ label="Prior Inference Steps (Model C)",
729
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
730
+ minimum=1,
731
+ maximum=50,
732
+ value=25,
733
+ step=1,
734
+ visible=False,
735
+ )
736
+ prior_guidance_scale_c = gr.Slider(
737
+ label="Prior Guidance Scale (Model C)",
738
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
739
+ minimum=0.0,
740
+ maximum=10.0,
741
+ value=4.0,
742
+ step=0.1,
743
+ visible=False,
744
+ )
745
+ decoder_num_inference_steps_c = gr.Slider(
746
+ label="Decoder Inference Steps (Model C)",
747
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
748
+ minimum=1,
749
+ maximum=15,
750
+ value=12,
751
+ step=1,
752
+ visible=False,
753
+ )
754
+ decoder_guidance_scale_c = gr.Slider(
755
+ label="Decoder Guidance Scale (Model C)",
756
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
757
+ minimum=0.0,
758
+ maximum=10.0,
759
+ value=0.0,
760
+ step=0.1,
761
+ visible=False,
762
+ )
763
+ width_c = gr.Slider(
764
+ label="Width (Model C)",
765
+ info="Width of the Image",
766
+ minimum=256,
767
+ maximum=1344,
768
+ step=32,
769
+ value=1024,
770
+ )
771
+ height_c = gr.Slider(
772
+ label="Height (Model C)",
773
+ info="Height of the Image",
774
+ minimum=256,
775
+ maximum=1344,
776
+ step=32,
777
+ value=1024,
778
+ )
779
+ with gr.Column(visible=False) as model_d_options:
780
+ num_inference_steps_d = gr.Slider(
781
+ label="Inference Steps (Model D)",
782
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
783
+ minimum=1,
784
+ maximum=50,
785
+ value=25,
786
+ step=1,
787
+ visible=True,
788
+ )
789
+ guidance_scale_d = gr.Slider(
790
+ label="Guidance Scale (Model D)",
791
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
792
+ minimum=0.0,
793
+ maximum=10.0,
794
+ value=7.5,
795
+ step=0.1,
796
+ visible=True,
797
+ )
798
+ prior_num_inference_steps_d = gr.Slider(
799
+ label="Prior Inference Steps (Model D)",
800
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
801
+ minimum=1,
802
+ maximum=50,
803
+ value=25,
804
+ step=1,
805
+ visible=False,
806
+ )
807
+ prior_guidance_scale_d = gr.Slider(
808
+ label="Prior Guidance Scale (Model D)",
809
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
810
+ minimum=0.0,
811
+ maximum=10.0,
812
+ value=4.0,
813
+ step=0.1,
814
+ visible=False,
815
+ )
816
+ decoder_num_inference_steps_d = gr.Slider(
817
+ label="Decoder Inference Steps (Model D)",
818
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
819
+ minimum=1,
820
+ maximum=15,
821
+ value=12,
822
+ step=1,
823
+ visible=False,
824
+ )
825
+ decoder_guidance_scale_d = gr.Slider(
826
+ label="Decoder Guidance Scale (Model D)",
827
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
828
+ minimum=0.0,
829
+ maximum=10.0,
830
+ value=0.0,
831
+ step=0.1,
832
+ visible=False,
833
+ )
834
+ width_d = gr.Slider(
835
+ label="Width (Model D)",
836
+ info="Width of the Image",
837
+ minimum=256,
838
+ maximum=1344,
839
+ step=32,
840
+ value=1024,
841
+ )
842
+ height_d = gr.Slider(
843
+ label="Height (Model D)",
844
+ info="Height of the Image",
845
+ minimum=256,
846
+ maximum=1344,
847
+ step=32,
848
+ value=1024,
849
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850
  with gr.Row():
851
  seed = gr.Slider(
852
  value=42,
 
884
  decoder_num_inference_steps_a: gr.update(visible=False),
885
  decoder_guidance_scale_a: gr.update(visible=False),
886
  }
887
+ elif model_choice_a == "sd1.5":
888
+ return {
889
+ num_inference_steps_a: gr.update(visible=True, maximum=50, value=25),
890
+ guidance_scale_a: gr.update(visible=True, maximum=10.0, value=7.5),
891
+ prior_guidance_scale_a: gr.update(visible=True),
892
+ decoder_num_inference_steps_a: gr.update(visible=True),
893
+ decoder_guidance_scale_a: gr.update(visible=True),
894
+ }
895
+ elif model_choice_a == "sdxl flash":
896
+ return {
897
+ num_inference_steps_a: gr.update(visible=True, maximum=15, value=8),
898
+ guidance_scale_a: gr.update(visible=True, maximum=6.0, value=3.5),
899
+ prior_num_inference_steps_a: gr.update(visible=False),
900
+ prior_guidance_scale_a: gr.update(visible=False),
901
+ decoder_num_inference_steps_a: gr.update(visible=False),
902
+ decoder_guidance_scale_a: gr.update(visible=False),
903
+ }
904
+ elif model_choice_a == "sd1.5":
905
+ return {
906
+ num_inference_steps_a: gr.update(visible=True, maximum=50, value=25),
907
+ guidance_scale_a: gr.update(visible=True, maximum=10.0, value=7.5),
908
+ prior_num_inference_steps_a: gr.update(visible=False),
909
+ prior_guidance_scale_a: gr.update(visible=False),
910
+ decoder_num_inference_steps_a: gr.update(visible=False),
911
+ decoder_guidance_scale_a: gr.update(visible=False),
912
+ width_a: gr.update(value=512, maximum=768),
913
+ height_a: gr.update(value=512, maximum=768),
914
+ }
915
+ elif model_choice_a == "sd2.1":
916
+ return {
917
+ num_inference_steps_a: gr.update(visible=True, maximum=50, value=25),
918
+ guidance_scale_a: gr.update(visible=True, maximum=10.0, value=7.5),
919
+ prior_num_inference_steps_a: gr.update(visible=False),
920
+ prior_guidance_scale_a: gr.update(visible=False),
921
+ decoder_num_inference_steps_a: gr.update(visible=False),
922
+ decoder_guidance_scale_a: gr.update(visible=False),
923
+ width_a: gr.update(value=768, maximum=1024),
924
+ height_a: gr.update(value=768, maximum=1024),
925
+ }
926
  else:
927
  return {
928
  num_inference_steps_a: gr.update(visible=True, maximum=50, value=25),
 
931
  prior_guidance_scale_a: gr.update(visible=False),
932
  decoder_num_inference_steps_a: gr.update(visible=False),
933
  decoder_guidance_scale_a: gr.update(visible=False),
934
+ width_a: gr.update(maximum=1344),
935
+ height_a: gr.update(maximum=1344),
936
  }
937
 
938
  def toggle_visibility_arena_b(model_choice_b):
 
954
  decoder_num_inference_steps_b: gr.update(visible=False),
955
  decoder_guidance_scale_b: gr.update(visible=False),
956
  }
957
+ elif model_choice_b == "sd1.5":
958
+ return {
959
+ num_inference_steps_b: gr.update(visible=True, maximum=50, value=25),
960
+ guidance_scale_b: gr.update(visible=True, maximum=10.0, value=7.5),
961
+ prior_num_inference_steps_b: gr.update(visible=False),
962
+ prior_guidance_scale_b: gr.update(visible=False),
963
+ decoder_num_inference_steps_b: gr.update(visible=False),
964
+ decoder_guidance_scale_b: gr.update(visible=False),
965
+ width_b: gr.update(value=512, maximum=768),
966
+ height_b: gr.update(value=512, maximum=768),
967
+ }
968
+ elif model_choice_b == "sd2.1":
969
+ return {
970
+ num_inference_steps_b: gr.update(visible=True, maximum=50, value=25),
971
+ guidance_scale_b: gr.update(visible=True, maximum=10.0, value=7.5),
972
+ prior_num_inference_steps_b: gr.update(visible=False),
973
+ prior_guidance_scale_b: gr.update(visible=False),
974
+ decoder_num_inference_steps_b: gr.update(visible=False),
975
+ decoder_guidance_scale_b: gr.update(visible=False),
976
+ width_b: gr.update(value=768, maximum=1024),
977
+ height_b: gr.update(value=768, maximum=1024),
978
+ }
979
  else:
980
  return {
981
  num_inference_steps_b: gr.update(visible=True, maximum=50, value=25),
 
984
  prior_guidance_scale_b: gr.update(visible=False),
985
  decoder_num_inference_steps_b: gr.update(visible=False),
986
  decoder_guidance_scale_b: gr.update(visible=False),
987
+ width_b: gr.update(maximum=1344),
988
+ height_b: gr.update(maximum=1344),
989
+ }
990
+
991
+ def toggle_visibility_arena_c(model_choice_c):
992
+ if model_choice_c == "stable cascade":
993
+ return {
994
+ num_inference_steps_c: gr.update(visible=False),
995
+ guidance_scale_c: gr.update(visible=False),
996
+ prior_num_inference_steps_c: gr.update(visible=True),
997
+ prior_guidance_scale_c: gr.update(visible=True),
998
+ decoder_num_inference_steps_c: gr.update(visible=True),
999
+ decoder_guidance_scale_c: gr.update(visible=True),
1000
+ }
1001
+ elif model_choice_c == "sdxl flash":
1002
+ return {
1003
+ num_inference_steps_c: gr.update(visible=True, maximum=15, value=8),
1004
+ guidance_scale_c: gr.update(visible=True, maximum=6.0, value=3.5),
1005
+ prior_num_inference_steps_c: gr.update(visible=False),
1006
+ prior_guidance_scale_c: gr.update(visible=False),
1007
+ decoder_num_inference_steps_c: gr.update(visible=False),
1008
+ decoder_guidance_scale_c: gr.update(visible=False),
1009
+ }
1010
+ elif model_choice_c == "sd1.5":
1011
+ return {
1012
+ num_inference_steps_c: gr.update(visible=True, maximum=50, value=25),
1013
+ guidance_scale_c: gr.update(visible=True, maximum=10.0, value=7.5),
1014
+ prior_num_inference_steps_c: gr.update(visible=False),
1015
+ prior_guidance_scale_c: gr.update(visible=False),
1016
+ decoder_num_inference_steps_c: gr.update(visible=False),
1017
+ decoder_guidance_scale_c: gr.update(visible=False),
1018
+ width_c: gr.update(value=512, maximum=768),
1019
+ height_c: gr.update(value=512, maximum=768),
1020
+ }
1021
+ elif model_choice_c == "sd2.1":
1022
+ return {
1023
+ num_inference_steps_c: gr.update(visible=True, maximum=50, value=25),
1024
+ guidance_scale_c: gr.update(visible=True, maximum=10.0, value=7.5),
1025
+ prior_num_inference_steps_c: gr.update(visible=False),
1026
+ prior_guidance_scale_c: gr.update(visible=False),
1027
+ decoder_num_inference_steps_c: gr.update(visible=False),
1028
+ decoder_guidance_scale_c: gr.update(visible=False),
1029
+ width_c: gr.update(value=768, maximum=1024),
1030
+ height_c: gr.update(value=768, maximum=1024),
1031
+ }
1032
+ else:
1033
+ return {
1034
+ num_inference_steps_c: gr.update(visible=True, maximum=50, value=25),
1035
+ guidance_scale_c: gr.update(visible=True, maximum=10.0, value=7.5),
1036
+ prior_num_inference_steps_c: gr.update(visible=False),
1037
+ prior_guidance_scale_c: gr.update(visible=False),
1038
+ decoder_num_inference_steps_c: gr.update(visible=False),
1039
+ decoder_guidance_scale_c: gr.update(visible=False),
1040
+ width_c: gr.update(maximum=1344),
1041
+ height_c: gr.update(maximum=1344),
1042
+ }
1043
+
1044
+ def toggle_visibility_arena_d(model_choice_d):
1045
+ if model_choice_d == "stable cascade":
1046
+ return {
1047
+ num_inference_steps_d: gr.update(visible=False),
1048
+ guidance_scale_d: gr.update(visible=False),
1049
+ prior_num_inference_steps_d: gr.update(visible=True),
1050
+ prior_guidance_scale_d: gr.update(visible=True),
1051
+ decoder_num_inference_steps_d: gr.update(visible=True),
1052
+ decoder_guidance_scale_d: gr.update(visible=True),
1053
+ }
1054
+ elif model_choice_d == "sdxl flash":
1055
+ return {
1056
+ num_inference_steps_d: gr.update(visible=True, maximum=15, value=8),
1057
+ guidance_scale_d: gr.update(visible=True, maximum=6.0, value=3.5),
1058
+ prior_num_inference_steps_d: gr.update(visible=False),
1059
+ prior_guidance_scale_d: gr.update(visible=False),
1060
+ decoder_num_inference_steps_d: gr.update(visible=False),
1061
+ decoder_guidance_scale_d: gr.update(visible=False),
1062
+ }
1063
+ elif model_choice_d == "sd1.5":
1064
+ return {
1065
+ num_inference_steps_d: gr.update(visible=True, maximum=50, value=25),
1066
+ guidance_scale_d: gr.update(visible=True, maximum=10.0, value=7.5),
1067
+ prior_num_inference_steps_d: gr.update(visible=False),
1068
+ prior_guidance_scale_d: gr.update(visible=False),
1069
+ decoder_num_inference_steps_d: gr.update(visible=False),
1070
+ decoder_guidance_scale_d: gr.update(visible=False),
1071
+ width_d: gr.update(value=512, maximum=768),
1072
+ height_d: gr.update(value=512, maximum=768),
1073
+ }
1074
+ elif model_choice_d == "sd2.1":
1075
+ return {
1076
+ num_inference_steps_d: gr.update(visible=True, maximum=50, value=25),
1077
+ guidance_scale_d: gr.update(visible=True, maximum=10.0, value=7.5),
1078
+ prior_num_inference_steps_d: gr.update(visible=False),
1079
+ prior_guidance_scale_d: gr.update(visible=False),
1080
+ decoder_num_inference_steps_d: gr.update(visible=False),
1081
+ decoder_guidance_scale_d: gr.update(visible=False),
1082
+ width_d: gr.update(value=768, maximum=1024),
1083
+ height_d: gr.update(value=768, maximum=1024),
1084
+ }
1085
+ else:
1086
+ return {
1087
+ num_inference_steps_d: gr.update(visible=True, maximum=50, value=25),
1088
+ guidance_scale_d: gr.update(visible=True, maximum=10.0, value=7.5),
1089
+ prior_num_inference_steps_d: gr.update(visible=False),
1090
+ prior_guidance_scale_d: gr.update(visible=False),
1091
+ decoder_num_inference_steps_d: gr.update(visible=False),
1092
+ decoder_guidance_scale_d: gr.update(visible=False),
1093
+ width_d: gr.update(maximum=1344),
1094
+ height_d: gr.update(maximum=1344),
1095
  }
1096
 
1097
  model_choice_a.change(
 
1103
  prior_num_inference_steps_a,
1104
  prior_guidance_scale_a,
1105
  decoder_num_inference_steps_a,
1106
+ decoder_guidance_scale_a,
1107
+ width_a,
1108
+ height_a,
1109
+ ],
1110
  )
1111
  model_choice_b.change(
1112
  toggle_visibility_arena_b,
 
1117
  prior_num_inference_steps_b,
1118
  prior_guidance_scale_b,
1119
  decoder_num_inference_steps_b,
1120
+ decoder_guidance_scale_b,
1121
+ width_b,
1122
+ height_b,
1123
+ ],
1124
+ )
1125
+ model_choice_c.change(
1126
+ toggle_visibility_arena_c,
1127
+ inputs=[model_choice_c],
1128
+ outputs=[
1129
+ num_inference_steps_c,
1130
+ guidance_scale_c,
1131
+ prior_num_inference_steps_c,
1132
+ prior_guidance_scale_c,
1133
+ decoder_num_inference_steps_c,
1134
+ decoder_guidance_scale_c,
1135
+ width_c,
1136
+ height_c,
1137
+ ],
1138
+ )
1139
+ model_choice_d.change(
1140
+ toggle_visibility_arena_d,
1141
+ inputs=[model_choice_d],
1142
+ outputs=[
1143
+ num_inference_steps_d,
1144
+ guidance_scale_d,
1145
+ prior_num_inference_steps_d,
1146
+ prior_guidance_scale_d,
1147
+ decoder_num_inference_steps_d,
1148
+ decoder_guidance_scale_d,
1149
+ width_d,
1150
+ height_d,
1151
+ ],
1152
  )
1153
 
1154
+ def toggle_model_options(num_models):
1155
+ if num_models == 2:
1156
+ return {
1157
+ model_choice_c: gr.update(visible=False),
1158
+ model_d_options: gr.update(visible=False),
1159
+ model_choice_d: gr.update(visible=False),
1160
+ result_3: gr.update(visible=False),
1161
+ result_4: gr.update(visible=False),
1162
+ model_c_options: gr.update(visible=False),
1163
+ }
1164
+ elif num_models == 3:
1165
+ return {
1166
+ model_choice_c: gr.update(visible=True),
1167
+ model_d_options: gr.update(visible=False),
1168
+ model_choice_d: gr.update(visible=False),
1169
+ result_3: gr.update(visible=True),
1170
+ result_4: gr.update(visible=False),
1171
+ model_c_options: gr.update(visible=True),
1172
+ }
1173
+ elif num_models == 4:
1174
+ return {
1175
+ model_choice_c: gr.update(visible=True),
1176
+ model_d_options: gr.update(visible=True),
1177
+ model_choice_d: gr.update(visible=True),
1178
+ result_3: gr.update(visible=True),
1179
+ result_4: gr.update(visible=True),
1180
+ model_c_options: gr.update(visible=True),
1181
+ }
1182
+
1183
+ num_models_to_compare.change(
1184
+ toggle_model_options,
1185
+ inputs=[num_models_to_compare],
1186
+ outputs=[
1187
+ model_choice_c,
1188
+ model_d_options,
1189
+ model_choice_d,
1190
+ result_3,
1191
+ result_4,
1192
+ model_c_options,
1193
+ ],
1194
+ )
1195
 
1196
  gr.Examples(
1197
  examples=examples_arena,
1198
  inputs=[
1199
  prompt,
1200
  negative_prompt,
1201
+ num_models_to_compare,
1202
  num_inference_steps_a,
1203
  guidance_scale_a,
1204
  num_inference_steps_b,
1205
  guidance_scale_b,
1206
+ num_inference_steps_c,
1207
+ guidance_scale_c,
1208
+ num_inference_steps_d,
1209
+ guidance_scale_d,
1210
+ height_a,
1211
+ width_a,
1212
+ height_b,
1213
+ width_b,
1214
+ height_c,
1215
+ width_c,
1216
+ height_d,
1217
+ width_d,
1218
  seed,
1219
  num_images_per_prompt,
1220
  model_choice_a,
1221
  model_choice_b,
1222
+ model_choice_c,
1223
+ model_choice_d,
1224
  prior_num_inference_steps_a,
1225
  prior_guidance_scale_a,
1226
  decoder_num_inference_steps_a,
 
1229
  prior_guidance_scale_b,
1230
  decoder_num_inference_steps_b,
1231
  decoder_guidance_scale_b,
1232
+ prior_num_inference_steps_c,
1233
+ prior_guidance_scale_c,
1234
+ decoder_num_inference_steps_c,
1235
+ decoder_guidance_scale_c,
1236
+ prior_num_inference_steps_d,
1237
+ prior_guidance_scale_d,
1238
+ decoder_num_inference_steps_d,
1239
+ decoder_guidance_scale_d,
1240
  ],
1241
+ outputs=[result_1, result_2, result_3, result_4],
1242
  fn=generate_arena_images,
1243
  )
1244
 
 
1251
  inputs=[
1252
  prompt,
1253
  negative_prompt,
1254
+ num_models_to_compare,
1255
  num_inference_steps_a,
1256
  guidance_scale_a,
1257
  num_inference_steps_b,
1258
  guidance_scale_b,
1259
+ num_inference_steps_c,
1260
+ guidance_scale_c,
1261
+ num_inference_steps_d,
1262
+ guidance_scale_d,
1263
+ height_a,
1264
+ width_a,
1265
+ height_b,
1266
+ width_b,
1267
+ height_c,
1268
+ width_c,
1269
+ height_d,
1270
+ width_d,
1271
  seed,
1272
  num_images_per_prompt,
1273
  model_choice_a,
1274
  model_choice_b,
1275
+ model_choice_c,
1276
+ model_choice_d,
1277
  prior_num_inference_steps_a,
1278
  prior_guidance_scale_a,
1279
  decoder_num_inference_steps_a,
 
1282
  prior_guidance_scale_b,
1283
  decoder_num_inference_steps_b,
1284
  decoder_guidance_scale_b,
1285
+ prior_num_inference_steps_c,
1286
+ prior_guidance_scale_c,
1287
+ decoder_num_inference_steps_c,
1288
+ decoder_guidance_scale_c,
1289
+ prior_num_inference_steps_d,
1290
+ prior_guidance_scale_d,
1291
+ decoder_num_inference_steps_d,
1292
+ decoder_guidance_scale_d,
1293
  ],
1294
+ outputs=[result_1, result_2, result_3, result_4],
1295
  )
1296
 
1297
  with gr.TabItem("Individual"):
 
1304
  )
1305
  model_choice = gr.Dropdown(
1306
  label="Stable Diffusion Model",
1307
+ choices=[
1308
+ "sd3 medium",
1309
+ "sd2.1",
1310
+ "sdxl",
1311
+ "sdxl flash",
1312
+ "stable cascade",
1313
+ "sd1.5",
1314
+ ],
1315
  value="sd3 medium",
1316
  )
1317
  run_button = gr.Button("Run")
1318
+ result = gr.Gallery(
1319
+ label="Generated AI Images", elem_id="gallery"
1320
+ )
1321
  with gr.Accordion("Advanced options", open=False):
1322
  with gr.Row():
1323
  negative_prompt = gr.Textbox(
 
1334
  maximum=50,
1335
  value=25,
1336
  step=1,
1337
+ visible=True,
1338
  )
1339
  guidance_scale = gr.Slider(
1340
  label="Guidance Scale",
 
1343
  maximum=10.0,
1344
  value=7.5,
1345
  step=0.1,
1346
+ visible=True,
1347
  )
1348
  prior_num_inference_steps = gr.Slider(
1349
  label="Prior Inference Steps",
 
1352
  maximum=50,
1353
  value=25,
1354
  step=1,
1355
+ visible=False,
1356
  )
1357
  prior_guidance_scale = gr.Slider(
1358
  label="Prior Guidance Scale",
 
1361
  maximum=10.0,
1362
  value=4.0,
1363
  step=0.1,
1364
+ visible=False,
1365
  )
1366
  decoder_num_inference_steps = gr.Slider(
1367
  label="Decoder Inference Steps",
 
1370
  maximum=15,
1371
  value=12,
1372
  step=1,
1373
+ visible=False,
1374
  )
1375
  decoder_guidance_scale = gr.Slider(
1376
  label="Decoder Guidance Scale",
 
1379
  maximum=10.0,
1380
  value=0.0,
1381
  step=0.1,
1382
+ visible=False,
1383
  )
1384
  with gr.Row():
1385
  width = gr.Slider(
 
1435
  decoder_num_inference_steps: gr.update(visible=False),
1436
  decoder_guidance_scale: gr.update(visible=False),
1437
  }
1438
+ elif model_choice == "sd1.5":
1439
+ return {
1440
+ num_inference_steps: gr.update(visible=True, maximum=50, value=25),
1441
+ guidance_scale: gr.update(visible=True, maximum=10.0, value=7.5),
1442
+ prior_num_inference_steps: gr.update(visible=False),
1443
+ prior_guidance_scale: gr.update(visible=False),
1444
+ decoder_num_inference_steps: gr.update(visible=False),
1445
+ decoder_guidance_scale: gr.update(visible=False),
1446
+ width: gr.update(value=512, maximum=768),
1447
+ height: gr.update(value=512, maximum=768),
1448
+ }
1449
+ elif model_choice == "sd2.1":
1450
+ return {
1451
+ num_inference_steps: gr.update(visible=True, maximum=50, value=25),
1452
+ guidance_scale: gr.update(visible=True, maximum=10.0, value=7.5),
1453
+ prior_num_inference_steps: gr.update(visible=False),
1454
+ prior_guidance_scale: gr.update(visible=False),
1455
+ decoder_num_inference_steps: gr.update(visible=False),
1456
+ decoder_guidance_scale: gr.update(visible=False),
1457
+ width: gr.update(value=768, maximum=1024),
1458
+ height: gr.update(value=768, maximum=1024),
1459
+ }
1460
  else:
1461
  return {
1462
  num_inference_steps: gr.update(visible=True, maximum=50, value=25),
 
1465
  prior_guidance_scale: gr.update(visible=False),
1466
  decoder_num_inference_steps: gr.update(visible=False),
1467
  decoder_guidance_scale: gr.update(visible=False),
1468
+ width: gr.update(maximum=1344),
1469
+ height: gr.update(maximum=1344),
1470
  }
1471
 
1472
  model_choice.change(
 
1478
  prior_num_inference_steps,
1479
  prior_guidance_scale,
1480
  decoder_num_inference_steps,
1481
+ decoder_guidance_scale,
1482
+ width,
1483
+ height,
1484
+ ],
1485
  )
1486
 
1487
  gr.Examples(