multimodalart HF staff commited on
Commit
8ae8508
·
verified ·
1 Parent(s): 5cd0360

Add custom

Browse files
Files changed (1) hide show
  1. app.py +165 -27
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  torch.jit.script = lambda f: f
4
  import timm
5
  import time
6
- from huggingface_hub import hf_hub_download
7
  from safetensors.torch import load_file
8
  from share_btn import community_icon_html, loading_icon_html, share_js
9
  from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
@@ -16,6 +16,8 @@ import random
16
  from urllib.parse import quote
17
  import gdown
18
  import os
 
 
19
 
20
  import diffusers
21
  from diffusers.utils import load_image
@@ -155,6 +157,8 @@ button.addEventListener('click', function() {
155
  element.classList.add('selected');
156
  });
157
  '''
 
 
158
  def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative, is_new=False):
159
  lora_repo = sdxl_loras[selected_state.index]["repo"]
160
  new_placeholder = "Type a prompt to use your selected LoRA"
@@ -197,8 +201,8 @@ def center_crop_image_as_square(img):
197
  img_cropped = img.crop((left, top, right, bottom))
198
  return img_cropped
199
 
200
- def check_selected(selected_state):
201
- if not selected_state:
202
  raise gr.Error("You must select a style")
203
 
204
  def merge_incompatible_lora(full_path_lora, lora_scale):
@@ -224,6 +228,7 @@ def merge_incompatible_lora(full_path_lora, lora_scale):
224
  del lora_model
225
  @spaces.GPU
226
  def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index, st):
 
227
  et = time.time()
228
  elapsed_time = et - st
229
  print('Getting into the decorated function took: ', elapsed_time, 'seconds')
@@ -303,8 +308,10 @@ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_stren
303
  last_lora = repo_name
304
  return image
305
 
306
- def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, progress=gr.Progress(track_tqdm=True)):
307
- selected_state_index = selected_state.index
 
 
308
  st = time.time()
309
  face_image = center_crop_image_as_square(face_image)
310
  try:
@@ -319,28 +326,35 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
319
  print('Cropping and calculating face embeds took: ', elapsed_time, 'seconds')
320
 
321
  st = time.time()
322
- for lora_list in lora_defaults:
323
- if lora_list["model"] == sdxl_loras[selected_state_index]["repo"]:
324
- prompt_full = lora_list.get("prompt", None)
325
- if(prompt_full):
326
- prompt = prompt_full.replace("<subject>", prompt)
327
 
328
-
 
 
 
 
 
 
 
 
329
  print("Prompt:", prompt)
330
  if(prompt == ""):
331
  prompt = "a person"
332
 
333
- print("Selected State: ", selected_state_index)
334
- print(sdxl_loras[selected_state_index]["repo"])
335
  if negative == "":
336
  negative = None
337
-
338
- if not selected_state:
339
- raise gr.Error("You must select a LoRA")
340
- repo_name = sdxl_loras[selected_state_index]["repo"]
341
- weight_name = sdxl_loras[selected_state_index]["weights"]
342
-
343
- full_path_lora = state_dicts[repo_name]["saved_name"]
 
 
 
 
344
  #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
345
  cross_attention_kwargs = None
346
  et = time.time()
@@ -368,6 +382,117 @@ def swap_gallery(order, sdxl_loras):
368
  def deselect():
369
  return gr.Gallery(selected_index=None)
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  with gr.Blocks(css="custom.css") as demo:
372
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
373
  title = gr.HTML(
@@ -381,8 +506,9 @@ with gr.Blocks(css="custom.css") as demo:
381
  elem_id="title",
382
  )
383
  selected_state = gr.State()
 
384
  with gr.Row(elem_id="main_app"):
385
- with gr.Column(scale=4):
386
  with gr.Group(elem_id="gallery_box"):
387
  photo = gr.Image(label="Upload a picture of yourself", interactive=True, type="pil", height=300)
388
  selected_loras = gr.Gallery(label="Selected LoRAs", height=80, show_share_button=False, visible=False, elem_id="gallery_selected", )
@@ -394,14 +520,16 @@ with gr.Blocks(css="custom.css") as demo:
394
  # value=[(item["image"], item["title"]) for item in sdxl_loras_raw_new], allow_preview=False, show_share_button=False)
395
  gallery = gr.Gallery(
396
  #value=[(item["image"], item["title"]) for item in sdxl_loras],
397
- label="Style gallery",
398
  allow_preview=False,
399
  columns=4,
400
  elem_id="gallery",
401
  show_share_button=False,
402
  height=550
403
  )
404
- custom_model = gr.Textbox(label="Enter a custom Hugging Face or CivitAI SDXL LoRA", interactive=False, placeholder="Coming soon...")
 
 
405
  with gr.Column(scale=5):
406
  with gr.Row():
407
  prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt")
@@ -431,6 +559,16 @@ with gr.Blocks(css="custom.css") as demo:
431
  # outputs=[gallery, gr_sdxl_loras],
432
  # queue=False
433
  #)
 
 
 
 
 
 
 
 
 
 
434
  gallery.select(
435
  fn=update_selection,
436
  inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
@@ -447,22 +585,22 @@ with gr.Blocks(css="custom.css") as demo:
447
  #)
448
  prompt.submit(
449
  fn=check_selected,
450
- inputs=[selected_state],
451
  queue=False,
452
  show_progress=False
453
  ).success(
454
  fn=run_lora,
455
- inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras],
456
  outputs=[result, share_group],
457
  )
458
  button.click(
459
  fn=check_selected,
460
- inputs=[selected_state],
461
  queue=False,
462
  show_progress=False
463
  ).success(
464
  fn=run_lora,
465
- inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras],
466
  outputs=[result, share_group],
467
  )
468
  share_button.click(None, [], [], js=share_js)
 
3
  torch.jit.script = lambda f: f
4
  import timm
5
  import time
6
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard
7
  from safetensors.torch import load_file
8
  from share_btn import community_icon_html, loading_icon_html, share_js
9
  from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
 
16
  from urllib.parse import quote
17
  import gdown
18
  import os
19
+ import re
20
+ import requests
21
 
22
  import diffusers
23
  from diffusers.utils import load_image
 
157
  element.classList.add('selected');
158
  });
159
  '''
