primerz commited on
Commit
b5dddcc
·
verified ·
1 Parent(s): 43ce680

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -24
app.py CHANGED
@@ -257,8 +257,37 @@ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_stren
257
  et = time.time()
258
  elapsed_time = et - st
259
  print('Zoe Depth calculations took: ', elapsed_time, 'seconds')
260
- if last_lora != repo_name:
261
- if(last_fused):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  st = time.time()
263
  pipe.unfuse_lora()
264
  pipe.unload_lora_weights()
@@ -266,21 +295,7 @@ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_stren
266
  et = time.time()
267
  elapsed_time = et - st
268
  print('Unfuse and unload LoRA took: ', elapsed_time, 'seconds')
269
- st = time.time()
270
- pipe.load_lora_weights(loaded_state_dict)
271
- pipe.fuse_lora(lora_scale)
272
- et = time.time()
273
- elapsed_time = et - st
274
- print('Fuse and load LoRA took: ', elapsed_time, 'seconds')
275
- last_fused = True
276
- is_pivotal = sdxl_loras[selected_state_index]["is_pivotal"]
277
- if(is_pivotal):
278
- #Add the textual inversion embeddings from pivotal tuning models
279
- text_embedding_name = sdxl_loras[selected_state_index]["text_embedding_weights"]
280
- embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
281
- state_dict_embedding = load_file(embedding_path)
282
- pipe.load_textual_inversion(state_dict_embedding["clip_l" if "clip_l" in state_dict_embedding else "text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
283
- pipe.load_textual_inversion(state_dict_embedding["clip_g" if "clip_g" in state_dict_embedding else "text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
284
 
285
  print("Processing prompt...")
286
  st = time.time()
@@ -359,8 +374,12 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
359
  full_path_lora = custom_lora_path
360
  else:
361
  repo_name = sdxl_loras[selected_state_index]["repo"]
362
- weight_name = sdxl_loras[selected_state_index]["weights"]
363
- full_path_lora = state_dicts[repo_name]["saved_name"]
 
 
 
 
364
  print("Full path LoRA ", full_path_lora)
365
  #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
366
  cross_attention_kwargs = None
@@ -376,11 +395,11 @@ run_lora.zerogpu = True
376
 
377
  def shuffle_gallery(sdxl_loras):
378
  random.shuffle(sdxl_loras)
379
- return [(item["image"], item["title"]) for item in sdxl_loras], sdxl_loras
380
 
381
  def classify_gallery(sdxl_loras):
382
  sorted_gallery = sorted(sdxl_loras, key=lambda x: x.get("likes", 0), reverse=True)
383
- return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
384
 
385
  def swap_gallery(order, sdxl_loras):
386
  if(order == "random"):
@@ -575,7 +594,7 @@ with gr.Blocks(css="custom.css") as demo:
575
  fn=update_selection,
576
  inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
577
  outputs=[prompt_title, prompt, face_strength, image_strength, weight, depth_control_scale, negative, selected_state],
578
- show_progress=False
579
  )
580
  #new_gallery.select(
581
  # fn=update_selection,
@@ -587,7 +606,7 @@ with gr.Blocks(css="custom.css") as demo:
587
  prompt.submit(
588
  fn=check_selected,
589
  inputs=[selected_state, custom_loaded_lora],
590
- show_progress=False
591
  ).success(
592
  fn=run_lora,
593
  inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
@@ -596,7 +615,7 @@ with gr.Blocks(css="custom.css") as demo:
596
  button.click(
597
  fn=check_selected,
598
  inputs=[selected_state, custom_loaded_lora],
599
- show_progress=False
600
  ).success(
601
  fn=run_lora,
602
  inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
 
257
  et = time.time()
258
  elapsed_time = et - st
259
  print('Zoe Depth calculations took: ', elapsed_time, 'seconds')
260
+
261
+ # Only handle lora if we have weights
262
+ if loaded_state_dict is not None:
263
+ if last_lora != repo_name:
264
+ if(last_fused):
265
+ st = time.time()
266
+ pipe.unfuse_lora()
267
+ pipe.unload_lora_weights()
268
+ pipe.unload_textual_inversion()
269
+ et = time.time()
270
+ elapsed_time = et - st
271
+ print('Unfuse and unload LoRA took: ', elapsed_time, 'seconds')
272
+ st = time.time()
273
+ pipe.load_lora_weights(loaded_state_dict)
274
+ pipe.fuse_lora(lora_scale)
275
+ et = time.time()
276
+ elapsed_time = et - st
277
+ print('Fuse and load LoRA took: ', elapsed_time, 'seconds')
278
+ last_fused = True
279
+ is_pivotal = sdxl_loras[selected_state_index].get("is_pivotal", False)
280
+ if(is_pivotal):
281
+ #Add the textual inversion embeddings from pivotal tuning models
282
+ text_embedding_name = sdxl_loras[selected_state_index].get("text_embedding_weights")
283
+ if text_embedding_name:
284
+ embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
285
+ state_dict_embedding = load_file(embedding_path)
286
+ pipe.load_textual_inversion(state_dict_embedding["clip_l" if "clip_l" in state_dict_embedding else "text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
287
+ pipe.load_textual_inversion(state_dict_embedding["clip_g" if "clip_g" in state_dict_embedding else "text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
288
+ else:
289
+ # No lora to load, unfuse any existing lora
290
+ if last_fused:
291
  st = time.time()
292
  pipe.unfuse_lora()
293
  pipe.unload_lora_weights()
 
295
  et = time.time()
296
  elapsed_time = et - st
297
  print('Unfuse and unload LoRA took: ', elapsed_time, 'seconds')
298
+ last_fused = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  print("Processing prompt...")
301
  st = time.time()
 
374
  full_path_lora = custom_lora_path
375
  else:
376
  repo_name = sdxl_loras[selected_state_index]["repo"]
377
+ weight_name = sdxl_loras[selected_state_index].get("weights", "")
378
+ if weight_name and repo_name in state_dicts:
379
+ full_path_lora = state_dicts[repo_name]["saved_name"]
380
+ else:
381
+ # No weights available, use base model without lora
382
+ full_path_lora = None
383
  print("Full path LoRA ", full_path_lora)
384
  #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
385
  cross_attention_kwargs = None
 
395
 
396
  def shuffle_gallery(sdxl_loras):
397
  random.shuffle(sdxl_loras)
398
+ return [(item.get("image") or None, item["title"]) for item in sdxl_loras], sdxl_loras
399
 
400
  def classify_gallery(sdxl_loras):
401
  sorted_gallery = sorted(sdxl_loras, key=lambda x: x.get("likes", 0), reverse=True)
402
+ return [(item.get("image") or None, item["title"]) for item in sorted_gallery], sorted_gallery
403
 
404
  def swap_gallery(order, sdxl_loras):
405
  if(order == "random"):
 
594
  fn=update_selection,
595
  inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
596
  outputs=[prompt_title, prompt, face_strength, image_strength, weight, depth_control_scale, negative, selected_state],
597
+ show_progress=True
598
  )
599
  #new_gallery.select(
600
  # fn=update_selection,
 
606
  prompt.submit(
607
  fn=check_selected,
608
  inputs=[selected_state, custom_loaded_lora],
609
+ show_progress=True
610
  ).success(
611
  fn=run_lora,
612
  inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
 
615
  button.click(
616
  fn=check_selected,
617
  inputs=[selected_state, custom_loaded_lora],
618
+ show_progress=True
619
  ).success(
620
  fn=run_lora,
621
  inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],