r3gm commited on
Commit
481fe70
·
verified ·
1 Parent(s): 1a95355

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -27
app.py CHANGED
@@ -21,6 +21,7 @@ from stablepy import (
21
  SD15_TASKS,
22
  SDXL_TASKS,
23
  )
 
24
  # import urllib.parse
25
 
26
  # - **Download SD 1.5 Models**
@@ -88,6 +89,11 @@ load_diffusers_format_model = [
88
  'GraydientPlatformAPI/realcartoon-real17',
89
  ]
90
 
 
 
 
 
 
91
  CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
92
  HF_TOKEN = os.environ.get("HF_READ_TOKEN")
93
 
@@ -302,6 +308,7 @@ model_list = get_model_list(directory_models)
302
  model_list = load_diffusers_format_model + model_list
303
  lora_model_list = get_model_list(directory_loras)
304
  lora_model_list.insert(0, "None")
 
305
  vae_model_list = get_model_list(directory_vaes)
306
  vae_model_list.insert(0, "None")
307
 
@@ -404,6 +411,7 @@ def get_my_lora(link_url):
404
  download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
405
  new_lora_model_list = get_model_list(directory_loras)
406
  new_lora_model_list.insert(0, "None")
 
407
 
408
  return gr.update(
409
  choices=new_lora_model_list
@@ -467,7 +475,7 @@ class GuiSD:
467
  if vae_model:
468
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
469
  if model_type != vae_type:
470
- gr.Info(msg_inc_vae)
471
 
472
  self.model.device = torch.device("cpu")
473
  dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
@@ -482,7 +490,7 @@ class GuiSD:
482
 
483
  yield f"Model loaded: {model_name}"
484
 
485
- @spaces.GPU(duration=59)
486
  @torch.inference_mode()
487
  def generate_pipeline(
488
  self,
@@ -593,7 +601,7 @@ class GuiSD:
593
  vae_model = vae_model if vae_model != "None" else None
594
  loras_list = [lora1, lora2, lora3, lora4, lora5]
595
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
596
- msg_lora = []
597
 
598
  print("Config model:", model_name, vae_model, loras_list)
599
 
@@ -756,11 +764,18 @@ class GuiSD:
756
  for img, seed, image_path, metadata in self.model(**pipe_params):
757
  info_state += ">"
758
  if image_path:
759
- info_state = f"COMPLETED. Seeds: {str(seed)}"
760
  if vae_msg:
761
  info_state = info_state + "<br>" + vae_msg
 
 
 
 
 
 
 
762
  if msg_lora:
763
- info_state = info_state + "<br>" + "<br>".join(msg_lora)
764
 
765
  info_state = info_state + "<br>" + "GENERATION DATA:<br>" + "<br>-------<br>".join(metadata).replace("\n", "<br>")
766
 
@@ -785,32 +800,90 @@ def update_task_options(model_name, task_name):
785
  return gr.update(value=task_name, choices=new_choices)
786
 
787
 
788
- # def sd_gen_generate_pipeline(*args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789
 
790
- # # Load lora in CPU
791
- # status_lora = sd_gen.model.lora_merge(
792
- # lora_A=args[7] if args[7] != "None" else None, lora_scale_A=args[8],
793
- # lora_B=args[9] if args[9] != "None" else None, lora_scale_B=args[10],
794
- # lora_C=args[11] if args[11] != "None" else None, lora_scale_C=args[12],
795
- # lora_D=args[13] if args[13] != "None" else None, lora_scale_D=args[14],
796
- # lora_E=args[15] if args[15] != "None" else None, lora_scale_E=args[16],
797
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798
 
799
- # lora_list = [args[7], args[9], args[11], args[13], args[15]]
800
- # print(status_lora)
801
- # for status, lora in zip(status_lora, lora_list):
802
- # if status:
803
- # gr.Info(f"LoRA loaded: {lora}")
804
- # elif status is not None:
805
- # gr.Warning(f"Failed to load LoRA: {lora}")
806
 
807
- # # if status_lora == [None] * 5 and self.model.lora_memory != [None] * 5:
808
- # # gr.Info(f"LoRAs in cache: {", ".join(str(x) for x in self.model.lora_memory if x is not None)}")
 
 
 
 
 
809
 
810
- # yield from sd_gen.generate_pipeline(*args)
 
 
 
 
 
 
811
 
812
 
813
- # sd_gen_generate_pipeline.zerogpu = True
 
814
  sd_gen = GuiSD()
815
 
816
  with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
@@ -854,6 +927,12 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
854
 
855
  actual_task_info = gr.HTML()
856
 
 
 
 
 
 
 
857
  with gr.Column(scale=1):
858
  steps_gui = gr.Slider(minimum=1, maximum=100, step=1, value=30, label="Steps")
859
  cfg_gui = gr.Slider(minimum=0, maximum=30, step=0.5, value=7.5, label="CFG")
@@ -1172,6 +1251,8 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1172
  1.0, # cn end
1173
  "Classic",
1174
  "Nearest",
 
 
1175
  ],
1176
  [
1177
  "a digital illustration of a movie poster titled 'Finding Emo', finding nemo parody poster, featuring a depressed cartoon clownfish with black emo hair, eyeliner, and piercings, bored expression, swimming in a dark underwater scene, in the background, movie title in a dripping, grungy font, moody blue and purple color palette",
@@ -1194,6 +1275,8 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1194
  1.0, # cn end
1195
  "Classic",
1196
  None,
 
 
1197
  ],
1198
  [
1199
  "((masterpiece)), best quality, blonde disco girl, detailed face, realistic face, realistic hair, dynamic pose, pink pvc, intergalactic disco background, pastel lights, dynamic contrast, airbrush, fine detail, 70s vibe, midriff",
@@ -1216,6 +1299,8 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1216
  1.0, # cn end
1217
  "Classic",
1218
  None,
 
 
1219
  ],
1220
  [
1221
  "cinematic scenery old city ruins",
@@ -1238,6 +1323,8 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1238
  0.75, # cn end
1239
  "Classic",
1240
  None,
 
 
1241
  ],
1242
  [
1243
  "black and white, line art, coloring drawing, clean line art, black strokes, no background, white, black, free lines, black scribbles, on paper, A blend of comic book art and lineart full of black and white color, masterpiece, high-resolution, trending on Pixiv fan box, palette knife, brush strokes, two-dimensional, planar vector, T-shirt design, stickers, and T-shirt design, vector art, fantasy art, Adobe Illustrator, hand-painted, digital painting, low polygon, soft lighting, aerial view, isometric style, retro aesthetics, 8K resolution, black sketch lines, monochrome, invert color",
@@ -1260,6 +1347,8 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1260
  1.0, # cn end
1261
  "Compel",
1262
  None,
 
 
1263
  ],
1264
  [
1265
  "1girl,face,curly hair,red hair,white background,",
@@ -1282,6 +1371,8 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1282
  0.9, # cn end
1283
  "Compel",
1284
  "Latent (antialiased)",
 
 
1285
  ],
1286
  ],
