kadirnar commited on
Commit
df81eb7
·
verified ·
1 Parent(s): 6723bf5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -174
app.py CHANGED
@@ -1,184 +1,217 @@
1
- import gradio as gr
2
- import torch
3
- from diffusers import StableDiffusionPipeline
4
 
5
- from diffusion_webui.utils.model_list import stable_model_list
6
- from diffusion_webui.utils.scheduler_list import (
7
- SCHEDULER_MAPPING,
8
- get_scheduler,
9
- )
10
- import spaces
11
 
 
 
 
 
12
 
13
- class StableDiffusionText2ImageGenerator:
14
- def __init__(self):
15
- self.pipe = None
16
-
17
- def load_model(
18
- self,
19
- model_path,
20
- scheduler,
21
- ):
22
- if self.pipe is None:
23
- self.pipe = StableDiffusionPipeline.from_pretrained(
24
- model_path, safety_checker=None, torch_dtype=torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  )
26
-
27
- self.pipe = get_scheduler(pipe=self.pipe, scheduler=scheduler)
28
- self.pipe.to("cuda")
29
- self.pipe.enable_xformers_memory_efficient_attention()
30
-
31
- return self.pipe
32
-
33
- @spaces.GPU()
34
- def generate_image(
35
- self,
36
- model_path: str,
37
- prompt: str,
38
- negative_prompt: str,
39
- num_images_per_prompt: int,
40
- scheduler: str,
41
- guidance_scale: int,
42
- num_inference_step: int,
43
- height: int,
44
- width: int,
45
- seed_generator=0,
46
- ):
47
- pipe = self.load_model(
48
- model_path=model_path,
49
- scheduler=scheduler,
 
 
 
50
  )
51
- if seed_generator == 0:
52
- random_seed = torch.randint(0, 1000000, (1,))
53
- generator = torch.manual_seed(random_seed)
54
- else:
55
- generator = torch.manual_seed(seed_generator)
56
-
57
- images = pipe(
58
- prompt=prompt,
59
- height=height,
60
- width=width,
61
- negative_prompt=negative_prompt,
62
- num_images_per_prompt=num_images_per_prompt,
63
- num_inference_steps=num_inference_step,
64
- guidance_scale=guidance_scale,
65
- generator=generator,
66
- ).images
67
-
68
- return images
69
-
70
- def app():
71
- with gr.Blocks():
72
- with gr.Row():
73
- with gr.Column():
74
- text2image_prompt = gr.Textbox(
75
- lines=1,
76
- placeholder="Prompt",
77
- show_label=False,
78
- )
79
-
80
- text2image_negative_prompt = gr.Textbox(
81
- lines=1,
82
- placeholder="Negative Prompt",
83
- show_label=False,
84
- )
85
- with gr.Row():
86
- with gr.Column():
87
- text2image_model_path = gr.Dropdown(
88
- choices=stable_model_list,
89
- value=stable_model_list[0],
90
- label="Text-Image Model Id",
91
- )
92
-
93
- text2image_guidance_scale = gr.Slider(
94
- minimum=0.1,
95
- maximum=15,
96
- step=0.1,
97
- value=7.5,
98
- label="Guidance Scale",
99
- )
100
-
101
- text2image_num_inference_step = gr.Slider(
102
- minimum=1,
103
- maximum=100,
104
- step=1,
105
- value=50,
106
- label="Num Inference Step",
107
- )
108
- text2image_num_images_per_prompt = gr.Slider(
109
- minimum=1,
110
- maximum=30,
111
- step=1,
112
- value=1,
113
- label="Number Of Images",
114
- )
115
- with gr.Row():
116
- with gr.Column():
117
- text2image_scheduler = gr.Dropdown(
118
- choices=list(SCHEDULER_MAPPING.keys()),
119
- value=list(SCHEDULER_MAPPING.keys())[0],
120
- label="Scheduler",
121
- )
122
-
123
- text2image_height = gr.Slider(
124
- minimum=128,
125
- maximum=1280,
126
- step=32,
127
- value=512,
128
- label="Image Height",
129
- )
130
-
131
- text2image_width = gr.Slider(
132
- minimum=128,
133
- maximum=1280,
134
- step=32,
135
- value=512,
136
- label="Image Width",
137
- )
138
- text2image_seed_generator = gr.Slider(
139
- label="Seed(0 for random)",
140
- minimum=0,
141
- maximum=1000000,
142
- value=0,
143
- )
144
- text2image_predict = gr.Button(value="Generator")
145
-
146
- with gr.Column():
147
- output_image = gr.Gallery(
148
- label="Generated images",
149
- show_label=False,
150
- elem_id="gallery",
151
- ).style(grid=(1, 2), height=200)
152
-
153
- text2image_predict.click(
154
- fn=StableDiffusionText2ImageGenerator().generate_image,
155
- inputs=[
156
- text2image_model_path,
157
- text2image_prompt,
158
- text2image_negative_prompt,
159
- text2image_num_images_per_prompt,
160
- text2image_scheduler,
161
- text2image_guidance_scale,
162
- text2image_num_inference_step,
163
- text2image_height,
164
- text2image_width,
165
- text2image_seed_generator,
166
- ],
167
- outputs=output_image,
168
  )
