prithivMLmods commited on
Commit
45c21ee
1 Parent(s): 1013a67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +398 -0
app.py CHANGED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ from PIL import Image
9
+ import spaces
10
+ import torch
11
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
12
+
13
+ # Use environment variables for flexibility
14
+ MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
15
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
16
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
17
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
18
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
19
+
20
+ # Determine device and load model outside of function for efficiency
21
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
+ pipe = StableDiffusionXLPipeline.from_pretrained(
23
+ MODEL_ID,
24
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
25
+ use_safetensors=True,
26
+ add_watermarker=False,
27
+ ).to(device)
28
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
29
+
30
+ # Torch compile for potential speedup (experimental)
31
+ if USE_TORCH_COMPILE:
32
+ pipe.compile()
33
+
34
+ # CPU offloading for larger RAM capacity (experimental)
35
+ if ENABLE_CPU_OFFLOAD:
36
+ pipe.enable_model_cpu_offload()
37
+
38
+ MAX_SEED = np.iinfo(np.int32).max
39
+
40
+ def save_image(img):
41
+ unique_name = str(uuid.uuid4()) + ".png"
42
+ img.save(unique_name)
43
+ return unique_name
44
+
45
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
46
+ if randomize_seed:
47
+ seed = random.randint(0, MAX_SEED)
48
+ return seed
49
+
50
+ @spaces.GPU(duration=35, enable_queue=True)
51
+ def generate(
52
+ prompt: str,
53
+ negative_prompt: str = "",
54
+ use_negative_prompt: bool = False,
55
+ seed: int = 1,
56
+ width: int = 1024,
57
+ height: int = 1024,
58
+ guidance_scale: float = 3,
59
+ num_inference_steps: int = 30,
60
+ randomize_seed: bool = False,
61
+ use_resolution_binning: bool = True,
62
+ num_images: int = 1, # Number of images to generate
63
+ progress=gr.Progress(track_tqdm=True),
64
+ ):
65
+ seed = int(randomize_seed_fn(seed, randomize_seed))
66
+ generator = torch.Generator(device=device).manual_seed(seed)
67
+
68
+ # Improved options handling
69
+ options = {
70
+ "prompt": [prompt] * num_images,
71
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
72
+ "width": width,
73
+ "height": height,
74
+ "guidance_scale": guidance_scale,
75
+ "num_inference_steps": num_inference_steps,
76
+ "generator": generator,
77
+ "output_type": "pil",
78
+ }
79
+
80
+ # Use resolution binning for faster generation with less VRAM usage
81
+ if use_resolution_binning:
82
+ options["use_resolution_binning"] = True
83
+
84
+ # Generate images potentially in batches
85
+ images = []
86
+ for i in range(0, num_images, BATCH_SIZE):
87
+ batch_options = options.copy()
88
+ batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
89
+ if "negative_prompt" in batch_options:
90
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
91
+ images.extend(pipe(**batch_options).images)
92
+
93
+ image_paths = [save_image(img) for img in images]
94
+ return image_paths, seed
95
+
96
+ examples = [
97
+ "a cat eating a piece of cheese",
98
+ "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
99
+ "Ironman VS Hulk, ultrarealistic",
100
+ "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
101
+ "An alien holding a sign board containing the word 'Flash', futuristic, neonpunk",
102
+ "Kids going to school, Anime style"
103
+ ]
104
+
105
+ css = '''
106
+ .gradio-container{max-width: 700px !important}
107
+ h1{text-align:center}
108
+ footer {
109
+ visibility: hidden
110
+ }
111
+
112
+ .wheel-and-hamster {
113
+ --dur: 1s;
114
+ position: relative;
115
+ width: 12em;
116
+ height: 12em;
117
+ font-size: 14px;
118
+ }
119
+
120
+ .wheel,
121
+ .hamster,
122
+ .hamster div,
123
+ .spoke {
124
+ position: absolute;
125
+ }
126
+
127
+ .wheel,
128
+ .spoke {
129
+ border-radius: 50%;
130
+ top: 0;
131
+ left: 0;
132
+ width: 100%;
133
+ height: 100%;
134
+ }
135
+
136
+ .wheel {
137
+ background: radial-gradient(100% 100% at center,hsla(0,0%,60%,0) 47.8%,hsl(0,0%,60%) 48%);
138
+ z-index: 2;
139
+ }
140
+
141
+ .hamster {
142
+ animation: hamster var(--dur) ease-in-out infinite;
143
+ top: 50%;
144
+ left: calc(50% - 3.5em);
145
+ width: 7em;
146
+ height: 3.75em;
147
+ transform: rotate(4deg) translate(-0.8em,1.85em);
148
+ transform-origin: 50% 0;
149
+ z-index: 1;
150
+ }
151
+
152
+ .hamster__head {
153
+ animation: hamsterHead var(--dur) ease-in-out infinite;
154
+ background: hsl(30,90%,55%);
155
+ border-radius: 70% 30% 0 100% / 40% 25% 25% 60%;
156
+ box-shadow: 0 -0.25em 0 hsl(30,90%,80%) inset,
157
+ 0.75em -1.55em 0 hsl(30,90%,90%) inset;
158
+ top: 0;
159
+ left: -2em;
160
+ width: 2.75em;
161
+ height: 2.5em;
162
+ transform-origin: 100% 50%;
163
+ }
164
+
165
+ .hamster__ear {
166
+ animation: hamsterEar var(--dur) ease-in-out infinite;
167
+ background: hsl(0,90%,85%);
168
+ border-radius: 50%;
169
+ box-shadow: -0.25em 0 hsl(30,90%,55%) inset;
170
+ top: -0.25em;
171
+ right: -0.25em;
172
+ width: 0.75em;
173
+ height: 0.75em;
174
+ transform-origin: 50% 75%;
175
+ }
176
+
177
+ .hamster__eye {
178
+ animation: hamsterEye var(--dur) linear infinite;
179
+ background-color: hsl(0,0%,0%);
180
+ border-radius: 50%;
181
+ top: 0.375em;
182
+ left: 1.25em;
183
+ width: 0.5em;
184
+ height: 0.5em;
185
+ }
186
+
187
+ .hamster__nose {
188
+ background: hsl(0,90%,75%);
189
+ border-radius: 35% 65% 85% 15% / 70% 50% 50% 30%;
190
+ top: 0.75em;
191
+ left: 0;
192
+ width: 0.2em;
193
+ height: 0.25em;
194
+ }
195
+
196
+ .hamster__body {
197
+ animation: hamsterBody var(--dur) ease-in-out infinite;
198
+ background: hsl(30,90%,90%);
199
+ border-radius: 50% 30% 50% 30% / 15% 60% 40% 40%;
200
+ box-shadow: 0.1em 0.75em 0 hsl(30,90%,55%) inset,
201
+ 0.15em -0.5em 0 hsl(30,90%,80%) inset;
202
+ top: 0.25em;
203
+ left: 2em;
204
+ width: 4.5em;
205
+ height: 3em;
206
+ transform-origin: 17% 50%;
207
+ transform-style: preserve-3d;
208
+ }
209
+
210
+ .hamster__limb--fr,
211
+ .hamster__limb--fl {
212
+ clip-path: polygon(0 0,100% 0,70% 80%,60% 100%,0% 100%,40% 80%);
213
+ top: 2em;
214
+ left: 0.5em;
215
+ width: 1em;
216
+ height: 1.5em;
217
+ transform-origin: 50% 0;
218
+ }
219
+
220
+ .hamster__limb--fr {
221
+ animation: hamsterFRLimb var(--dur) linear infinite;
222
+ background: linear-gradient(hsl(30,90%,80%) 80%,hsl(0,90%,75%) 80%);
223
+ transform: rotate(15deg) translateZ(-1px);
224
+ }
225
+
226
+ .hamster__limb--fl {
227
+ animation: hamsterFLLimb var(--dur) linear infinite;
228
+ background: linear-gradient(hsl(30,90%,80%) 80%,hsl(0,90%,75%) 80%);
229
+ transform: rotate(-60deg) translateZ(-1px);
230
+ }
231
+
232
+ .hamster__limb--br,
233
+ .hamster__limb--bl {
234
+ clip-path: polygon(0 0,100% 0,100% 20%,30% 100%,0% 100%);
235
+ top: 1.25em;
236
+ left: 2.8em;
237
+ width: 1.5em;
238
+ height: 2.5em;
239
+ transform-origin: 33% 10%;
240
+ }
241
+
242
+ .hamster__limb--br {
243
+ animation: hamsterBRLimb var(--dur) linear infinite;
244
+ background: linear-gradient(hsl(0,90%,75%) 40%,hsl(30,90%,80%) 40%);
245
+ transform: rotate(-15deg) translateZ(-1px);
246
+ }
247
+
248
+ .hamster__limb--bl {
249
+ animation: hamsterBLLimb var(--dur) linear infinite;
250
+ background: linear-gradient(hsl(0,90%,75%) 40%,hsl(30,90%,80%) 40%);
251
+ transform: rotate(60deg) translateZ(-1px);
252
+ }
253
+
254
+ .hamster__tail {
255
+ animation: hamsterTail var(--dur) linear infinite;
256
+ background: hsl(0,90%,85%);
257
+ border-radius: 0.25em 50% 50% 0.25em;
258
+ box-shadow: 0.25em 0 hsl(30,90%,55%) inset;
259
+ top: 1.5em;
260
+ left: 5.5em;
261
+ width: 0.5em;
262
+ height: 0.75em;
263
+ transform: rotate(30deg) translateZ(-1px);
264
+ transform-origin: 0.25em 0.125em;
265
+ }
266
+
267
+ .spoke {
268
+ background: radial-gradient(hsl(0,0%,70%) 25%,hsla(0,0%,60%,0) 26%) center/8px 8px;
269
+ z-index: 0;
270
+ }
271
+
272
+ .spoke--1 {
273
+ animation: spoke var(--dur) linear infinite;
274
+ }
275
+
276
+ .spoke--2 {
277
+ animation: spoke var(--dur) linear infinite;
278
+ transform: rotate(30deg);
279
+ }
280
+
281
+ .spoke--3 {
282
+ animation: spoke var(--dur) linear infinite;
283
+ transform: rotate(60deg);
284
+ }
285
+
286
+ @keyframes hamster {
287
+ 0%,100% { transform: rotate(4deg) translate(-0.8em,1.85em) }
288
+ 50% { transform: rotate(0) translate(-0.8em,1.85em) }
289
+ }
290
+
291
+ @keyframes hamsterHead {
292
+ 0%,100% { transform: rotate(0) }
293
+ 50% { transform: rotate(-8deg) }
294
+ }
295
+
296
+ @keyframes hamsterEar {
297
+ 0%,100% { transform: rotate(0) }
298
+ 50% { transform: rotate(-3deg) }
299
+ }
300
+
301
+ @keyframes hamsterEye {
302
+ 0%,90%,100% { transform: scaleY(1) }
303
+ 95% { transform: scaleY(0) }
304
+ }
305
+
306
+ @keyframes hamsterBody {
307
+ 0%,100% { transform: rotate(0) }
308
+ 50% { transform: rotate(2deg) }
309
+ }
310
+
311
+ @keyframes hamsterFRLimb {
312
+ 0%,100% { transform: rotate(15deg) translateZ(-1px) }
313
+ 50% { transform: rotate(-30deg) translateZ(-1px) }
314
+ }
315
+
316
+ @keyframes hamsterFLLimb {
317
+ 0%,100% { transform: rotate(-60deg) translateZ(-1px) }
318
+ 50% { transform: rotate(-25deg) translateZ(-1px) }
319
+ }
320
+
321
+ @keyframes hamsterBRLimb {
322
+ 0%,100% { transform: rotate(-15deg) translateZ(-1px) }
323
+ 50% { transform: rotate(30deg) translateZ(-1px) }
324
+ }
325
+
326
+ @keyframes hamsterBLLimb {
327
+ 0%,100% { transform: rotate(60deg) translateZ(-1px) }
328
+ 50% { transform: rotate(25deg) translateZ(-1px) }
329
+ }
330
+
331
+ @keyframes hamsterTail {
332
+ 0%,100% { transform: rotate(30deg) translateZ(-1px) }
333
+ 50% { transform: rotate(10deg) translateZ(-1px) }
334
+ }
335
+
336
+ @keyframes spoke {
337
+ 0% { transform: rotate(0) }
338
+ 100% { transform: rotate(1turn) }
339
+ }
340
+ '''
341
+
342
+ html = '''
343
+ <div id="loading-animation" style="display: flex; justify-content: center; align-items: center; height: 100vh;">
344
+ <div class="wheel-and-hamster">
345
+ <div class="wheel"></div>
346
+ <div class="hamster">
347
+ <div class="hamster__body">
348
+ <div class="hamster__head">
349
+ <div class="hamster__ear"></div>
350
+ <div class="hamster__eye"></div>
351
+ <div class="hamster__nose"></div>
352
+ </div>
353
+ <div class="hamster__limb hamster__limb--fr"></div>
354
+ <div class="hamster__limb hamster__limb--fl"></div>
355
+ <div class="hamster__limb hamster__limb--br"></div>
356
+ <div class="hamster__limb hamster__limb--bl"></div>
357
+ <div class="hamster__tail"></div>
358
+ </div>
359
+ </div>
360
+ <div class="spoke spoke--1"></div>
361
+ <div class="spoke spoke--2"></div>
362
+ <div class="spoke spoke--3"></div>
363
+ </div>
364
+ </div>
365
+ <script>
366
+ window.onload = function() {
367
+ document.getElementById("loading-animation").style.display = "none";
368
+ }
369
+ </script>
370
+ '''
371
+
372
+ with gr.Blocks(css=css) as demo:
373
+ gr.HTML(html)
374
+ gr.Markdown("# Flash Attention with SDXL")
375
+ gr.Markdown("Generate images with Flash Attention and SDXL")
376
+
377
+ with gr.Row():
378
+ with gr.Column(scale=55):
379
+ prompt = gr.Textbox(label="Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt").style(container=False)
380
+ negative_prompt = gr.Textbox(label="Negative Prompt", show_label=False, max_lines=2, placeholder="Enter negative prompt").style(container=False)
381
+ with gr.Column(scale=45):
382
+ generate_btn = gr.Button("Generate")
383
+
384
+ with gr.Row():
385
+ image_output = gr.Gallery(label="Generated Images").style(grid=2, height="auto")
386
+ seed_output = gr.Number(label="Seed Used")
387
+
388
+ gr.Examples(examples=examples, inputs=[prompt])
389
+
390
+ inputs = [prompt, negative_prompt, gr.Checkbox(False, label="Use Negative Prompt"), gr.Slider(1, MAX_SEED, value=1, label="Seed"),
391
+ gr.Slider(256, MAX_IMAGE_SIZE, value=1024, label="Width"), gr.Slider(256, MAX_IMAGE_SIZE, value=1024, label="Height"),
392
+ gr.Slider(1, 20, value=7.5, label="Guidance Scale"), gr.Slider(1, 100, value=30, label="Number of Inference Steps"),
393
+ gr.Checkbox(False, label="Randomize Seed"), gr.Checkbox(True, label="Use Resolution Binning"),
394
+ gr.Slider(1, 10, value=1, label="Number of Images")]
395
+
396
+ generate_btn.click(fn=generate, inputs=inputs, outputs=[image_output, seed_output])
397
+
398
+ demo.launch()