1287
  fn=sd_gen.generate_pipeline,
@@ -1306,6 +1397,8 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1306
  control_net_stop_threshold_gui,
1307
  prompt_syntax_gui,
1308
  upscaler_model_path_gui,
 
 
1309
  ],
1310
  outputs=[result_images, actual_task_info],
1311
  cache_examples=False,
@@ -1385,7 +1478,7 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1385
  queue=True,
1386
  show_progress="minimal",
1387
  ).success(
1388
- fn=sd_gen.generate_pipeline, # fn=sd_gen_generate_pipeline,
1389
  inputs=[
1390
  prompt_gui,
1391
  neg_prompt_gui,
@@ -1489,6 +1582,9 @@ with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
1489
  mode_ip2,
1490
  scale_ip2,
1491
  pag_scale_gui,
 
 
 
1492
  ],
1493
  outputs=[result_images, actual_task_info],
1494
  queue=True,
@@ -1501,4 +1597,4 @@ app.launch(
1501
  show_error=True,
1502
  debug=True,
1503
  allowed_paths=["./images/"],
1504
- )
 
21
  SD15_TASKS,
22
  SDXL_TASKS,
23
  )
24
+ import time
25
  # import urllib.parse
26
 
27
  # - **Download SD 1.5 Models**
 
89
  'GraydientPlatformAPI/realcartoon-real17',
90
  ]
91
 
92
+ DIFFUSERS_FORMAT_LORAS = [
93
+ "nerijs/animation2k-flux",
94
+ "XLabs-AI/flux-RealismLora",
95
+ ]
96
+
97
  CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
98
  HF_TOKEN = os.environ.get("HF_READ_TOKEN")
99
 
 
308
  model_list = load_diffusers_format_model + model_list
309
  lora_model_list = get_model_list(directory_loras)
310
  lora_model_list.insert(0, "None")
311
+ lora_model_list = lora_model_list + DIFFUSERS_FORMAT_LORAS
312
  vae_model_list = get_model_list(directory_vaes)
313
  vae_model_list.insert(0, "None")
314
 
 
411
  download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
412
  new_lora_model_list = get_model_list(directory_loras)
413
  new_lora_model_list.insert(0, "None")
414
+ new_lora_model_list = new_lora_model_list + DIFFUSERS_FORMAT_LORAS
415
 