160
+ lora_archive = "/data"
161
+
162
  def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative, is_new=False):
163
  lora_repo = sdxl_loras[selected_state.index]["repo"]
164
  new_placeholder = "Type a prompt to use your selected LoRA"
 
201
  img_cropped = img.crop((left, top, right, bottom))
202
  return img_cropped
203
 
204
+ def check_selected(selected_state, custom_lora):
205
+ if not selected_state and not custom_lora:
206
  raise gr.Error("You must select a style")
207
 
208
  def merge_incompatible_lora(full_path_lora, lora_scale):
 
228
  del lora_model
229
  @spaces.GPU
230
  def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index, st):
231
+ print(loaded_state_dict)
232
  et = time.time()
233
  elapsed_time = et - st
234
  print('Getting into the decorated function took: ', elapsed_time, 'seconds')
 
308
  last_lora = repo_name
309
  return image
310
 
311
+ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, custom_lora, progress=gr.Progress(track_tqdm=True)):
312
+ print("Custom LoRA: ", custom_lora)
313
+ custom_lora_path = custom_lora[0] if custom_lora else None
314
+ selected_state_index = selected_state.index if selected_state else -1
315
  st = time.time()
316
  face_image = center_crop_image_as_square(face_image)
317
  try:
 
326
  print('Cropping and calculating face embeds took: ', elapsed_time, 'seconds')
327
 
328
  st = time.time()
 
 
 
 
 
329
 
330
+ if(custom_lora_path):
331
+ prompt = f"{prompt} {custom_lora[1]}"
332
+ else:
333
+ for lora_list in lora_defaults:
334
+ if lora_list["model"] == sdxl_loras[selected_state_index]["repo"]:
335
+ prompt_full = lora_list.get("prompt", None)
336
+ if(prompt_full):
337
+ prompt = prompt_full.replace("<subject>", prompt)
338
+
339
  print("Prompt:", prompt)
340
  if(prompt == ""):
341
  prompt = "a person"
342
 
343
+ #print("Selected State: ", selected_state_index)
344
+ #print(sdxl_loras[selected_state_index]["repo"])
345
  if negative == "":
346
  negative = None