169
-
170
- import gradio as gr
171
-
172
-
173
- def diffusion_app():
174
- app = gr.Blocks()
175
- with app:
176
  with gr.Row():
177
- with gr.Column():
178
- StableDiffusionText2ImageGenerator.app()
179
-
180
- app.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  if __name__ == "__main__":
184
- diffusion_app()
 
 
 
 
1
 
 
 
 
 
 
 
2
 
3
+ import os
4
+ import random
5
+ import uuid
6
+ import json
7
 
8
+ import gradio as gr
9
+ import numpy as np
10
+ from PIL import Image
11
+ import spaces
12
+ import torch
13
+ from diffusers import StableDiffusionXLPipeline, DPMSolverSinglestepScheduler
14
+
15
+ # Use environment variables for flexibility
16
+ MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
17
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
18
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
19
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
20
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
21
+
22
+ # Determine device and load model outside of function for efficiency
23
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+ pipe = StableDiffusionXLPipeline.from_single_file(
25
+ "https://huggingface.co/kadirnar/Black-Hole/blob/main/tachyon.safetensors",
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
+ use_safetensors=True,
28
+ add_watermarker=False,
29
+ ).to(device)
30
+ pipe.scheduler = DPMSolverSinglestepScheduler(use_karras_sigmas=True).from_config(pipe.scheduler.config)
31
+
32
+ # Torch compile for potential speedup (experimental)
33
+ if USE_TORCH_COMPILE:
34
+ pipe.compile()
35
+
36
+ # CPU offloading for larger RAM capacity (experimental)
37
+ if ENABLE_CPU_OFFLOAD:
38
+ pipe.enable_model_cpu_offload()
39
+
40
+ MAX_SEED = np.iinfo(np.int32).max
41
+
42
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
43
+ if randomize_seed:
44
+ seed = random.randint(0, MAX_SEED)
45
+ return seed
46
+
47
+ @spaces.GPU()
48
+ def generate(
49
+ prompt: str,
50
+ negative_prompt: str = "",
51
+ use_negative_prompt: bool = False,
52
+ seed: int = 1,
53
+ width: int = 1024,
54
+ height: int = 1024,
55
+ guidance_scale: float = 3,
56
+ num_inference_steps: int = 30,
57
+ randomize_seed: bool = False,
58
+ use_resolution_binning: bool = True,
59
+ num_images: int = 1, # Number of images to generate
60
+ progress=gr.Progress(track_tqdm=True),
61
+ ):
62
+ seed = int(randomize_seed_fn(seed, randomize_seed))
63
+ generator = torch.Generator(device=device).manual_seed(seed)
64
+
65
+ # Improved options handling
66
+ options = {
67
+ "prompt": [prompt] * num_images,
68
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
69
+ "width": width,
70
+ "height": height,
71
+ "guidance_scale": guidance_scale,
72
+ "num_inference_steps": num_inference_steps,
73
+ "generator": generator,
74
+ "output_type": "pil",
75
+ }
76
+
77
+ # Use resolution binning for faster generation with less VRAM usage
78
+ if use_resolution_binning:
79
+ options["use_resolution_binning"] = True
80
+
81
+ # Generate images potentially in batches
82
+ images = []
83
+ for i in range(0, num_images, BATCH_SIZE):
84
+ batch_options = options.copy()
85
+ batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
86
+ if "negative_prompt" in batch_options:
87
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
88
+ images.extend(pipe(**batch_options).images)
89
+
90
+ return images
91
+
92
+ examples = [
93
+ "a cat eating a piece of cheese",
94
+ "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
95
+ "Ironman VS Hulk, ultrarealistic",
96
+ "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
97
+ "An alien holding a sign board containing the word 'Flash', futuristic, neonpunk",
98
+ "Kids going to school, Anime style"
99
+ ]
100
+
101
+ css = '''
102
+ .gradio-container{max-width: 700px !important}
103
+ h1{text-align:center}
104
+ footer {
105
+ visibility: hidden
106
+ }
107
+ '''
108
+
109
+ with gr.Blocks(css=css) as demo:
110
+ gr.Markdown("""# Black Hole SDXL-Lightning""")
111
+ with gr.Group():
112
+ with gr.Row():
113
+ prompt = gr.Text(
114
+ label="Prompt",
115
+ show_label=False,
116
+ max_lines=1,
117
+ placeholder="Enter your prompt",
118
+ container=False,
119
  )
