rockeycoss commited on
Commit
1e50ca9
1 Parent(s): 0ab1c76
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +39 -35
  3. requirements.txt +1 -1
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🖼️🖌️
4
  colorFrom: yellow
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.27.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: yellow
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.31.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -1,9 +1,10 @@
1
-
2
  import json
3
  import webcolors
4
  import spaces
5
  import gradio as gr
6
  import os.path as osp
 
7
  from PIL import Image, ImageDraw, ImageFont
8
 
9
  import torch
@@ -64,6 +65,10 @@ font = ImageFont.truetype("assets/Arial.ttf", 20)
64
 
65
  device = "cuda"
66
 
 
 
 
 
67
  def import_model_class_from_model_name_or_path(
68
  pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder",
69
  ):
@@ -215,6 +220,18 @@ pipeline.scheduler = DPMSolverMultistepScheduler.from_pretrained(
215
 
216
  prompt_format = PromptFormat()
217
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  def get_pixels(
219
  box_sketch_template,
220
  evt: gr.SelectData
@@ -318,8 +335,6 @@ def exe_undo(
318
  return box_sketch_template
319
 
320
  def process_box():
321
- global stack
322
- global state
323
 
324
  visibilities = []
325
  for _ in range(MAX_TEXT_BOX + 1):
@@ -330,31 +345,19 @@ def process_box():
330
  # return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
331
  return [gr.update(visible=True), *visibilities]
332
 
333
- @spaces.GPU
 
334
  def generate_image(bg_prompt, bg_class, bg_tags, seed, *conditions):
335
- print(conditions)
336
-
337
- # 0 load model to cuda
338
- global pipeline
339
- if config.pretrained_vae_model_name_or_path is None:
340
- vae.to(device, dtype=torch.float32)
341
- else:
342
- vae.to(device, dtype=inference_dtype)
343
- text_encoder_one.to(device, dtype=inference_dtype)
344
- text_encoder_two.to(device, dtype=inference_dtype)
345
- byt5_model.to(device)
346
- unet.to(device, dtype=inference_dtype)
347
- pipeline = pipeline.to(device)
348
 
 
 
 
349
  # 1. parse input
350
- global state
351
- global stack
352
-
353
  prompts = []
354
  colors = []
355
  font_type = []
356
  bboxes = []
357
- num_boxes = len(stack) if len(stack[-1]) == 4 else len(stack) - 1
358
  for i in range(num_boxes):
359
  prompts.append(conditions[i])
360
  colors.append(conditions[i + MAX_TEXT_BOX])
@@ -373,10 +376,10 @@ def generate_image(bg_prompt, bg_class, bg_tags, seed, *conditions):
373
  raise gr.Error(f"Invalid style for text box {i + 1} !")
374
  bboxes.append(
375
  [
376
- stack[i][0] / 1024,
377
- stack[i][1] / 1024,
378
- (stack[i][2] - stack[i][0]) / 1024,
379
- (stack[i][3] - stack[i][1]) / 1024,
380
  ]
381
  )
382
  styles.append(
@@ -393,14 +396,11 @@ def generate_image(bg_prompt, bg_class, bg_tags, seed, *conditions):
393
  bg_prompt += " Tags: " + bg_tags
394
  text_prompt = prompt_format.format_prompt(prompts, styles)
395
 
396
- print(bg_prompt)
397
- print(text_prompt)
398
 
399
  # 4. inference
400
- if seed == -1:
401
- generator = torch.Generator(device=device)
402
- else:
403
- generator = torch.Generator(device=device).manual_seed(seed)
404
  with torch.cuda.amp.autocast():
405
  image = pipeline(
406
  prompt=bg_prompt,
@@ -411,6 +411,9 @@ def generate_image(bg_prompt, bg_class, bg_tags, seed, *conditions):
411
  generator=generator,
412
  text_attn_mask=None,
413
  ).images[0]
 
 
 
414
  return image
415
 
416
  def process_example(bg_prompt, bg_class, bg_tags, color_str, style_str, text_str, box_str, seed):
@@ -534,10 +537,10 @@ def main():
534
  choices=font_idx_list,
535
  ))
536
 
537
- seed_ = gr.Slider(label="Seed", minimum=-1, maximum=999999999, value=-1, step=1)
538
- button_generate = gr.Button("(2) I've finished my texts, colors and styles, generate!", elem_id="main_button", interactive=True)
539
 
540
- button_layout.click(process_box, inputs=[], outputs=[post_box, *color_row], queue=False)
541
 
542
  with gr.Column():
543
  output_image = gr.Image(label="Output Image", interactive=False)
@@ -570,7 +573,7 @@ def main():
570
  'LilitaOne, Sensei-Medium, Sensei-Medium, LilitaOne, LilitaOne, LilitaOne',
571
  "RSVP to +123-456-7890**********Olivia Wilson**********Baby Shower**********Please Join Us For a**********In Honoring**********23 November, 2021 | 03:00 PM Fauget Hotels",
572
  '[0.07112462006079028, 0.6462006079027356, 0.3373860182370821, 0.026747720364741642]; [0.07051671732522796, 0.38662613981762917, 0.37264437689969604, 0.059574468085106386]; [0.07234042553191489, 0.15623100303951368, 0.6547112462006079, 0.12401215805471125]; [0.0662613981762918, 0.06747720364741641, 0.3981762917933131, 0.035866261398176294]; [0.07051671732522796, 0.31550151975683893, 0.22006079027355624, 0.03951367781155015]; [0.06990881458966565, 0.48328267477203646, 0.39878419452887537, 0.1094224924012158]',
573
- 0,
574
  ],
575
  [
576
  'The image features a white background with a variety of colorful flowers and decorations. There are several pink flowers scattered throughout the scene, with some positioned closer to the top and others near the bottom. A blue flower can also be seen in the middle of the image. The overall composition creates a visually appealing and vibrant display.',
@@ -605,6 +608,7 @@ def main():
605
  ],
606
  outputs=[post_box, box_sketch_template, seed_, *color_row, *colors, *styles, *prompts],
607
  fn=process_example,
 
608
  run_on_click=True,
609
  label='Examples',
610
  )
 
