multimodalart HF staff commited on
Commit
8ca8d03
1 Parent(s): 15183c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -54
app.py CHANGED
@@ -123,6 +123,8 @@ pipe.load_ip_adapter_instantid(face_adapter)
123
  pipe.set_ip_adapter_scale(0.8)
124
  zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
125
  zoe.to(device)
 
 
126
  pipe.to(device)
127
 
128
  last_lora = ""
@@ -202,10 +204,58 @@ def merge_incompatible_lora(full_path_lora, lora_scale):
202
  )
203
  del weights_sd
204
  del lora_model
205
- @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, progress=gr.Progress(track_tqdm=True)):
207
  global last_lora, last_merged, last_fused, pipe
208
-
209
  face_image = center_crop_image_as_square(face_image)
210
  try:
211
  face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
@@ -216,7 +266,7 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
216
  raise gr.Error("No face found in your image. Only face images work here. Try again")
217
 
218
  for lora_list in lora_defaults:
219
- if lora_list["model"] == sdxl_loras[selected_state.index]["repo"]:
220
  prompt_full = lora_list.get("prompt", None)
221
  if(prompt_full):
222
  prompt = prompt_full.replace("<subject>", prompt)
@@ -224,7 +274,7 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
224
 
225
  print("Prompt:", prompt)
226
  if(prompt == ""):
227
- prompt = "A person"
228
  #prepare face zoe
229
  with torch.no_grad():
230
  image_zoe = zoe(face_image)
@@ -239,15 +289,15 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
239
  # else:
240
  # selected_state.index *= -1
241
  #sdxl_loras = sdxl_loras_new
242
- print("Selected State: ", selected_state.index)
243
- print(sdxl_loras[selected_state.index]["repo"])
244
  if negative == "":
245
  negative = None
246
 
247
  if not selected_state:
248
  raise gr.Error("You must select a LoRA")
249
- repo_name = sdxl_loras[selected_state.index]["repo"]
250
- weight_name = sdxl_loras[selected_state.index]["weights"]
251
 
252
  full_path_lora = state_dicts[repo_name]["saved_name"]
253
  loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
@@ -255,53 +305,8 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
255
  print("Last LoRA: ", last_lora)
256
  print("Current LoRA: ", repo_name)
257
  print("Last fused: ", last_fused)