120
+ run_button = gr.Button("Run", scale=0)
121
+ result = gr.Gallery(elem_id="gallery", label="Result", show_label=False)
122
+
123
+ with gr.Accordion("Advanced options", open=False):
124
+ num_images = gr.Slider(
125
+ label="Number of Images",
126
+ minimum=1,
127
+ maximum=4,
128
+ step=1,
129
+ value=1,
130
+ )
131
+ with gr.Row():
132
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
133
+ negative_prompt = gr.Text(
134
+ label="Negative prompt",
135
+ max_lines=5,
136
+ lines=4,
137
+ placeholder="Enter a negative prompt",
138
+ value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW",
139
+ visible=True,
140
+ )
141
+ seed = gr.Slider(
142
+ label="Seed",
143
+ minimum=0,
144
+ maximum=MAX_SEED,
145
+ step=1,
146
+ value=0,
147
  )
148
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
149
+ with gr.Row(visible=True):
150
+ width = gr.Slider(
151
+ label="Width",
152
+ minimum=512,
153
+ maximum=MAX_IMAGE_SIZE,
154
+ step=64,
155
+ value=1024,
156
+ )
157
+ height = gr.Slider(
158
+ label="Height",
159
+ minimum=512,
160
+ maximum=MAX_IMAGE_SIZE,
161
+ step=64,
162
+ value=1024,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  )
 
 
 
 
 
 
 
164
  with gr.Row():
165
+ guidance_scale = gr.Slider(
166
+ label="Guidance Scale",
167
+ minimum=0.1,
168
+ maximum=6,
169
+ step=0.1,
170
+ value=3.0,
171
+ )
172
+ num_inference_steps = gr.Slider(
173
+ label="Number of inference steps",
174
+ minimum=1,
175
+ maximum=15,
176
+ step=1,
177
+ value=4,
178
+ )
179
 
180
+ gr.Examples(
181
+ examples=examples,
182
+ inputs=prompt,
183
+ cache_examples=False
184
+ )
185
+
186
+ use_negative_prompt.change(
187
+ fn=lambda x: gr.update(visible=x),
188
+ inputs=use_negative_prompt,
189
+ outputs=negative_prompt,
190
+ api_name=False,
191
+ )
192
+
193
+ gr.on(
194
+ triggers=[
195
+ prompt.submit,
196
+ negative_prompt.submit,
197
+ run_button.click,
198
+ ],
199
+ fn=generate,
200
+ inputs=[
201
+ prompt,
202
+ negative_prompt,
203
+ use_negative_prompt,
204
+ seed,
205
+ width,
206
+ height,
207
+ guidance_scale,
208
+ num_inference_steps,
209
+ randomize_seed,
210
+ num_images
211
+ ],
212
+ outputs=[result, seed],
213
+ api_name="run",
214
+ )
215
 
216
  if __name__ == "__main__":
217
+ demo.queue().launch()