347
+ print("Custom Loaded LoRA: ", custom_lora_path)
348
+ if not selected_state and not custom_lora_path:
349
+ raise gr.Error("You must select a style")
350
+ elif custom_lora_path:
351
+ repo_name = custom_lora_path
352
+ full_path_lora = custom_lora_path
353
+ else:
354
+ repo_name = sdxl_loras[selected_state_index]["repo"]
355
+ weight_name = sdxl_loras[selected_state_index]["weights"]
356
+ full_path_lora = state_dicts[repo_name]["saved_name"]
357
+ print("Full path LoRA ", full_path_lora)
358
  #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
359
  cross_attention_kwargs = None
360
  et = time.time()
 
382
  def deselect():
383
  return gr.Gallery(selected_index=None)
384
 
385
+ def get_huggingface_safetensors(link):
386
+ split_link = link.split("/")
387
+ if(len(split_link) == 2):
388
+ model_card = ModelCard.load(link)
389
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
390
+ trigger_word = model_card.data.get("instance_prompt", "")
391
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
392
+ fs = HfFileSystem()
393
+ try:
394
+ list_of_files = fs.ls(link, detail=False)
395
+ for file in list_of_files:
396
+ if(file.endswith(".safetensors")):
397
+ safetensors_name = file.replace("/", "_")
398
+ if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
399
+ fs.get_file(file, lpath=f"{lora_archive}/{safetensors_name}")
400
+ if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
401
+ image_elements = file.split("/")
402
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
403
+ except:
404
+ gr.Warning("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
405
+ raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
406
+ return split_link[1], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
407
+
408
+ def get_civitai_safetensors(link):
409
+ link_split = link.split("civitai.com/")
410
+ pattern = re.compile(r'models\/(\d+)')
411
+ regex_match = pattern.search(link_split[1])
412
+ if(regex_match):
413
+ civitai_model_id = regex_match.group(1)
414
+ else:
415
+ gr.Warning("No CivitAI model id found in your URL")
416
+ raise Exception("No CivitAI model id found in your URL")
417
+ model_request_url = f"https://civitai.com/api/v1/models/{civitai_model_id}?token={os.getenv('CIVITAI_TOKEN')}"
418
+ x = requests.get(model_request_url)
419
+ if(x.status_code != 200):
420
+ raise Exception("Invalid CivitAI URL")
421
+ model_data = x.json()
422
+ if(model_data["nsfw"] == True or model_data["nsfwLevel"] > 2):
423
+ gr.Warning("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.")
424
+ raise Exception("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.")
425
+ elif(model_data["type"] != "LORA"):
426
+ gr.Warning("The model isn't tagged at CivitAI as a LoRA")
427
+ raise Exception("The model isn't tagged at CivitAI as a LoRA")
428
+ model_link_download = None
429
+ image_url = None
430
+ trigger_word = ""
431
+ for model in model_data["modelVersions"]:
432
+ if(model["baseModel"] == "SDXL 1.0"):
433
+ model_link_download = f"{model['downloadUrl']}/?token={os.getenv('CIVITAI_TOKEN')}"
434
+ safetensors_name = model["files"][0]["name"]
435
+ if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
436
+ safetensors_file_request = requests.get(model_link_download)
437
+ if(safetensors_file_request.status_code != 200):
438
+ raise Exception("Invalid CivitAI download link")
439
+ with open(f"{lora_archive}/{safetensors_name}", 'wb') as file:
440
+ file.write(safetensors_file_request.content)
441
+ trigger_word = model.get("trainedWords", [""])[0]
442
+ for image in model["images"]:
443
+ if(image["nsfwLevel"] == 1):
444
+ image_url = image["url"]
445
+ break
446
+ break
447
+ if(not model_link_download):
448
+ gr.Warning("We couldn't find a SDXL LoRA on the model you've sent")
449
+ raise Exception("We couldn't find a SDXL LoRA on the model you've sent")
450
+ return model_data["name"], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
451
+
452
+ def check_custom_model(link):
453
+ try:
454
+ if(link.startswith("https://")):
455
+ if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
456
+ link_split = link.split("huggingface.co/")
457
+ return get_huggingface_safetensors(link_split[1])
458
+ elif(link.startswith("https://civitai.com") or link.startswith("https://www.civitai.com")):
459
+ return get_civitai_safetensors(link)
460
+ else:
461
+ return get_huggingface_safetensors(link)
462
+ except Exception as e:
463
+ print("Error: ", e)
464
+ return None, None, None, None
465
+
466
+ def show_loading_widget():
467
+ return gr.update(visible=True)
468
+
469
+ def load_custom_lora(link):
470
+ title, path, trigger_word, image = check_custom_model(link)
471
+ if(title):
472
+ card = f'''
473
+ <div class="custom_lora_card">
474
+ <span>Loaded custom LoRA:</span>
475
+ <div class="card_internal">
476
+ <img src="{image}" />
477
+ <div>
478
+ <h3>{title}</h3>
479
+ <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br>
480
+ If the style isn't applied properly, modify advanced settings, specially <code>Face strength</code> and <code>Image strength</code>
481
+ </small>
482
+ </div>
483
+ </div>
484
+ </div>
485
+ '''
486
+ return gr.update(visible=True), card, gr.update(visible=True), [path, trigger_word], gr.Gallery(selected_index=None), f"Custom: {path}"
487
+ else:
488
+ gr.Warning("Invalid LoRA: either you entered an invalid link, a non-SDXL LoRA or a LoRA with mature content")
489
+ return gr.update(visible=True), "Invalid LoRA: either you entered an invalid link, a non-SDXL LoRA or a LoRA with mature content", gr.update(visible=False), None, gr.update(visible=True), gr.update(visible=True)
490
+ #except Exception as e:
491
+ # gr.Info("Invalid custom LoRA")
492
+ # return gr.update(visible=True), "Invalid LoRA: either you entered an invalid link, a non-SDXL LoRA or a LoRA with mature content", gr.update(visible=False), None, gr.update(visible=True), gr.update(visible=True)
493
+
494
+ def remove_custom_lora():
495
+ return "", gr.update(visible=False), gr.update(visible=False), None
496
  with gr.Blocks(css="custom.css") as demo:
497
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
498
  title = gr.HTML(
 
506
  elem_id="title",
507
  )
508
  selected_state = gr.State()
509
+ custom_loaded_lora = gr.State()
510
  with gr.Row(elem_id="main_app"):
511
+ with gr.Column(scale=4, elem_id="box_column"):
512
  with gr.Group(elem_id="gallery_box"):
513
  photo = gr.Image(label="Upload a picture of yourself", interactive=True, type="pil", height=300)
514
  selected_loras = gr.Gallery(label="Selected LoRAs", height=80, show_share_button=False, visible=False, elem_id="gallery_selected", )
 
520
  # value=[(item["image"], item["title"]) for item in sdxl_loras_raw_new], allow_preview=False, show_share_button=False)
521
  gallery = gr.Gallery(
522
  #value=[(item["image"], item["title"]) for item in sdxl_loras],
523
+ label="Pick a style from the gallery",
524
  allow_preview=False,
525
  columns=4,
526
  elem_id="gallery",
527
  show_share_button=False,
528
  height=550
529
  )
530
+ custom_model = gr.Textbox(label="or enter a custom Hugging Face or CivitAI SDXL LoRA", placeholder="Paste Hugging Face or CivitAI model path...")
531
+ custom_model_card = gr.HTML(visible=False)
532
+ custom_model_button = gr.Button("Remove custom LoRA", visible=False)
533
  with gr.Column(scale=5):
534
  with gr.Row():
535
  prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt")
 
559
  # outputs=[gallery, gr_sdxl_loras],
560
  # queue=False
561
  #)
562
+ custom_model.input(
563
+ fn=load_custom_lora,
564
+ inputs=[custom_model],
565
+ outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title],
566
+ queue=False
567
+ )
568
+ custom_model_button.click(
569
+ fn=remove_custom_lora,
570
+ outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora]
571
+ )
572
  gallery.select(
573
  fn=update_selection,
574
  inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
 
585
  #)
586
  prompt.submit(
587
  fn=check_selected,
588
+ inputs=[selected_state, custom_loaded_lora],
589
  queue=False,
590
  show_progress=False
591
  ).success(
592
  fn=run_lora,
593
+ inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
594
  outputs=[result, share_group],
595
  )
596
  button.click(
597
  fn=check_selected,
598
+ inputs=[selected_state, custom_loaded_lora],
599
  queue=False,
600
  show_progress=False
601
  ).success(
602
  fn=run_lora,
603
+ inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
604
  outputs=[result, share_group],
605
  )
606
  share_button.click(None, [], [], js=share_js)