416
  return gr.update(
417
  choices=new_lora_model_list
 
475
  if vae_model:
476
  vae_type = "SDXL" if "sdxl" in vae_model.lower() else "SD 1.5"
477
  if model_type != vae_type:
478
+ gr.Warning(msg_inc_vae)
479
 
480
  self.model.device = torch.device("cpu")
481
  dtype_model = torch.bfloat16 if model_type == "FLUX" else torch.float16
 
490
 
491
  yield f"Model loaded: {model_name}"
492
 
493
+ # @spaces.GPU(duration=59)
494
  @torch.inference_mode()
495
  def generate_pipeline(
496
  self,
 
601
  vae_model = vae_model if vae_model != "None" else None
602
  loras_list = [lora1, lora2, lora3, lora4, lora5]
603
  vae_msg = f"VAE: {vae_model}" if vae_model else ""
604
+ msg_lora = ""
605
 
606
  print("Config model:", model_name, vae_model, loras_list)
607
 
 
764
  for img, seed, image_path, metadata in self.model(**pipe_params):
765
  info_state += ">"
766
  if image_path:
767
+ info_state = f"COMPLETE. Seeds: {str(seed)}"
768
  if vae_msg:
769
  info_state = info_state + "<br>" + vae_msg
770
+
771
+ for status, lora in zip(self.model.lora_status, self.model.lora_memory):
772
+ if status:
773
+ msg_lora += f"<br>Loaded: {lora}"
774
+ elif status is not None:
775
+ msg_lora += f"<br>Error with: {lora}"
776
+
777
  if msg_lora:
778
+ info_state += msg_lora
779
 
780
  info_state = info_state + "<br>" + "GENERATION DATA:<br>" + "<br>-------<br>".join(metadata).replace("\n", "<br>")
781
 
 
800
  return gr.update(value=task_name, choices=new_choices)
801
 
802
 
803
+ def dynamic_gpu_duration(func, duration, *args):
804
+
805
+ @spaces.GPU(duration=duration)
806
+ def wrapped_func():
807
+ yield from func(*args)
808
+
809
+ return wrapped_func()
810
+
811
+
812
+ @spaces.GPU
813
+ def dummy_gpu():
814
+ return None
815
+
816
+
817
+ def sd_gen_generate_pipeline(*args):
818
 
819
+ gpu_duration_arg = int(args[-1]) if args[-1] else 59
820
+ verbose_arg = int(args[-2])
821
+ load_lora_cpu = args[-3]
822
+ generation_args = args[:-3]
823
+ lora_list = [
824
+ None if item == "None" else item
825
+ for item in [args[7], args[9], args[11], args[13], args[15]]
826
+ ]
827
+ lora_status = [None] * 5
828
+
829
+ msg_load_lora = "Updating LoRAs in GPU..."
830
+ if load_lora_cpu:
831
+ msg_load_lora = "Updating LoRAs in CPU (Slow but saves GPU usage)..."
832
+
833
+ if lora_list != sd_gen.model.lora_memory and lora_list != [None] * 5:
834
+ yield None, msg_load_lora
835
+
836
+ # Load lora in CPU
837
+ if load_lora_cpu:
838
+ lora_status = sd_gen.model.lora_merge(
839
+ lora_A=lora_list[0], lora_scale_A=args[8],
840
+ lora_B=lora_list[1], lora_scale_B=args[10],
841
+ lora_C=lora_list[2], lora_scale_C=args[12],
842
+ lora_D=lora_list[3], lora_scale_D=args[14],
843
+ lora_E=lora_list[4], lora_scale_E=args[16],
844
+ )
845
+ print(lora_status)
846
+
847
+ if verbose_arg:
848
+ for status, lora in zip(lora_status, lora_list):
849
+ if status:
850
+ gr.Info(f"LoRA loaded in CPU: {lora}")
851
+ elif status is not None:
852
+ gr.Warning(f"Failed to load LoRA: {lora}")
853
+
854
+ if lora_status == [None] * 5 and sd_gen.model.lora_memory != [None] * 5 and load_lora_cpu:
855
+ lora_cache_msg = ", ".join(
856
+ str(x) for x in sd_gen.model.lora_memory if x is not None
857
+ )
858
+ gr.Info(f"LoRAs in cache: {lora_cache_msg}")
859
+
860
+ msg_request = f"Requesting {gpu_duration_arg}s. of GPU time"
861
+ gr.Info(msg_request)
862
+ print(msg_request)
863
+
864
+ # yield from sd_gen.generate_pipeline(*generation_args)
865
 
866
+ start_time = time.time()
 
 
 
 
 
 
867
 
868
+ yield from dynamic_gpu_duration(
869
+ sd_gen.generate_pipeline,
870
+ gpu_duration_arg,
871
+ *generation_args,
872
+ )
873
+
874
+ end_time = time.time()
875
 
876
+ if verbose_arg:
877
+ execution_time = end_time - start_time
878
+ msg_task_complete = (
879
+ f"GPU task complete in: {round(execution_time, 0) + 1} seconds"
880
+ )
881
+ gr.Info(msg_task_complete)
882
+ print(msg_task_complete)
883
 
884
 
885
+ dynamic_gpu_duration.zerogpu = True
886
+ sd_gen_generate_pipeline.zerogpu = True
887
  sd_gen = GuiSD()
888
 
889
  with gr.Blocks(theme="NoCrypt/miku", css=CSS) as app:
 
927
 
928
  actual_task_info = gr.HTML()
929
 
930
+ with gr.Row(equal_height=False, variant="default"):
931
+ gpu_duration_gui = gr.Number(minimum=5, maximum=240, value=59, show_label=False, container=False, info="GPU time duration (seconds)")
932
+ with gr.Column():
933
+ verbose_info_gui = gr.Checkbox(value=False, container=False, label="Status info")
934
+ load_lora_cpu_gui = gr.Checkbox(value=False, container=False, label="Load LoRAs on CPU (Save GPU time)")
935
+
936
  with gr.Column(scale=1):
937
  steps_gui = gr.Slider(minimum=1, maximum=100, step=1, value=30, label="Steps")
938
  cfg_gui = gr.Slider(minimum=0, maximum=30, step=0.5, value=7.5, label="CFG")
 
1251
  1.0, # cn end
1252
  "Classic",
1253
  "Nearest",
1254
+ 45,
1255
+ False,
1256
  ],
1257
  [
1258
  "a digital illustration of a movie poster titled 'Finding Emo', finding nemo parody poster, featuring a depressed cartoon clownfish with black emo hair, eyeliner, and piercings, bored expression, swimming in a dark underwater scene, in the background, movie title in a dripping, grungy font, moody blue and purple color palette",
 
1275
  1.0, # cn end
1276
  "Classic",
1277
  None,
1278
+ 70,
1279
+ True,
1280
  ],
1281
  [
1282
  "((masterpiece)), best quality, blonde disco girl, detailed face, realistic face, realistic hair, dynamic pose, pink pvc, intergalactic disco background, pastel lights, dynamic contrast, airbrush, fine detail, 70s vibe, midriff",
 
1299
  1.0, # cn end
1300
  "Classic",
1301
  None,
1302
+ 44,
1303
+ False,
1304
  ],
1305
  [
1306
  "cinematic scenery old city ruins",
 
1323
  0.75, # cn end
1324
  "Classic",
1325
  None,
1326
+ 35,
1327
+ False,
1328
  ],
1329
  [
1330
  "black and white, line art, coloring drawing, clean line art, black strokes, no background, white, black, free lines, black scribbles, on paper, A blend of comic book art and lineart full of black and white color, masterpiece, high-resolution, trending on Pixiv fan box, palette knife, brush strokes, two-dimensional, planar vector, T-shirt design, stickers, and T-shirt design, vector art, fantasy art, Adobe Illustrator, hand-painted, digital painting, low polygon, soft lighting, aerial view, isometric style, retro aesthetics, 8K resolution, black sketch lines, monochrome, invert color",
 
1347
  1.0, # cn end
1348
  "Compel",
1349
  None,
1350
+ 35,
1351
+ False,
1352
  ],
1353
  [
1354
  "1girl,face,curly hair,red hair,white background,",
 
1371
  0.9, # cn end
1372
  "Compel",
1373
  "Latent (antialiased)",
1374
+ 46,
1375
+ False,
1376
  ],
1377
  ],
1378
  fn=sd_gen.generate_pipeline,
 
1397
  control_net_stop_threshold_gui,
1398
  prompt_syntax_gui,
1399
  upscaler_model_path_gui,
1400
+ gpu_duration_gui,
1401
+ load_lora_cpu_gui,
1402
  ],
1403
  outputs=[result_images, actual_task_info],
1404
  cache_examples=False,
 
1478
  queue=True,
1479
  show_progress="minimal",
1480
  ).success(
1481
+ fn=sd_gen_generate_pipeline, # fn=sd_gen.generate_pipeline,
1482
  inputs=[
1483
  prompt_gui,
1484
  neg_prompt_gui,
 
1582
  mode_ip2,
1583
  scale_ip2,
1584
  pag_scale_gui,
1585
+ load_lora_cpu_gui,
1586
+ verbose_info_gui,
1587
+ gpu_duration_gui,
1588
  ],
1589
  outputs=[result_images, actual_task_info],
1590
  queue=True,
 
1597
  show_error=True,
1598
  debug=True,
1599
  allowed_paths=["./images/"],
1600
+ )