primerz commited on
Commit
88a08ab
·
verified ·
1 Parent(s): b5dddcc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -64
app.py CHANGED
@@ -45,12 +45,12 @@ with open("sdxl_loras.json", "r") as file:
45
  data = json.load(file)
46
  sdxl_loras_raw = [
47
  {
48
- "image": item.get("image", ""),
49
- "title": item.get("nickname", item.get("title", "")),
50
- "repo": item.get("model", item.get("repo", "")),
51
- "trigger_word": item.get("prompt", item.get("trigger_word", "")),
52
- "weights": item.get("weights", ""),
53
- "is_compatible": item.get("is_compatible", True),
54
  "is_pivotal": item.get("is_pivotal", False),
55
  "text_embedding_weights": item.get("text_embedding_weights", None),
56
  "likes": item.get("likes", 0),
@@ -70,9 +70,6 @@ device = "cuda"
70
  state_dicts = {}
71
 
72
  for item in sdxl_loras_raw:
73
- if not item["weights"]:
74
- continue
75
-
76
  saved_name = hf_hub_download(item["repo"], item["weights"])
77
 
78
  if not saved_name.endswith('.safetensors'):
@@ -134,8 +131,7 @@ elapsed_time = et - st
134
  print('Loading VAE took: ', elapsed_time, 'seconds')
135
  st = time.time()
136
 
137
- #pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained("stablediffusionapi/albedobase-xl-v21",
138
- pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained("frankjoshua/albedobaseXL_v21",
139
  vae=vae,
140
  controlnet=[identitynet, zoedepthnet],
141
  torch_dtype=torch.float16)
@@ -174,7 +170,7 @@ def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, i
174
  lora_repo = sdxl_loras[selected_state.index]["repo"]
175
  new_placeholder = "Type a prompt to use your selected LoRA"
176
  weight_name = sdxl_loras[selected_state.index]["weights"]
177
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨ {'(non-commercial LoRA, `cc-by-nc`)' if sdxl_loras[selected_state.index]['is_nc'] else '' }"
178
 
179
  for lora_list in lora_defaults:
180
  if lora_list["model"] == sdxl_loras[selected_state.index]["repo"]:
@@ -238,7 +234,7 @@ def merge_incompatible_lora(full_path_lora, lora_scale):
238
  del weights_sd
239
  del lora_model
240
 
241
- @spaces.GPU(duration=100)
242
  def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index, st):
243
  print(loaded_state_dict)
244
  et = time.time()
@@ -257,37 +253,8 @@ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_stren
257
  et = time.time()
258
  elapsed_time = et - st
259
  print('Zoe Depth calculations took: ', elapsed_time, 'seconds')
260
-
261
- # Only handle lora if we have weights
262
- if loaded_state_dict is not None:
263
- if last_lora != repo_name:
264
- if(last_fused):
265
- st = time.time()
266
- pipe.unfuse_lora()
267
- pipe.unload_lora_weights()
268
- pipe.unload_textual_inversion()
269
- et = time.time()
270
- elapsed_time = et - st
271
- print('Unfuse and unload LoRA took: ', elapsed_time, 'seconds')
272
- st = time.time()
273
- pipe.load_lora_weights(loaded_state_dict)
274
- pipe.fuse_lora(lora_scale)
275
- et = time.time()
276
- elapsed_time = et - st
277
- print('Fuse and load LoRA took: ', elapsed_time, 'seconds')
278
- last_fused = True
279
- is_pivotal = sdxl_loras[selected_state_index].get("is_pivotal", False)
280
- if(is_pivotal):
281
- #Add the textual inversion embeddings from pivotal tuning models
282
- text_embedding_name = sdxl_loras[selected_state_index].get("text_embedding_weights")
283
- if text_embedding_name:
284
- embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
285
- state_dict_embedding = load_file(embedding_path)
286
- pipe.load_textual_inversion(state_dict_embedding["clip_l" if "clip_l" in state_dict_embedding else "text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
287
- pipe.load_textual_inversion(state_dict_embedding["clip_g" if "clip_g" in state_dict_embedding else "text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
288
- else:
289
- # No lora to load, unfuse any existing lora
290
- if last_fused:
291
  st = time.time()
292
  pipe.unfuse_lora()
293
  pipe.unload_lora_weights()
@@ -295,7 +262,21 @@ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_stren
295
  et = time.time()
296
  elapsed_time = et - st
297
  print('Unfuse and unload LoRA took: ', elapsed_time, 'seconds')
298
- last_fused = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  print("Processing prompt...")
301
  st = time.time()
@@ -320,7 +301,7 @@ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_stren
320
  image=face_image,
321
  strength=1-image_strength,
322
  control_image=images,
323
- num_inference_steps=36,
324
  guidance_scale = guidance_scale,
325
  controlnet_conditioning_scale=[face_strength, depth_control_scale],
326
  ).images[0]
@@ -374,12 +355,8 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
374
  full_path_lora = custom_lora_path
375
  else:
376
  repo_name = sdxl_loras[selected_state_index]["repo"]
377
- weight_name = sdxl_loras[selected_state_index].get("weights", "")
378
- if weight_name and repo_name in state_dicts:
379
- full_path_lora = state_dicts[repo_name]["saved_name"]
380
- else:
381
- # No weights available, use base model without lora
382
- full_path_lora = None
383
  print("Full path LoRA ", full_path_lora)
384
  #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
385
  cross_attention_kwargs = None
@@ -395,11 +372,11 @@ run_lora.zerogpu = True
395
 
396
  def shuffle_gallery(sdxl_loras):
397
  random.shuffle(sdxl_loras)
398
- return [(item.get("image") or None, item["title"]) for item in sdxl_loras], sdxl_loras
399
 
400
  def classify_gallery(sdxl_loras):
401
  sorted_gallery = sorted(sdxl_loras, key=lambda x: x.get("likes", 0), reverse=True)
402
- return [(item.get("image") or None, item["title"]) for item in sorted_gallery], sorted_gallery
403
 
404
  def swap_gallery(order, sdxl_loras):
405
  if(order == "random"):
@@ -447,10 +424,10 @@ def get_civitai_safetensors(link):
447
  if(x.status_code != 200):
448
  raise Exception("Invalid CivitAI URL")
449
  model_data = x.json()
450
- #if(model_data["nsfw"] == True or model_data["nsfwLevel"] > 20):
451
- # gr.Warning("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.")
452
- # raise Exception("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.")
453
- if(model_data["type"] != "LORA"):
454
  gr.Warning("The model isn't tagged at CivitAI as a LoRA")
455
  raise Exception("The model isn't tagged at CivitAI as a LoRA")
456
  model_link_download = None
@@ -519,12 +496,12 @@ with gr.Blocks(css="custom.css") as demo:
519
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
520
  title = gr.HTML(
521
  """<h1><img src="https://i.imgur.com/DVoGw04.png">
522
- <span>Face to All<br><small style="
523
  font-size: 13px;
524
  display: block;
525
  font-weight: normal;
526
  opacity: 0.75;
527
- ">🧨 diffusers InstantID + ControlNet<br> inspired by fofr's <a href="https://github.com/fofr/cog-face-to-many" target="_blank">face-to-many</a></small></span></h1>""",
528
  elem_id="title",
529
  )
530
  selected_state = gr.State()
@@ -594,7 +571,7 @@ with gr.Blocks(css="custom.css") as demo:
594
  fn=update_selection,
595
  inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
596
  outputs=[prompt_title, prompt, face_strength, image_strength, weight, depth_control_scale, negative, selected_state],
597
- show_progress=True
598
  )
599
  #new_gallery.select(
600
  # fn=update_selection,
@@ -606,7 +583,7 @@ with gr.Blocks(css="custom.css") as demo:
606
  prompt.submit(
607
  fn=check_selected,
608
  inputs=[selected_state, custom_loaded_lora],
609
- show_progress=True
610
  ).success(
611
  fn=run_lora,
612
  inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
@@ -615,7 +592,7 @@ with gr.Blocks(css="custom.css") as demo:
615
  button.click(
616
  fn=check_selected,
617
  inputs=[selected_state, custom_loaded_lora],
618
- show_progress=True
619
  ).success(
620
  fn=run_lora,
621
  inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
@@ -625,4 +602,5 @@ with gr.Blocks(css="custom.css") as demo:
625
  demo.load(fn=classify_gallery, inputs=[gr_sdxl_loras], outputs=[gallery, gr_sdxl_loras], js=js)
626
 
627
  demo.queue(default_concurrency_limit=None, api_open=True)
628
- demo.launch(share=True)
 
 
45
  data = json.load(file)
46
  sdxl_loras_raw = [
47
  {
48
+ "image": item["image"],
49
+ "title": item["title"],
50
+ "repo": item["repo"],
51
+ "trigger_word": item["trigger_word"],
52
+ "weights": item["weights"],
53
+ "is_compatible": item["is_compatible"],
54
  "is_pivotal": item.get("is_pivotal", False),
55
  "text_embedding_weights": item.get("text_embedding_weights", None),
56
  "likes": item.get("likes", 0),
 
70
  state_dicts = {}
71
 
72
  for item in sdxl_loras_raw:
 
 
 
73
  saved_name = hf_hub_download(item["repo"], item["weights"])
74
 
75
  if not saved_name.endswith('.safetensors'):
 
131
  print('Loading VAE took: ', elapsed_time, 'seconds')
132
  st = time.time()
133
 
134
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained("rubbrband/albedobaseXL_v21",
 
135
  vae=vae,
136
  controlnet=[identitynet, zoedepthnet],
137
  torch_dtype=torch.float16)
 
170
  lora_repo = sdxl_loras[selected_state.index]["repo"]
171
  new_placeholder = "Type a prompt to use your selected LoRA"
172
  weight_name = sdxl_loras[selected_state.index]["weights"]
173
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) {'(non-commercial LoRA, `cc-by-nc`)' if sdxl_loras[selected_state.index]['is_nc'] else '' }"
174
 
175
  for lora_list in lora_defaults:
176
  if lora_list["model"] == sdxl_loras[selected_state.index]["repo"]:
 
234
  del weights_sd
235
  del lora_model
236
 
237
+ @spaces.GPU(duration=80)
238
  def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index, st):
239
  print(loaded_state_dict)
240
  et = time.time()
 
253
  et = time.time()
254
  elapsed_time = et - st
255
  print('Zoe Depth calculations took: ', elapsed_time, 'seconds')
256
+ if last_lora != repo_name:
257
+ if(last_fused):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  st = time.time()
259
  pipe.unfuse_lora()
260
  pipe.unload_lora_weights()
 
262
  et = time.time()
263
  elapsed_time = et - st
264
  print('Unfuse and unload LoRA took: ', elapsed_time, 'seconds')
265
+ st = time.time()
266
+ pipe.load_lora_weights(loaded_state_dict)
267
+ pipe.fuse_lora(lora_scale)
268
+ et = time.time()
269
+ elapsed_time = et - st
270
+ print('Fuse and load LoRA took: ', elapsed_time, 'seconds')
271
+ last_fused = True
272
+ is_pivotal = sdxl_loras[selected_state_index]["is_pivotal"]
273
+ if(is_pivotal):
274
+ #Add the textual inversion embeddings from pivotal tuning models
275
+ text_embedding_name = sdxl_loras[selected_state_index]["text_embedding_weights"]
276
+ embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
277
+ state_dict_embedding = load_file(embedding_path)
278
+ pipe.load_textual_inversion(state_dict_embedding["clip_l" if "clip_l" in state_dict_embedding else "text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
279
+ pipe.load_textual_inversion(state_dict_embedding["clip_g" if "clip_g" in state_dict_embedding else "text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
280
 
281
  print("Processing prompt...")
282
  st = time.time()
 
301
  image=face_image,
302
  strength=1-image_strength,
303
  control_image=images,
304
+ num_inference_steps=20,
305
  guidance_scale = guidance_scale,
306
  controlnet_conditioning_scale=[face_strength, depth_control_scale],
307
  ).images[0]
 
355
  full_path_lora = custom_lora_path
356
  else:
357
  repo_name = sdxl_loras[selected_state_index]["repo"]
358
+ weight_name = sdxl_loras[selected_state_index]["weights"]
359
+ full_path_lora = state_dicts[repo_name]["saved_name"]
 
 
 
 
360
  print("Full path LoRA ", full_path_lora)
361
  #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
362
  cross_attention_kwargs = None
 
372
 
373
  def shuffle_gallery(sdxl_loras):
374
  random.shuffle(sdxl_loras)
375
+ return [(item["image"], item["title"]) for item in sdxl_loras], sdxl_loras
376
 
377
  def classify_gallery(sdxl_loras):
378
  sorted_gallery = sorted(sdxl_loras, key=lambda x: x.get("likes", 0), reverse=True)
379
+ return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
380
 
381
  def swap_gallery(order, sdxl_loras):
382
  if(order == "random"):
 
424
  if(x.status_code != 200):
425
  raise Exception("Invalid CivitAI URL")
426
  model_data = x.json()
427
+ if(model_data["nsfw"] == True or model_data["nsfwLevel"] > 20):
428
+ gr.Warning("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.")
429
+ raise Exception("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.")
430
+ elif(model_data["type"] != "LORA"):
431
  gr.Warning("The model isn't tagged at CivitAI as a LoRA")
432
  raise Exception("The model isn't tagged at CivitAI as a LoRA")
433
  model_link_download = None
 
496
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
497
  title = gr.HTML(
498
  """<h1><img src="https://i.imgur.com/DVoGw04.png">
499
+ <span>Face to All SDXL<br><small style="
500
  font-size: 13px;
501
  display: block;
502
  font-weight: normal;
503
  opacity: 0.75;
504
+ ">🧨 diffusers InstantID + ControlNet<br> inspired by fofr's <a href="https://github.com/fofr/cog-face-to-many" target="_blank">face-to-many</a></small></span></h1>""",
505
  elem_id="title",
506
  )
507
  selected_state = gr.State()
 
571
  fn=update_selection,
572
  inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
573
  outputs=[prompt_title, prompt, face_strength, image_strength, weight, depth_control_scale, negative, selected_state],
574
+ show_progress=False
575
  )
576
  #new_gallery.select(
577
  # fn=update_selection,
 
583
  prompt.submit(
584
  fn=check_selected,
585
  inputs=[selected_state, custom_loaded_lora],
586
+ show_progress=False
587
  ).success(
588
  fn=run_lora,
589
  inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
 
592
  button.click(
593
  fn=check_selected,
594
  inputs=[selected_state, custom_loaded_lora],
595
+ show_progress=False
596
  ).success(
597
  fn=run_lora,
598
  inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
 
602
  demo.load(fn=classify_gallery, inputs=[gr_sdxl_loras], outputs=[gallery, gr_sdxl_loras], js=js)
603
 
604
  demo.queue(default_concurrency_limit=None, api_open=True)
605
+ demo.launch(share=True)
606
+