258
- if last_lora != repo_name:
259
- if(last_fused):
260
- pipe.unfuse_lora()
261
- pipe.unload_lora_weights()
262
- pipe.load_lora_weights(loaded_state_dict)
263
- pipe.fuse_lora(lora_scale)
264
- last_fused = True
265
- is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
266
- if(is_pivotal):
267
- #Add the textual inversion embeddings from pivotal tuning models
268
- text_embedding_name = sdxl_loras[selected_state.index]["text_embedding_weights"]
269
- embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
270
- state_dict_embedding = load_file(embedding_path)
271
- print(state_dict_embedding)
272
- try:
273
- pipe.unload_textual_inversion()
274
- pipe.load_textual_inversion(state_dict_embedding["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
275
- pipe.load_textual_inversion(state_dict_embedding["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
276
- except:
277
- pipe.unload_textual_inversion()
278
- pipe.load_textual_inversion(state_dict_embedding["text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
279
- pipe.load_textual_inversion(state_dict_embedding["text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
280
-
281
- print("Processing prompt...")
282
- conditioning, pooled = compel(prompt)
283
- if(negative):
284
- negative_conditioning, negative_pooled = compel(negative)
285
- else:
286
- negative_conditioning, negative_pooled = None, None
287
- print("Processing image...")
288
-
289
- image = pipe(
290
- prompt_embeds=conditioning,
291
- pooled_prompt_embeds=pooled,
292
- negative_prompt_embeds=negative_conditioning,
293
- negative_pooled_prompt_embeds=negative_pooled,
294
- width=1024,
295
- height=1024,
296
- image_embeds=face_emb,
297
- image=face_image,
298
- strength=1-image_strength,
299
- control_image=images,
300
- num_inference_steps=20,
301
- guidance_scale = guidance_scale,
302
- controlnet_conditioning_scale=[face_strength, depth_control_scale],
303
- ).images[0]
304
 
 
305
  last_lora = repo_name
306
  return image, gr.update(visible=True)
307
 
 
123
  pipe.set_ip_adapter_scale(0.8)
124
  zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
125
  zoe.to(device)
126
+
127
+ original_pipe = copy.deepcopy(pipe)
128
  pipe.to(device)
129
 
130
  last_lora = ""
 
204
  )
205
  del weights_sd
206
  del lora_model
207
+ @spaces.GPU
208
+ def generate_image(prompt, negative, face_emb, face_image, image_strength, images, guidance_scale, face_strength, depth_control_scale, last_lora, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index):
209
+ if last_lora != repo_name:
210
+ if(last_fused):
211
+ pipe.unfuse_lora()
212
+ pipe.unload_lora_weights()
213
+ pipe.load_lora_weights(loaded_state_dict)
214
+ pipe.fuse_lora(lora_scale)
215
+ last_fused = True
216
+ is_pivotal = sdxl_loras[selected_state_index]["is_pivotal"]
217
+ if(is_pivotal):
218
+ #Add the textual inversion embeddings from pivotal tuning models
219
+ text_embedding_name = sdxl_loras[selected_state_index]["text_embedding_weights"]
220
+ embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
221
+ state_dict_embedding = load_file(embedding_path)
222
+ print(state_dict_embedding)
223
+ try:
224
+ pipe.unload_textual_inversion()
225
+ pipe.load_textual_inversion(state_dict_embedding["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
226
+ pipe.load_textual_inversion(state_dict_embedding["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
227
+ except:
228
+ pipe.unload_textual_inversion()
229
+ pipe.load_textual_inversion(state_dict_embedding["text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
230
+ pipe.load_textual_inversion(state_dict_embedding["text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
231
+
232
+ print("Processing prompt...")
233
+ conditioning, pooled = compel(prompt)
234
+ if(negative):
235
+ negative_conditioning, negative_pooled = compel(negative)
236
+ else:
237
+ negative_conditioning, negative_pooled = None, None
238
+ print("Processing image...")
239
+ image = pipe(
240
+ prompt_embeds=conditioning,
241
+ pooled_prompt_embeds=pooled,
242
+ negative_prompt_embeds=negative_conditioning,
243
+ negative_pooled_prompt_embeds=negative_pooled,
244
+ width=1024,
245
+ height=1024,
246
+ image_embeds=face_emb,
247
+ image=face_image,
248
+ strength=1-image_strength,
249
+ control_image=images,
250
+ num_inference_steps=20,
251
+ guidance_scale = guidance_scale,
252
+ controlnet_conditioning_scale=[face_strength, depth_control_scale],
253
+ ).images[0]
254
+ return image
255
+
256
  def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, progress=gr.Progress(track_tqdm=True)):
257
  global last_lora, last_merged, last_fused, pipe
258
+ selected_state_index = selected_state.index
259
  face_image = center_crop_image_as_square(face_image)
260
  try:
261
  face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
 
266
  raise gr.Error("No face found in your image. Only face images work here. Try again")
267
 
268
  for lora_list in lora_defaults:
269
+ if lora_list["model"] == sdxl_loras[selected_state_index]["repo"]:
270
  prompt_full = lora_list.get("prompt", None)
271
  if(prompt_full):
272
  prompt = prompt_full.replace("<subject>", prompt)
 
274
 
275
  print("Prompt:", prompt)
276
  if(prompt == ""):
277
+ prompt = "a person"
278
  #prepare face zoe
279
  with torch.no_grad():
280
  image_zoe = zoe(face_image)
 
289
  # else:
290
  # selected_state.index *= -1
291
  #sdxl_loras = sdxl_loras_new
292
+ print("Selected State: ", selected_state_index)
293
+ print(sdxl_loras[selected_state_index]["repo"])
294
  if negative == "":
295
  negative = None
296
 
297
  if not selected_state:
298
  raise gr.Error("You must select a LoRA")
299
+ repo_name = sdxl_loras[selected_state_index]["repo"]
300
+ weight_name = sdxl_loras[selected_state_index]["weights"]
301
 
302
  full_path_lora = state_dicts[repo_name]["saved_name"]
303
  loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
 
305
  print("Last LoRA: ", last_lora)
306
  print("Current LoRA: ", repo_name)
307
  print("Last fused: ", last_fused)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
+ image = generate_image(prompt, negative, face_emb, face_image, image_strength, images, guidance_scale, face_strength, depth_control_scale, last_lora, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index)
310
  last_lora = repo_name
311
  return image, gr.update(visible=True)
312