1
+ import gc
2
  import json
3
  import webcolors
4
  import spaces
5
  import gradio as gr
6
  import os.path as osp
7
+ from copy import deepcopy
8
  from PIL import Image, ImageDraw, ImageFont
9
 
10
  import torch
 
65
 
66
  device = "cuda"
67
 
68
+ def flush():
69
+ gc.collect()
70
+ torch.cuda.empty_cache()
71
+
72
  def import_model_class_from_model_name_or_path(
73
  pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder",
74
  ):
 
220
 
221
  prompt_format = PromptFormat()
222
 
223
+ # move to gpu
224
+ if config.pretrained_vae_model_name_or_path is None:
225
+ vae = vae.to(device, dtype=torch.float32)
226
+ else:
227
+ vae = vae.to(device, dtype=inference_dtype)
228
+ text_encoder_one = text_encoder_one.to(device, dtype=inference_dtype)
229
+ text_encoder_two = text_encoder_two.to(device, dtype=inference_dtype)
230
+ byt5_model = byt5_model.to(device)
231
+ unet = unet.to(device, dtype=inference_dtype)
232
+ pipeline = pipeline.to(device)
233
+
234
+
235
  def get_pixels(
236
  box_sketch_template,
237
  evt: gr.SelectData
 
335
  return box_sketch_template
336
 
337
  def process_box():
 
 
338
 
339
  visibilities = []
340
  for _ in range(MAX_TEXT_BOX + 1):
 
345
  # return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
346
  return [gr.update(visible=True), *visibilities]
347
 
348
+ @torch.inference_mode()
349
+ @spaces.GPU(enable_queue=True)
350
  def generate_image(bg_prompt, bg_class, bg_tags, seed, *conditions):
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
+ stack_cp = deepcopy(stack)
353
+ print(f"conditions: {conditions}")
354
+
355
  # 1. parse input
 
 
 
356
  prompts = []
357
  colors = []
358
  font_type = []
359
  bboxes = []
360
+ num_boxes = len(stack_cp) if len(stack_cp[-1]) == 4 else len(stack_cp) - 1
361
  for i in range(num_boxes):
362
  prompts.append(conditions[i])
363
  colors.append(conditions[i + MAX_TEXT_BOX])
 
376
  raise gr.Error(f"Invalid style for text box {i + 1} !")
377
  bboxes.append(
378
  [
379
+ stack_cp[i][0] / 1024,
380
+ stack_cp[i][1] / 1024,
381
+ (stack_cp[i][2] - stack_cp[i][0]) / 1024,
382
+ (stack_cp[i][3] - stack_cp[i][1]) / 1024,
383
  ]
384
  )
385
  styles.append(
 
396
  bg_prompt += " Tags: " + bg_tags
397
  text_prompt = prompt_format.format_prompt(prompts, styles)
398
 
399
+ print(f"bg_prompt: {bg_prompt}")
400
+ print(f"text_prompt: {text_prompt}")
401
 
402
  # 4. inference
403
+ generator = torch.Generator(device=device).manual_seed(int(seed))
 
 
 
404
  with torch.cuda.amp.autocast():
405
  image = pipeline(
406
  prompt=bg_prompt,
 
411
  generator=generator,
412
  text_attn_mask=None,
413
  ).images[0]
414
+
415
+ flush()
416
+
417
  return image
418
 
419
  def process_example(bg_prompt, bg_class, bg_tags, color_str, style_str, text_str, box_str, seed):
 
537
  choices=font_idx_list,
538
  ))
