Sergidev commited on
Commit
220ffe7
·
verified ·
1 Parent(s): 9f429dc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -167
app.py CHANGED
@@ -12,17 +12,13 @@ from PIL import Image, PngImagePlugin
12
  from datetime import datetime
13
  from diffusers.models import AutoencoderKL
14
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
15
- from collections import deque
16
- import base64
17
- from io import BytesIO
18
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
  DESCRIPTION = "PonyDiffusion V6 XL"
23
  if not torch.cuda.is_available():
24
- DESCRIPTION += "\n\nRunning on CPU 🥶 This demo does not work on CPU."
25
-
26
  IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
27
  HF_TOKEN = os.getenv("HF_TOKEN")
28
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
@@ -31,6 +27,7 @@ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
31
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
32
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
33
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
 
34
  MODEL = os.getenv(
35
  "MODEL",
36
  "https://huggingface.co/AstraliteHeart/pony-diffusion-v6/blob/main/v6.safetensors",
@@ -41,8 +38,6 @@ torch.backends.cudnn.benchmark = False
41
 
42
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
43
 
44
- MAX_HISTORY_SIZE = 10
45
- image_history = deque(maxlen=MAX_HISTORY_SIZE)
46
 
47
  def load_pipeline(model_name):
48
  vae = AutoencoderKL.from_pretrained(
@@ -54,6 +49,7 @@ def load_pipeline(model_name):
54
  if MODEL.endswith(".safetensors")
55
  else StableDiffusionXLPipeline.from_pretrained
56
  )
 
57
  pipe = pipeline(
58
  model_name,
59
  vae=vae,
@@ -64,9 +60,11 @@ def load_pipeline(model_name):
64
  use_auth_token=HF_TOKEN,
65
  variant="fp16",
66
  )
 
67
  pipe.to(device)
68
  return pipe
69
 
 
70
  @spaces.GPU
71
  def generate(
72
  prompt: str,
@@ -84,16 +82,20 @@ def generate(
84
  progress=gr.Progress(track_tqdm=True),
85
  ) -> Image:
86
  generator = utils.seed_everything(seed)
 
87
  width, height = utils.aspect_ratio_handler(
88
- aspect_ratio_selector, custom_width, custom_height,
 
 
89
  )
 
90
  width, height = utils.preprocess_image_dimensions(width, height)
 
91
  backup_scheduler = pipe.scheduler
92
  pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
93
 
94
  if use_upscaler:
95
  upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
96
-
97
  metadata = {
98
  "prompt": prompt,
99
  "negative_prompt": negative_prompt,
@@ -115,7 +117,6 @@ def generate(
115
  }
116
  else:
117
  metadata["use_upscaler"] = None
118
-
119
  logger.info(json.dumps(metadata, indent=4))
120
 
121
  try:
@@ -153,33 +154,12 @@ def generate(
153
  output_type="pil",
154
  ).images
155
 
156
- if images:
157
- image = images[0]
158
- # Create thumbnail
159
- thumbnail = image.copy()
160
- thumbnail.thumbnail((256, 256))
161
-
162
- # Convert thumbnail to base64
163
- buffered = BytesIO()
164
- thumbnail.save(buffered, format="PNG")
165
- img_str = base64.b64encode(buffered.getvalue()).decode()
166
-
167
- # Add to history
168
- image_history.appendleft({
169
- "thumbnail": f"data:image/png;base64,{img_str}",
170
- "prompt": prompt,
171
- "negative_prompt": negative_prompt,
172
- "seed": seed,
173
- "width": width,
174
- "height": height,
175
- })
176
-
177
- if IS_COLAB:
178
  filepath = utils.save_image(image, metadata, OUTPUT_DIR)
179
  logger.info(f"Image saved as {filepath} with metadata")
180
 
181
- return image, metadata, list(image_history)
182
-
183
  except Exception as e:
184
  logger.exception(f"An error occurred: {e}")
185
  raise
@@ -189,6 +169,7 @@ def generate(
189
  pipe.scheduler = backup_scheduler
190
  utils.free_memory()
191
 
 
192
  if torch.cuda.is_available():
193
  pipe = load_pipeline(MODEL)
194
  logger.info("Loaded on Device!")
@@ -197,32 +178,52 @@ else:
197
 
198
  with gr.Blocks(css="style.css") as demo:
199
  title = gr.HTML(
200
- f"""<h1>{DESCRIPTION}</h1>"""
 
201
  )
202
-
203
- with gr.Row():
204
- with gr.Column(scale=2):
205
- prompt = gr.Textbox(
 
 
 
 
 
 
 
 
206
  label="Prompt",
207
  show_label=False,
208
- max_lines=2,
209
  placeholder="Enter your prompt",
 
210
  )
211
- negative_prompt = gr.Textbox(
212
- label="Negative Prompt",
213
- show_label=False,
214
- max_lines=2,
215
- placeholder="Enter a negative prompt",
216
  )
217
-
218
- with gr.Row():
219
- seed = gr.Number(
220
- label="Seed",
221
- value=0,
222
- precision=0,
223
- )
224
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
225
-
 
 
 
 
 
 
 
 
 
 
 
226
  with gr.Row():
227
  custom_width = gr.Slider(
228
  label="Width",
@@ -238,126 +239,125 @@ with gr.Blocks(css="style.css") as demo:
238
  step=8,
239
  value=1024,
240
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  with gr.Row():
243
  guidance_scale = gr.Slider(
244
- label="Guidance Scale", minimum=0, maximum=20, step=0.1, value=7
 
 
 
 
245
  )
246
  num_inference_steps = gr.Slider(
247
- label="Num Inference Steps",
248
  minimum=1,
249
- maximum=100,
250
  step=1,
251
- value=30,
252
- )
253
-
254
- with gr.Row():
255
- sampler = gr.Dropdown(
256
- label="Sampler",
257
- choices=[
258
- "DPM++ 2M SDE Karras",
259
- "DPM++ 2M SDE",
260
- "Euler a",
261
- "Euler",
262
- "DPM++ 2M Karras",
263
- "DPM++ 2M",
264
- "LMS Karras",
265
- "Heun",
266
- "DPM++ SDE Karras",
267
- "DPM++ SDE",
268
- "DPM2 Karras",
269
- "DPM2",
270
- "DPM2 a Karras",
271
- "DPM2 a",
272
- "LMS",
273
- "DDIM",
274
- "PLMS",
275
- ],
276
- value="DPM++ 2M SDE Karras",
277
- )
278
- aspect_ratio_selector = gr.Dropdown(
279
- label="Aspect Ratio",
280
- choices=[
281
- "1024 x 1024",
282
- "1152 x 896",
283
- "896 x 1152",
284
- "1216 x 832",
285
- "832 x 1216",
286
- "1344 x 768",
287
- "768 x 1344",
288
- "1536 x 640",
289
- "640 x 1536",
290
- ],
291
- value="1024 x 1024",
292
- )
293
-
294
- with gr.Row():
295
- use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
296
- upscaler_strength = gr.Slider(
297
- label="Upscaler Strength",
298
- minimum=0,
299
- maximum=1,
300
- step=0.05,
301
- value=0.55,
302
- )
303
- upscale_by = gr.Slider(
304
- label="Upscale By",
305
- minimum=1,
306
- maximum=4,
307
- step=0.1,
308
- value=1.5,
309
  )
310
-
311
- with gr.Column(scale=1):
312
- output_image = gr.Image(label="Generated Image")
313
- output_text = gr.JSON(label="Generation Info")
314
-
315
- with gr.Row():
316
- generate_button = gr.Button("Generate")
317
-
318
- # Add the history component
319
- history = gr.HTML(label="Generation History")
320
-
321
- # Update the generate_button click event
322
- generate_button.click(
323
- generate,
324
- inputs=[
325
- prompt,
326
- negative_prompt,
327
- seed,
328
- custom_width,
329
- custom_height,
330
- guidance_scale,
331
- num_inference_steps,
332
- sampler,
333
- aspect_ratio_selector,
334
- use_upscaler,
335
- upscaler_strength,
336
- upscale_by,
337
- ],
338
- outputs=[output_image, output_text, history],
339
- concurrency_limit=1
340
  )
341
 
342
- # Add a function to update the history display
343
- def update_history(history_data):
344
- html = "<div class='history-container'>"
345
- for item in history_data:
346
- html += f"""
347
- <div class='history-item'>
348
- <img src='{item['thumbnail']}' alt='Generated Image'>
349
- <div class='history-info'>
350
- <p><strong>Prompt:</strong> {item['prompt']}</p>
351
- <p><strong>Negative Prompt:</strong> {item['negative_prompt']}</p>
352
- <p><strong>Seed:</strong> {item['seed']}</p>
353
- <p><strong>Size:</strong> {item['width']}x{item['height']}</p>
354
- </div>
355
- </div>
356
- """
357
- html += "</div>"
358
- return html
359
-
360
- # Connect the update_history function to the history component
361
- history.change(update_history, inputs=[history], outputs=[history])
362
-
363
- demo.launch(debug=True, max_threads=20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from datetime import datetime
13
  from diffusers.models import AutoencoderKL
14
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
 
 
 
15
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
  DESCRIPTION = "PonyDiffusion V6 XL"
20
  if not torch.cuda.is_available():
21
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
 
22
  IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
23
  HF_TOKEN = os.getenv("HF_TOKEN")
24
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
 
27
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
28
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
29
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
30
+
31
  MODEL = os.getenv(
32
  "MODEL",
33
  "https://huggingface.co/AstraliteHeart/pony-diffusion-v6/blob/main/v6.safetensors",
 
38
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
 
 
 
41
 
42
  def load_pipeline(model_name):
43
  vae = AutoencoderKL.from_pretrained(
 
49
  if MODEL.endswith(".safetensors")
50
  else StableDiffusionXLPipeline.from_pretrained
51
  )
52
+
53
  pipe = pipeline(
54
  model_name,
55
  vae=vae,
 
60
  use_auth_token=HF_TOKEN,
61
  variant="fp16",
62
  )
63
+
64
  pipe.to(device)
65
  return pipe
66
 
67
+
68
  @spaces.GPU
69
  def generate(
70
  prompt: str,
 
82
  progress=gr.Progress(track_tqdm=True),
83
  ) -> Image:
84
  generator = utils.seed_everything(seed)
85
+
86
  width, height = utils.aspect_ratio_handler(
87
+ aspect_ratio_selector,
88
+ custom_width,
89
+ custom_height,
90
  )
91
+
92
  width, height = utils.preprocess_image_dimensions(width, height)
93
+
94
  backup_scheduler = pipe.scheduler
95
  pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
96
 
97
  if use_upscaler:
98
  upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
 
99
  metadata = {
100
  "prompt": prompt,
101
  "negative_prompt": negative_prompt,
 
117
  }
118
  else:
119
  metadata["use_upscaler"] = None
 
120
  logger.info(json.dumps(metadata, indent=4))
121
 
122
  try:
 
154
  output_type="pil",
155
  ).images
156
 
157
+ if images and IS_COLAB:
158
+ for image in images:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  filepath = utils.save_image(image, metadata, OUTPUT_DIR)
160
  logger.info(f"Image saved as {filepath} with metadata")
161
 
162
+ return images, metadata
 
163
  except Exception as e:
164
  logger.exception(f"An error occurred: {e}")
165
  raise
 
169
  pipe.scheduler = backup_scheduler
170
  utils.free_memory()
171
 
172
+
173
  if torch.cuda.is_available():
174
  pipe = load_pipeline(MODEL)
175
  logger.info("Loaded on Device!")
 
178
 
179
  with gr.Blocks(css="style.css") as demo:
180
  title = gr.HTML(
181
+ f"""<h1><span>{DESCRIPTION}</span></h1>""",
182
+ elem_id="title",
183
  )
184
+ gr.Markdown(
185
+ f"""Gradio demo for ([Pony Diffusion V6]https://civitai.com/models/257749/pony-diffusion-v6-xl/)""",
186
+ elem_id="subtitle",
187
+ )
188
+ gr.DuplicateButton(
189
+ value="Duplicate Space for private use",
190
+ elem_id="duplicate-button",
191
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
192
+ )
193
+ with gr.Group():
194
+ with gr.Row():
195
+ prompt = gr.Text(
196
  label="Prompt",
197
  show_label=False,
198
+ max_lines=5,
199
  placeholder="Enter your prompt",
200
+ container=False,
201
  )
202
+ run_button = gr.Button(
203
+ "Generate",
204
+ variant="primary",
205
+ scale=0
 
206
  )
207
+ result = gr.Gallery(
208
+ label="Result",
209
+ columns=1,
210
+ preview=True,
211
+ show_label=False
212
+ )
213
+ with gr.Accordion(label="Advanced Settings", open=False):
214
+ negative_prompt = gr.Text(
215
+ label="Negative Prompt",
216
+ max_lines=5,
217
+ placeholder="Enter a negative prompt",
218
+ value=""
219
+ )
220
+ aspect_ratio_selector = gr.Radio(
221
+ label="Aspect Ratio",
222
+ choices=config.aspect_ratios,
223
+ value="1024 x 1024",
224
+ container=True,
225
+ )
226
+ with gr.Group(visible=False) as custom_resolution:
227
  with gr.Row():
228
  custom_width = gr.Slider(
229
  label="Width",
 
239
  step=8,
240
  value=1024,
241
  )
242
+ use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
243
+ with gr.Row() as upscaler_row:
244
+ upscaler_strength = gr.Slider(
245
+ label="Strength",
246
+ minimum=0,
247
+ maximum=1,
248
+ step=0.05,
249
+ value=0.55,
250
+ visible=False,
251
+ )
252
+ upscale_by = gr.Slider(
253
+ label="Upscale by",
254
+ minimum=1,
255
+ maximum=1.5,
256
+ step=0.1,
257
+ value=1.5,
258
+ visible=False,
259
+ )
260
 
261
+ sampler = gr.Dropdown(
262
+ label="Sampler",
263
+ choices=config.sampler_list,
264
+ interactive=True,
265
+ value="DPM++ 2M SDE Karras",
266
+ )
267
+ with gr.Row():
268
+ seed = gr.Slider(
269
+ label="Seed", minimum=0, maximum=utils.MAX_SEED, step=1, value=0
270
+ )
271
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
272
+ with gr.Group():
273
  with gr.Row():
274
  guidance_scale = gr.Slider(
275
+ label="Guidance scale",
276
+ minimum=1,
277
+ maximum=12,
278
+ step=0.1,
279
+ value=7.0,
280
  )
281
  num_inference_steps = gr.Slider(
282
+ label="Number of inference steps",
283
  minimum=1,
284
+ maximum=50,
285
  step=1,
286
+ value=28,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  )
288
+ with gr.Accordion(label="Generation Parameters", open=False):
289
+ gr_metadata = gr.JSON(label="Metadata", show_label=False)
290
+ gr.Examples(
291
+ examples=config.examples,
292
+ inputs=prompt,
293
+ outputs=[result, gr_metadata],
294
+ fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
295
+ cache_examples=CACHE_EXAMPLES,
296
+ )
297
+ use_upscaler.change(
298
+ fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
299
+ inputs=use_upscaler,
300
+ outputs=[upscaler_strength, upscale_by],
301
+ queue=False,
302
+ api_name=False,
303
+ )
304
+ aspect_ratio_selector.change(
305
+ fn=lambda x: gr.update(visible=x == "Custom"),
306
+ inputs=aspect_ratio_selector,
307
+ outputs=custom_resolution,
308
+ queue=False,
309
+ api_name=False,
 
 
 
 
 
 
 
 
310
  )
311
 
312
+ inputs = [
313
+ prompt,
314
+ negative_prompt,
315
+ seed,
316
+ custom_width,
317
+ custom_height,
318
+ guidance_scale,
319
+ num_inference_steps,
320
+ sampler,
321
+ aspect_ratio_selector,
322
+ use_upscaler,
323
+ upscaler_strength,
324
+ upscale_by,
325
+ ]
326
+
327
+ prompt.submit(
328
+ fn=utils.randomize_seed_fn,
329
+ inputs=[seed, randomize_seed],
330
+ outputs=seed,
331
+ queue=False,
332
+ api_name=False,
333
+ ).then(
334
+ fn=generate,
335
+ inputs=inputs,
336
+ outputs=result,
337
+ api_name="run",
338
+ )
339
+ negative_prompt.submit(
340
+ fn=utils.randomize_seed_fn,
341
+ inputs=[seed, randomize_seed],
342
+ outputs=seed,
343
+ queue=False,
344
+ api_name=False,
345
+ ).then(
346
+ fn=generate,
347
+ inputs=inputs,
348
+ outputs=result,
349
+ api_name=False,
350
+ )
351
+ run_button.click(
352
+ fn=utils.randomize_seed_fn,
353
+ inputs=[seed, randomize_seed],
354
+ outputs=seed,
355
+ queue=False,
356
+ api_name=False,
357
+ ).then(
358
+ fn=generate,
359
+ inputs=inputs,
360
+ outputs=[result, gr_metadata],
361
+ api_name=False,
362
+ )
363
+ demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)