Nick088 commited on
Commit
0900c59
·
verified ·
1 Parent(s): ff28634

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +470 -0
app.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusion3Pipeline, StableDiffusion2Pipeline, StableDiffusionXLBasePipeline
3
+ import gradio as gr
4
+ import os
5
+ import random
6
+ import transformers
7
+ import numpy as np
8
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
9
+ import spaces
10
+
11
+ HF_TOKEN = os.getenv("HF_TOKEN")
12
+
13
+ if torch.cuda.is_available():
14
+ device = "cuda"
15
+ print("Using GPU")
16
+ else:
17
+ device = "cpu"
18
+ print("Using CPU")
19
+
20
+
21
+ MAX_SEED = np.iinfo(np.int32).max
22
+
23
+ # Initialize the pipelines for each sd model
24
+ sd3_medium_pipe = StableDiffusion3Pipeline.from_pretrained(
25
+ "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
26
+ )
27
+ sd3_medium_pipe.to(device)
28
+
29
+ sd2_1_pipe = StableDiffusion2Pipeline.from_pretrained(
30
+ "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
31
+ )
32
+ sd2_1_pipe.to(device)
33
+
34
+ sdxl_pipe = StableDiffusionXLBasePipeline.from_pretrained(
35
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
36
+ )
37
+ sdxl_pipe.to(device)
38
+
39
+ # superprompt-v1
40
+ tokenizer = T5Tokenizer.from_pretrained("roborovski/superprompt-v1")
41
+ model = T5ForConditionalGeneration.from_pretrained(
42
+ "roborovski/superprompt-v1", device_map="auto", torch_dtype="auto"
43
+ )
44
+ model.to(device)
45
+
46
+ # toggle visibility the enhanced prompt output
47
+ def update_visibility(enhance_prompt):
48
+ return gr.update(visible=enhance_prompt)
49
+
50
+
51
+ # Define the image generation function for the Arena tab
52
+ @spaces.GPU(duration=80)
53
+ def generate_arena_images(
54
+ prompt,
55
+ enhance_prompt,
56
+ negative_prompt,
57
+ num_inference_steps,
58
+ height,
59
+ width,
60
+ guidance_scale,
61
+ seed,
62
+ num_images_per_prompt,
63
+ model_choice_1,
64
+ model_choice_2,
65
+ progress=gr.Progress(track_tqdm=True),
66
+ ):
67
+ if seed == 0:
68
+ seed = random.randint(1, 2**32 - 1)
69
+
70
+ if enhance_prompt:
71
+ transformers.set_seed(seed)
72
+
73
+ input_text = f"Expand the following prompt to add more detail: {prompt}"
74
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
75
+
76
+ outputs = model.generate(
77
+ input_ids,
78
+ max_new_tokens=512,
79
+ repetition_penalty=1.2,
80
+ do_sample=True,
81
+ temperature=0.7,
82
+ top_p=1,
83
+ top_k=50,
84
+ )
85
+ prompt = tokenizer.decode(outputs[0], skip_special_tokens=True)
86
+
87
+ generator = torch.Generator().manual_seed(seed)
88
+
89
+ # Generate images for both models
90
+ images_1 = generate_single_image(
91
+ prompt,
92
+ negative_prompt,
93
+ num_inference_steps,
94
+ height,
95
+ width,
96
+ guidance_scale,
97
+ seed,
98
+ num_images_per_prompt,
99
+ model_choice_1,
100
+ generator,
101
+ )
102
+ images_2 = generate_single_image(
103
+ prompt,
104
+ negative_prompt,
105
+ num_inference_steps,
106
+ height,
107
+ width,
108
+ guidance_scale,
109
+ seed,
110
+ num_images_per_prompt,
111
+ model_choice_2,
112
+ generator,
113
+ )
114
+
115
+ return images_1, images_2, prompt
116
+
117
+
118
+ # Helper function to generate images for a single model
119
+ def generate_single_image(
120
+ prompt,
121
+ negative_prompt,
122
+ num_inference_steps,
123
+ height,
124
+ width,
125
+ guidance_scale,
126
+ seed,
127
+ num_images_per_prompt,
128
+ model_choice,
129
+ generator,
130
+ ):
131
+ # Select the correct pipeline based on the model choice
132
+ if model_choice == "sd3 medium":
133
+ pipe = sd3_medium_pipe
134
+ elif model_choice == "sd2.1":
135
+ pipe = sd2_1_pipe
136
+ elif model_choice == "sdxl":
137
+ pipe = sdxl_pipe
138
+ else:
139
+ raise ValueError(f"Invalid model choice: {model_choice}")
140
+
141
+ output = pipe(
142
+ prompt=prompt,
143
+ negative_prompt=negative_prompt,
144
+ num_inference_steps=num_inference_steps,
145
+ height=height,
146
+ width=width,
147
+ guidance_scale=guidance_scale,
148
+ generator=generator,
149
+ num_images_per_prompt=num_images_per_prompt,
150
+ ).images
151
+
152
+ return output
153
+
154
+ # Define the image generation function for the Individual tab
155
+ @spaces.GPU(duration=80)
156
+ def generate_individual_image(
157
+ prompt,
158
+ enhance_prompt,
159
+ negative_prompt,
160
+ num_inference_steps,
161
+ height,
162
+ width,
163
+ guidance_scale,
164
+ seed,
165
+ num_images_per_prompt,
166
+ model_choice,
167
+ progress=gr.Progress(track_tqdm=True),
168
+ ):
169
+ if seed == 0:
170
+ seed = random.randint(1, 2**32 - 1)
171
+
172
+ if enhance_prompt:
173
+ transformers.set_seed(seed)
174
+
175
+ input_text = f"Expand the following prompt to add more detail: {prompt}"
176
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
177
+
178
+ outputs = model.generate(
179
+ input_ids,
180
+ max_new_tokens=512,
181
+ repetition_penalty=1.2,
182
+ do_sample=True,
183
+ temperature=0.7,
184
+ top_p=1,
185
+ top_k=50,
186
+ )
187
+ prompt = tokenizer.decode(outputs[0], skip_special_tokens=True)
188
+
189
+ generator = torch.Generator().manual_seed(seed)
190
+
191
+ output = generate_single_image(
192
+ prompt,
193
+ negative_prompt,
194
+ num_inference_steps,
195
+ height,
196
+ width,
197
+ guidance_scale,
198
+ seed,
199
+ num_images_per_prompt,
200
+ model_choice,
201
+ generator,
202
+ )
203
+
204
+ return output, prompt
205
+
206
+
207
+ # Create the Gradio interface
208
+ examples = [
209
+ ["A white car racing fast to the moon.", True],
210
+ ["A woman in a red dress singing on top of a building.", True],
211
+ ["An astronaut on mars in a futuristic cyborg suit.", True],
212
+ ]
213
+
214
+ css = """
215
+ .gradio-container{max-width: 1000px !important}
216
+ h1{text-align:center}
217
+ """
218
+ with gr.Blocks(css=css) as demo:
219
+ with gr.Row():
220
+ with gr.Column():
221
+ gr.HTML(
222
+ """
223
+ <h1 style='text-align: center'>
224
+ Stable Diffusion Arena
225
+ </h1>
226
+ """
227
+ )
228
+ gr.HTML(
229
+ """
230
+ Made by <a href='https://linktr.ee/Nick088' target='_blank'>Nick088</a>
231
+ <br> <a href="https://discord.gg/osai"> <img src="https://img.shields.io/discord/1198701940511617164?color=%23738ADB&label=Discord&style=for-the-badge" alt="Discord"> </a>
232
+ """
233
+ )
234
+ with gr.Tabs():
235
+ with gr.TabItem("Arena"):
236
+ with gr.Group():
237
+ with gr.Column():
238
+ prompt = gr.Textbox(
239
+ label="Prompt",
240
+ info="Describe the image you want",
241
+ placeholder="A cat...",
242
+ )
243
+ enhance_prompt = gr.Checkbox(
244
+ label="Prompt Enhancement with SuperPrompt-v1", value=True
245
+ )
246
+ model_choice_1 = gr.Dropdown(
247
+ label="Stable Diffusion Model 1",
248
+ choices=["sd3 medium", "sd2.1", "sdxl"],
249
+ value="sd3 medium",
250
+ )
251
+ model_choice_2 = gr.Dropdown(
252
+ label="Stable Diffusion Model 2",
253
+ choices=["sd3 medium", "sd2.1", "sdxl"],
254
+ value="sd2.1",
255
+ )
256
+ run_button = gr.Button("Run")
257
+ result_1 = gr.Gallery(label="Generated Images (Model 1)", elem_id="gallery_1")
258
+ result_2 = gr.Gallery(label="Generated Images (Model 2)", elem_id="gallery_2")
259
+ better_prompt = gr.Textbox(
260
+ label="Enhanced Prompt",
261
+ info="The output of your enhanced prompt used for the Image Generation",
262
+ visible=True,
263
+ )
264
+ enhance_prompt.change(
265
+ fn=update_visibility, inputs=enhance_prompt, outputs=better_prompt
266
+ )
267
+ with gr.Accordion("Advanced options", open=False):
268
+ with gr.Row():
269
+ negative_prompt = gr.Textbox(
270
+ label="Negative Prompt",
271
+ info="Describe what you don't want in the image",
272
+ value="deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
273
+ placeholder="Ugly, bad anatomy...",
274
+ )
275
+ with gr.Row():
276
+ num_inference_steps = gr.Slider(
277
+ label="Number of Inference Steps",
278
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
279
+ minimum=1,
280
+ maximum=50,
281
+ value=25,
282
+ step=1,
283
+ )
284
+ guidance_scale = gr.Slider(
285
+ label="Guidance Scale",
286
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
287
+ minimum=0.0,
288
+ maximum=10.0,
289
+ value=7.5,
290
+ step=0.1,
291
+ )
292
+ with gr.Row():
293
+ width = gr.Slider(
294
+ label="Width",
295
+ info="Width of the Image",
296
+ minimum=256,
297
+ maximum=1344,
298
+ step=32,
299
+ value=1024,
300
+ )
301
+ height = gr.Slider(
302
+ label="Height",
303
+ info="Height of the Image",
304
+ minimum=256,
305
+ maximum=1344,
306
+ step=32,
307
+ value=1024,
308
+ )
309
+ with gr.Row():
310
+ seed = gr.Slider(
311
+ value=42,
312
+ minimum=0,
313
+ maximum=MAX_SEED,
314
+ step=1,
315
+ label="Seed",
316
+ info="A starting point to initiate the generation process, put 0 for a random one",
317
+ )
318
+ num_images_per_prompt = gr.Slider(
319
+ label="Images Per Prompt",
320
+ info="Number of Images to generate with the settings",
321
+ minimum=1,
322
+ maximum=4,
323
+ step=1,
324
+ value=2,
325
+ )
326
+
327
+ gr.Examples(
328
+ examples=examples,
329
+ inputs=[prompt, enhance_prompt],
330
+ outputs=[result_1, result_2, better_prompt],
331
+ fn=generate_arena_images,
332
+ )
333
+
334
+ gr.on(
335
+ triggers=[
336
+ prompt.submit,
337
+ run_button.click,
338
+ ],
339
+ fn=generate_arena_images,
340
+ inputs=[
341
+ prompt,
342
+ enhance_prompt,
343
+ negative_prompt,
344
+ num_inference_steps,
345
+ width,
346
+ height,
347
+ guidance_scale,
348
+ seed,
349
+ num_images_per_prompt,
350
+ model_choice_1,
351
+ model_choice_2,
352
+ ],
353
+ outputs=[result_1, result_2, better_prompt],
354
+ )
355
+
356
+ with gr.TabItem("Individual"):
357
+ with gr.Group():
358
+ with gr.Column():
359
+ prompt = gr.Textbox(
360
+ label="Prompt",
361
+ info="Describe the image you want",
362
+ placeholder="A cat...",
363
+ )
364
+ enhance_prompt = gr.Checkbox(
365
+ label="Prompt Enhancement with SuperPrompt-v1", value=True
366
+ )
367
+ model_choice = gr.Dropdown(
368
+ label="Stable Diffusion Model",
369
+ choices=["sd3 medium", "sd2.1", "sdxl"],
370
+ value="sd3 medium",
371
+ )
372
+ run_button = gr.Button("Run")
373
+ result = gr.Gallery(label="Generated AI Images", elem_id="gallery")
374
+ better_prompt = gr.Textbox(
375
+ label="Enhanced Prompt",
376
+ info="The output of your enhanced prompt used for the Image Generation",
377
+ visible=True,
378
+ )
379
+ enhance_prompt.change(
380
+ fn=update_visibility, inputs=enhance_prompt, outputs=better_prompt
381
+ )
382
+ with gr.Accordion("Advanced options", open=False):
383
+ with gr.Row():
384
+ negative_prompt = gr.Textbox(
385
+ label="Negative Prompt",
386
+ info="Describe what you don't want in the image",
387
+ value="deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
388
+ placeholder="Ugly, bad anatomy...",
389
+ )
390
+ with gr.Row():
391
+ num_inference_steps = gr.Slider(
392
+ label="Number of Inference Steps",
393
+ info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference",
394
+ minimum=1,
395
+ maximum=50,
396
+ value=25,
397
+ step=1,
398
+ )
399
+ guidance_scale = gr.Slider(
400
+ label="Guidance Scale",
401
+ info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.",
402
+ minimum=0.0,
403
+ maximum=10.0,
404
+ value=7.5,
405
+ step=0.1,
406
+ )
407
+ with gr.Row():
408
+ width = gr.Slider(
409
+ label="Width",
410
+ info="Width of the Image",
411
+ minimum=256,
412
+ maximum=1344,
413
+ step=32,
414
+ value=1024,
415
+ )
416
+ height = gr.Slider(
417
+ label="Height",
418
+ info="Height of the Image",
419
+ minimum=256,
420
+ maximum=1344,
421
+ step=32,
422
+ value=1024,
423
+ )
424
+ with gr.Row():
425
+ seed = gr.Slider(
426
+ value=42,
427
+ minimum=0,
428
+ maximum=MAX_SEED,
429
+ step=1,
430
+ label="Seed",
431
+ info="A starting point to initiate the generation process, put 0 for a random one",
432
+ )
433
+ num_images_per_prompt = gr.Slider(
434
+ label="Images Per Prompt",
435
+ info="Number of Images to generate with the settings",
436
+ minimum=1,
437
+ maximum=4,
438
+ step=1,
439
+ value=2,
440
+ )
441
+
442
+ gr.Examples(
443
+ examples=examples,
444
+ inputs=[prompt, enhance_prompt],
445
+ outputs=[result, better_prompt],
446
+ fn=generate_individual_image,
447
+ )
448
+
449
+ gr.on(
450
+ triggers=[
451
+ prompt.submit,
452
+ run_button.click,
453
+ ],
454
+ fn=generate_individual_image,
455
+ inputs=[
456
+ prompt,
457
+ enhance_prompt,
458
+ negative_prompt,
459
+ num_inference_steps,
460
+ width,
461
+ height,
462
+ guidance_scale,
463
+ seed,
464
+ num_images_per_prompt,
465
+ model_choice,
466
+ ],
467
+ outputs=[result, better_prompt],
468
+ )
469
+
470
+ demo.queue().launch(share=False)