539
 
540
+ seed_ = gr.Slider(label="Seed", minimum=0, maximum=2147483647, value=42, step=1)
541
+ button_generate = gr.Button("(2) I've finished my texts, colors and styles, generate!", elem_id="main_button", interactive=True, variant='primary')
542
 
543
+ button_layout.click(process_box, inputs=[], outputs=[post_box, *color_row])
544
 
545
  with gr.Column():
546
  output_image = gr.Image(label="Output Image", interactive=False)
 
573
  'LilitaOne, Sensei-Medium, Sensei-Medium, LilitaOne, LilitaOne, LilitaOne',
574
  "RSVP to +123-456-7890**********Olivia Wilson**********Baby Shower**********Please Join Us For a**********In Honoring**********23 November, 2021 | 03:00 PM Fauget Hotels",
575
  '[0.07112462006079028, 0.6462006079027356, 0.3373860182370821, 0.026747720364741642]; [0.07051671732522796, 0.38662613981762917, 0.37264437689969604, 0.059574468085106386]; [0.07234042553191489, 0.15623100303951368, 0.6547112462006079, 0.12401215805471125]; [0.0662613981762918, 0.06747720364741641, 0.3981762917933131, 0.035866261398176294]; [0.07051671732522796, 0.31550151975683893, 0.22006079027355624, 0.03951367781155015]; [0.06990881458966565, 0.48328267477203646, 0.39878419452887537, 0.1094224924012158]',
576
+ 1,
577
  ],
578
  [
579
  'The image features a white background with a variety of colorful flowers and decorations. There are several pink flowers scattered throughout the scene, with some positioned closer to the top and others near the bottom. A blue flower can also be seen in the middle of the image. The overall composition creates a visually appealing and vibrant display.',
 
608
  ],
609
  outputs=[post_box, box_sketch_template, seed_, *color_row, *colors, *styles, *prompts],
610
  fn=process_example,
611
+ cache_examples=False,
612
  run_on_click=True,
613
  label='Examples',
614
  )
requirements.txt CHANGED
@@ -7,4 +7,4 @@ torchvision==0.17.0
7
  deepspeed
8
  peft
9
  webcolors
10
- gradio
 
7
  deepspeed
8
  peft
9
  webcolors
10
+ gradio==4.31.1