primerz commited on
Commit
2f93f32
·
verified ·
1 Parent(s): f16b104

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -398
app.py CHANGED
@@ -38,7 +38,9 @@ from compel import Compel, ReturnedEmbeddingsType
38
 
39
  from gradio_imageslider import ImageSlider
40
 
41
- # Load LoRA configurations
 
 
42
  with open("sdxl_loras.json", "r") as file:
43
  data = json.load(file)
44
  sdxl_loras_raw = [
@@ -61,11 +63,12 @@ with open("sdxl_loras.json", "r") as file:
61
 
62
  with open("defaults_data.json", "r") as file:
63
  lora_defaults = json.load(file)
 
64
 
65
- device = "cuda"
66
 
67
- # Cache for LoRA state dicts
68
  state_dicts = {}
 
69
  for item in sdxl_loras_raw:
70
  saved_name = hf_hub_download(item["repo"], item["weights"])
71
 
@@ -80,8 +83,8 @@ for item in sdxl_loras_raw:
80
  }
81
 
82
  sdxl_loras_raw = [item for item in sdxl_loras_raw if item.get("new") != True]
83
-
84
- # Download models
85
  hf_hub_download(
86
  repo_id="InstantX/InstantID",
87
  filename="ControlNetModel/config.json",
@@ -100,158 +103,70 @@ hf_hub_download(
100
  filename="pytorch_lora_weights.safetensors",
101
  local_dir="/data/checkpoints",
102
  )
 
 
 
 
103
 
104
- # Download antelopev2
105
  antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2")
106
  print(antelope_download)
107
  app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
108
- app.prepare(ctx_id=0, det_size=(768, 768))
109
 
110
- # Prepare models
111
  face_adapter = f'/data/checkpoints/ip-adapter.bin'
112
  controlnet_path = f'/data/checkpoints/ControlNetModel'
113
 
 
114
  st = time.time()
115
  identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
116
- zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0", torch_dtype=torch.float16)
117
  et = time.time()
118
- print('Loading ControlNet took: ', et - st, 'seconds')
119
-
120
  st = time.time()
121
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
122
  et = time.time()
123
- print('Loading VAE took: ', et - st, 'seconds')
124
-
125
  st = time.time()
126
- pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
127
- "SG161222/RealVisXL_V5.0",
128
- vae=vae,
129
- controlnet=[identitynet, zoedepthnet],
130
- torch_dtype=torch.float16
131
- )
132
 
 
 
 
 
 
133
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
134
  pipe.load_ip_adapter_instantid(face_adapter)
135
- pipe.set_ip_adapter_scale(0.9)
136
  et = time.time()
137
- print('Loading pipeline took: ', et - st, 'seconds')
138
-
139
  st = time.time()
140
- compel = Compel(
141
- tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
142
- text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
143
- returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
144
- requires_pooled=[False, True]
145
- )
146
  et = time.time()
147
- print('Loading Compel took: ', et - st, 'seconds')
 
148
 
149
  st = time.time()
150
  zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
151
  et = time.time()
152
- print('Loading Zoe took: ', et - st, 'seconds')
 
153
  zoe.to(device)
154
  pipe.to(device)
155
 
156
  last_lora = ""
157
  last_fused = False
 
 
 
 
 
 
 
158
  lora_archive = "/data"
159
 
160
- # Improved face detection with multi-face support
161
- def detect_faces(face_image, use_multiple_faces=False):
162
- """
163
- Detect faces in the image
164
- Returns: list of face info dictionaries, or empty list if no faces
165
- """
166
- try:
167
- face_info_list = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
168
-
169
- if not face_info_list or len(face_info_list) == 0:
170
- print("No faces detected")
171
- return []
172
-
173
- # Sort faces by size (largest first)
174
- face_info_list = sorted(
175
- face_info_list,
176
- key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]),
177
- reverse=True
178
- )
179
-
180
- if use_multiple_faces:
181
- print(f"Detected {len(face_info_list)} faces")
182
- return face_info_list
183
- else:
184
- print(f"Using largest face (detected {len(face_info_list)} total)")
185
- return [face_info_list[0]]
186
-
187
- except Exception as e:
188
- print(f"Face detection error: {e}")
189
- return []
190
-
191
- def process_face_embeddings(face_info_list):
192
- """
193
- Process face embeddings - average multiple faces or return single face
194
- """
195
- if not face_info_list:
196
- return None
197
-
198
- if len(face_info_list) == 1:
199
- return face_info_list[0]['embedding']
200
-
201
- # Average embeddings for multiple faces
202
- embeddings = [face_info['embedding'] for face_info in face_info_list]
203
- avg_embedding = np.mean(embeddings, axis=0)
204
- return avg_embedding
205
-
206
- def create_face_kps_image(face_image, face_info_list):
207
- """
208
- Create keypoints image from face info
209
- """
210
- if not face_info_list:
211
- return face_image
212
-
213
- # For multiple faces, draw all keypoints
214
- if len(face_info_list) > 1:
215
- return draw_multiple_kps(face_image, [f['kps'] for f in face_info_list])
216
- else:
217
- return draw_kps(face_image, face_info_list[0]['kps'])
218
-
219
- def draw_multiple_kps(image_pil, kps_list, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
220
- """
221
- Draw keypoints for multiple faces
222
- """
223
- stickwidth = 4
224
- limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
225
-
226
- w, h = image_pil.size
227
- out_img = np.zeros([h, w, 3])
228
-
229
- for kps in kps_list:
230
- kps = np.array(kps)
231
-
232
- for i in range(len(limbSeq)):
233
- index = limbSeq[i]
234
- color = color_list[index[0]]
235
-
236
- x = kps[index][:, 0]
237
- y = kps[index][:, 1]
238
- length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
239
- angle = np.degrees(np.arctan2(y[0] - y[1], x[0] - x[1]))
240
- polygon = cv2.ellipse2Poly(
241
- (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
242
- )
243
- out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
244
-
245
- out_img = (out_img * 0.6).astype(np.uint8)
246
-
247
- for idx_kp, kp in enumerate(kps):
248
- color = color_list[idx_kp]
249
- x, y = kp
250
- out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
251
-
252
- out_img_pil = Image.fromarray(out_img.astype(np.uint8))
253
- return out_img_pil
254
-
255
  def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative, is_new=False):
256
  lora_repo = sdxl_loras[selected_state.index]["repo"]
257
  new_placeholder = "Type a prompt to use your selected LoRA"
@@ -260,9 +175,9 @@ def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, i
260
 
261
  for lora_list in lora_defaults:
262
  if lora_list["model"] == sdxl_loras[selected_state.index]["repo"]:
263
- face_strength = lora_list.get("face_strength", 0.9)
264
- image_strength = lora_list.get("image_strength", 0.2)
265
- weight = lora_list.get("weight", 0.95)
266
  depth_control_scale = lora_list.get("depth_control_scale", 0.8)
267
  negative = lora_list.get("negative", "")
268
 
@@ -283,82 +198,157 @@ def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, i
283
  selected_state
284
  )
285
 
 
 
 
 
 
 
 
 
 
 
 
286
  def check_selected(selected_state, custom_lora):
287
  if not selected_state and not custom_lora:
288
  raise gr.Error("You must select a style")
289
 
290
- def resize_image_aspect_ratio(img, max_dim=1280):
291
- width, height = img.size
292
- aspect_ratio = width / height
293
-
294
- if aspect_ratio >= 1: # Landscape or square
295
- new_width = min(max_dim, width)
296
- new_height = int(new_width / aspect_ratio)
297
- else: # Portrait
298
- new_height = min(max_dim, height)
299
- new_width = int(new_height * aspect_ratio)
300
-
301
- new_width = (new_width // 8) * 8
302
- new_height = (new_height // 8) * 8
303
-
304
- return img.resize((new_width, new_height), Image.LANCZOS)
305
-
306
-
307
- def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength,
308
- guidance_scale, depth_control_scale, sdxl_loras, custom_lora, use_multiple_faces=False,
309
- progress=gr.Progress(track_tqdm=True)):
310
- """
311
- Enhanced run_lora with support for:
312
- - No faces (landscape mode)
313
- - Multiple faces
314
- - Improved results
315
- """
316
- print("Custom LoRA:", custom_lora)
317
- custom_lora_path = custom_lora[0] if custom_lora else None
318
- selected_state_index = selected_state.index if selected_state else -1
319
 
 
 
 
 
 
 
 
 
 
 
 
320
  st = time.time()
321
- face_image = resize_image_aspect_ratio(face_image)
322
-
323
- # Enhanced face detection
324
- face_info_list = detect_faces(face_image, use_multiple_faces)
325
- face_detected = len(face_info_list) > 0
326
-
327
- if face_detected:
328
- face_emb = process_face_embeddings(face_info_list)
329
- face_kps = create_face_kps_image(face_image, face_info_list)
330
- print(f"Processing with {len(face_info_list)} face(s)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  else:
332
- face_emb = None
333
- face_kps = face_image
334
- print("No faces detected - using landscape/depth mode only")
335
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  et = time.time()
337
- print('Face processing took:', et - st, 'seconds')
 
338
 
339
  st = time.time()
340
 
341
- # Enhanced prompt processing
342
- if custom_lora_path and custom_lora[1]:
343
  prompt = f"{prompt} {custom_lora[1]}"
344
  else:
345
  for lora_list in lora_defaults:
346
  if lora_list["model"] == sdxl_loras[selected_state_index]["repo"]:
347
  prompt_full = lora_list.get("prompt", None)
348
- if prompt_full:
349
  prompt = prompt_full.replace("<subject>", prompt)
350
 
351
- print("Prompt:", prompt)
352
- if prompt == "":
353
- prompt = "a beautiful scene" if not face_detected else "a person"
354
  print(f"Executing prompt: {prompt}")
355
-
 
356
  if negative == "":
357
- # Enhanced negative prompt for better quality
358
- negative = "worst quality, low quality, blurry, distorted, deformed" if not face_detected else None
359
-
360
- print("Custom Loaded LoRA:", custom_lora_path)
361
-
362
  if not selected_state and not custom_lora_path:
363
  raise gr.Error("You must select a style")
364
  elif custom_lora_path:
@@ -366,130 +356,21 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
366
  full_path_lora = custom_lora_path
367
  else:
368
  repo_name = sdxl_loras[selected_state_index]["repo"]
 
369
  full_path_lora = state_dicts[repo_name]["saved_name"]
370
-
371
- repo_name = repo_name.rstrip("/").lower()
372
-
373
- print("Full path LoRA", full_path_lora)
374
-
375
  et = time.time()
376
- print('Prompt processing took:', et - st, 'seconds')
377
-
378
- # Adjust parameters based on face detection
379
- if not face_detected:
380
- # For landscape/no face mode, reduce face strength and increase depth control
381
- face_strength = 0.0
382
- depth_control_scale = max(depth_control_scale, 0.9)
383
- image_strength = min(image_strength, 0.4)
384
- print("Adjusted parameters for no-face mode")
385
 
386
  st = time.time()
387
- image = generate_image(
388
- prompt, negative, face_emb, face_image, face_kps, image_strength,
389
- guidance_scale, face_strength, depth_control_scale, repo_name,
390
- full_path_lora, lora_scale, sdxl_loras, selected_state_index, face_detected, st
391
- )
392
-
393
  return (face_image, image), gr.update(visible=True)
394
 
395
  run_lora.zerogpu = True
396
 
397
-
398
- @spaces.GPU(duration=75)
399
- def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale,
400
- face_strength, depth_control_scale, repo_name, loaded_state_dict, lora_scale,
401
- sdxl_loras, selected_state_index, face_detected, st):
402
- global last_fused, last_lora
403
-
404
- print("Loaded state dict:", loaded_state_dict)
405
- print("Last LoRA:", last_lora, "| Current LoRA:", repo_name)
406
-
407
- # Prepare control images and scales based on face detection
408
- if face_detected:
409
- control_images = [face_kps, zoe(face_image)]
410
- control_scales = [face_strength, depth_control_scale]
411
- else:
412
- # Only use depth control for landscapes
413
- control_images = [zoe(face_image)]
414
- control_scales = [depth_control_scale]
415
-
416
- # Handle custom LoRA from HuggingFace
417
- if repo_name.startswith("https://huggingface.co"):
418
- repo_id = repo_name.split("huggingface.co/")[-1]
419
- fs = HfFileSystem()
420
- files = fs.ls(repo_id, detail=False)
421
- safetensors_files = [f for f in files if f.endswith(".safetensors")]
422
-
423
- if not safetensors_files:
424
- raise gr.Error("No .safetensors file found in this Hugging Face repository.")
425
-
426
- weight_file = safetensors_files[0]
427
- full_path_lora = hf_hub_download(repo_id=repo_id, filename=weight_file, repo_type="model")
428
- else:
429
- full_path_lora = loaded_state_dict
430
-
431
- # Improved LoRA loading and caching
432
- if last_lora != repo_name:
433
- if last_fused:
434
- pipe.unfuse_lora()
435
- pipe.unload_lora_weights()
436
- pipe.unload_textual_inversion()
437
-
438
- # Load LoRA with better error handling
439
- try:
440
- pipe.load_lora_weights(full_path_lora)
441
- pipe.fuse_lora(lora_scale)
442
- last_fused = True
443
-
444
- # Handle pivotal tuning embeddings
445
- is_pivotal = sdxl_loras[selected_state_index]["is_pivotal"]
446
- if is_pivotal:
447
- text_embedding_name = sdxl_loras[selected_state_index]["text_embedding_weights"]
448
- embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
449
- state_dict_embedding = load_file(embedding_path)
450
- pipe.load_textual_inversion(
451
- state_dict_embedding["clip_l" if "clip_l" in state_dict_embedding else "text_encoders_0"],
452
- token=["<s0>", "<s1>"],
453
- text_encoder=pipe.text_encoder,
454
- tokenizer=pipe.tokenizer
455
- )
456
- pipe.load_textual_inversion(
457
- state_dict_embedding["clip_g" if "clip_g" in state_dict_embedding else "text_encoders_1"],
458
- token=["<s0>", "<s1>"],
459
- text_encoder=pipe.text_encoder_2,
460
- tokenizer=pipe.tokenizer_2
461
- )
462
- except Exception as e:
463
- print(f"Error loading LoRA: {e}")
464
- raise gr.Error(f"Failed to load LoRA: {str(e)}")
465
-
466
- print("Processing prompt...")
467
- conditioning, pooled = compel(prompt)
468
- negative_conditioning, negative_pooled = compel(negative) if negative else (None, None)
469
-
470
- # Enhanced generation parameters
471
- num_inference_steps = 40 # Increased for better quality
472
-
473
- print("Generating image...")
474
- image = pipe(
475
- prompt_embeds=conditioning,
476
- pooled_prompt_embeds=pooled,
477
- negative_prompt_embeds=negative_conditioning,
478
- negative_pooled_prompt_embeds=negative_pooled,
479
- width=face_image.width,
480
- height=face_image.height,
481
- image_embeds=face_emb if face_detected else None,
482
- image=face_image,
483
- strength=1-image_strength,
484
- control_image=control_images,
485
- num_inference_steps=num_inference_steps,
486
- guidance_scale=guidance_scale,
487
- controlnet_conditioning_scale=control_scales,
488
- ).images[0]
489
-
490
- last_lora = repo_name
491
- return image
492
-
493
  def shuffle_gallery(sdxl_loras):
494
  random.shuffle(sdxl_loras)
495
  return [(item["image"], item["title"]) for item in sdxl_loras], sdxl_loras
@@ -505,75 +386,75 @@ def swap_gallery(order, sdxl_loras):
505
  return classify_gallery(sdxl_loras)
506
 
507
  def deselect():
508
- return gr.Gallery(selected_index=None)
509
 
510
  def get_huggingface_safetensors(link):
511
- split_link = link.split("/")
512
- if(len(split_link) == 2):
513
- model_card = ModelCard.load(link)
514
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
515
- trigger_word = model_card.data.get("instance_prompt", "")
516
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
517
- fs = HfFileSystem()
518
- try:
519
- list_of_files = fs.ls(link, detail=False)
520
- for file in list_of_files:
521
- if(file.endswith(".safetensors")):
522
- safetensors_name = file.replace("/", "_")
523
- if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
524
- fs.get_file(file, lpath=f"{lora_archive}/{safetensors_name}")
525
- if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
526
- image_elements = file.split("/")
527
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
528
- except:
529
- gr.Warning("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
530
- raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
531
- return split_link[1], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
532
 
533
  def get_civitai_safetensors(link):
534
- link_split = link.split("civitai.com/")
535
- pattern = re.compile(r'models\/(\d+)')
536
- regex_match = pattern.search(link_split[1])
537
- if(regex_match):
538
- civitai_model_id = regex_match.group(1)
539
- else:
540
- gr.Warning("No CivitAI model id found in your URL")
541
- raise Exception("No CivitAI model id found in your URL")
542
- model_request_url = f"https://civitai.com/api/v1/models/{civitai_model_id}?token={os.getenv('CIVITAI_TOKEN')}"
543
- x = requests.get(model_request_url)
544
- if(x.status_code != 200):
545
- raise Exception("Invalid CivitAI URL")
546
- model_data = x.json()
547
-
548
- if(model_data["type"] != "LORA"):
549
- gr.Warning("The model isn't tagged at CivitAI as a LoRA")
550
- raise Exception("The model isn't tagged at CivitAI as a LoRA")
551
-
552
- model_link_download = None
553
- image_url = None
554
- trigger_word = ""
555
- for model in model_data["modelVersions"]:
556
- if(model["baseModel"] == "SDXL 1.0"):
557
- model_link_download = f"{model['downloadUrl']}/?token={os.getenv('CIVITAI_TOKEN')}"
558
- safetensors_name = model["files"][0]["name"]
559
- if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
560
- safetensors_file_request = requests.get(model_link_download)
561
- if(safetensors_file_request.status_code != 200):
562
- raise Exception("Invalid CivitAI download link")
563
- with open(f"{lora_archive}/{safetensors_name}", 'wb') as file:
564
- file.write(safetensors_file_request.content)
565
- trigger_word = model.get("trainedWords", [""])[0]
566
- for image in model["images"]:
567
- if(image["nsfwLevel"] == 1):
568
- image_url = image["url"]
569
- break
570
- break
571
-
572
- if(not model_link_download):
573
- gr.Warning("We couldn't find a SDXL LoRA on the model you've sent")
574
- raise Exception("We couldn't find a SDXL LoRA on the model you've sent")
575
- return model_data["name"], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
576
-
577
  def check_custom_model(link):
578
  if(link.startswith("https://")):
579
  if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
@@ -584,6 +465,9 @@ def check_custom_model(link):
584
  else:
585
  return get_huggingface_safetensors(link)
586
 
 
 
 
587
  def load_custom_lora(link):
588
  if(link):
589
  try:
@@ -609,29 +493,33 @@ def load_custom_lora(link):
609
 
610
  def remove_custom_lora():
611
  return "", gr.update(visible=False), gr.update(visible=False), None
612
-
613
- # Build Gradio interface
614
  with gr.Blocks(css="custom.css") as demo:
615
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
616
  title = gr.HTML(
617
  """<h1><img src="https://i.imgur.com/DVoGw04.png">
618
- <span>Face to All - Enhanced<br><small style="
619
  font-size: 13px;
620
  display: block;
621
  font-weight: normal;
622
  opacity: 0.75;
623
- ">🔥 Supports: No faces (landscape), Multiple faces, Improved quality, Custom LoRAs<br> diffusers InstantID + ControlNet</small></span></h1>""",
624
  elem_id="title",
625
  )
626
  selected_state = gr.State()
627
  custom_loaded_lora = gr.State()
628
-
629
  with gr.Row(elem_id="main_app"):
630
  with gr.Column(scale=4, elem_id="box_column"):
631
  with gr.Group(elem_id="gallery_box"):
632
- photo = gr.Image(label="Upload a picture (with or without faces)", interactive=True, type="pil", height=300)
633
- selected_loras = gr.Gallery(label="Selected LoRAs", height=80, show_share_button=False, visible=False, elem_id="gallery_selected")
 
 
 
 
 
 
634
  gallery = gr.Gallery(
 
635
  label="Pick a style from the gallery",
636
  allow_preview=False,
637
  columns=4,
@@ -642,82 +530,77 @@ with gr.Blocks(css="custom.css") as demo:
642
  custom_model = gr.Textbox(label="or enter a custom Hugging Face or CivitAI SDXL LoRA", placeholder="Paste Hugging Face or CivitAI model path...")
643
  custom_model_card = gr.HTML(visible=False)
644
  custom_model_button = gr.Button("Remove custom LoRA", visible=False)
645
-
646
  with gr.Column(scale=5):
647
  with gr.Row():
648
- prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1,
649
- info="Describe your subject or scene", value="a person", elem_id="prompt")
650
  button = gr.Button("Run", elem_id="run_button")
651
-
652
  result = ImageSlider(
653
  interactive=False, label="Generated Image", elem_id="result-image", position=0.1
654
  )
655
-
656
  with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
657
  community_icon = gr.HTML(community_icon_html)
658
  loading_icon = gr.HTML(loading_icon_html)
659
  share_button = gr.Button("Share to community", elem_id="share-btn")
660
-
661
  with gr.Accordion("Advanced options", open=False):
662
- use_multiple_faces = gr.Checkbox(label="Use multiple faces (if detected)", value=False)
663
  negative = gr.Textbox(label="Negative Prompt")
664
  weight = gr.Slider(0, 10, value=0.9, step=0.1, label="LoRA weight")
665
- face_strength = gr.Slider(0, 2, value=0.9, step=0.01, label="Face strength",
666
- info="Higher values increase face likeness (auto-adjusted for no-face images)")
667
- image_strength = gr.Slider(0, 1, value=0.20, step=0.01, label="Image strength",
668
- info="Higher values increase similarity with original structure/colors")
669
- guidance_scale = gr.Slider(0, 50, value=8, step=0.1, label="Guidance Scale")
670
- depth_control_scale = gr.Slider(0, 1, value=0.8, step=0.01, label="Zoe Depth ControlNet strength")
671
-
672
  prompt_title = gr.Markdown(
673
  value="### Click on a LoRA in the gallery to select it",
674
  visible=True,
675
  elem_id="selected_lora",
676
  )
677
-
678
- # Event handlers
 
 
 
 
679
  custom_model.input(
680
  fn=load_custom_lora,
681
  inputs=[custom_model],
682
  outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title],
683
  )
684
-
685
  custom_model_button.click(
686
  fn=remove_custom_lora,
687
  outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora]
688
  )
689
-
690
  gallery.select(
691
  fn=update_selection,
692
  inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
693
  outputs=[prompt_title, prompt, face_strength, image_strength, weight, depth_control_scale, negative, selected_state],
694
  show_progress=False
695
  )
696
-
 
 
 
 
 
 
697
  prompt.submit(
698
  fn=check_selected,
699
  inputs=[selected_state, custom_loaded_lora],
700
  show_progress=False
701
  ).success(
702
  fn=run_lora,
703
- inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength,
704
- guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora, use_multiple_faces],
705
  outputs=[result, share_group],
706
  )
707
-
708
  button.click(
709
  fn=check_selected,
710
  inputs=[selected_state, custom_loaded_lora],
711
  show_progress=False
712
  ).success(
713
  fn=run_lora,
714
- inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength,
715
- guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora, use_multiple_faces],
716
  outputs=[result, share_group],
717
  )
718
-
719
  share_button.click(None, [], [], js=share_js)
720
- demo.load(fn=classify_gallery, inputs=[gr_sdxl_loras], outputs=[gallery, gr_sdxl_loras])
721
 
722
  demo.queue(default_concurrency_limit=None, api_open=True)
723
  demo.launch(share=True)
 
38
 
39
  from gradio_imageslider import ImageSlider
40
 
41
+
42
+ #from gradio_imageslider import ImageSlider
43
+
44
  with open("sdxl_loras.json", "r") as file:
45
  data = json.load(file)
46
  sdxl_loras_raw = [
 
63
 
64
  with open("defaults_data.json", "r") as file:
65
  lora_defaults = json.load(file)
66
+
67
 
68
+ device = "cuda"
69
 
 
70
  state_dicts = {}
71
+
72
  for item in sdxl_loras_raw:
73
  saved_name = hf_hub_download(item["repo"], item["weights"])
74
 
 
83
  }
84
 
85
  sdxl_loras_raw = [item for item in sdxl_loras_raw if item.get("new") != True]
86
+
87
+ # download models
88
  hf_hub_download(
89
  repo_id="InstantX/InstantID",
90
  filename="ControlNetModel/config.json",
 
103
  filename="pytorch_lora_weights.safetensors",
104
  local_dir="/data/checkpoints",
105
  )
106
+ # download antelopev2
107
+ #if not os.path.exists("/data/antelopev2.zip"):
108
+ # gdown.download(url="https://drive.google.com/file/d/18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8/view?usp=sharing", output="/data/", quiet=False, fuzzy=True)
109
+ # os.system("unzip /data/antelopev2.zip -d /data/models/")
110
 
 
111
  antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2")
112
  print(antelope_download)
113
  app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
114
+ app.prepare(ctx_id=0, det_size=(640, 640))
115
 
116
+ # prepare models under ./checkpoints
117
  face_adapter = f'/data/checkpoints/ip-adapter.bin'
118
  controlnet_path = f'/data/checkpoints/ControlNetModel'
119
 
120
+ # load IdentityNet
121
  st = time.time()
122
  identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
123
+ zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0",torch_dtype=torch.float16)
124
  et = time.time()
125
+ elapsed_time = et - st
126
+ print('Loading ControlNet took: ', elapsed_time, 'seconds')
127
  st = time.time()
128
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
129
  et = time.time()
130
+ elapsed_time = et - st
131
+ print('Loading VAE took: ', elapsed_time, 'seconds')
132
  st = time.time()
 
 
 
 
 
 
133
 
134
+ #pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained("stablediffusionapi/albedobase-xl-v21",
135
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained("frankjoshua/albedobaseXL_v21",
136
+ vae=vae,
137
+ controlnet=[identitynet, zoedepthnet],
138
+ torch_dtype=torch.float16)
139
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
140
  pipe.load_ip_adapter_instantid(face_adapter)
141
+ pipe.set_ip_adapter_scale(0.8)
142
  et = time.time()
143
+ elapsed_time = et - st
144
+ print('Loading pipeline took: ', elapsed_time, 'seconds')
145
  st = time.time()
146
+ compel = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2] , text_encoder=[pipe.text_encoder, pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])
 
 
 
 
 
147
  et = time.time()
148
+ elapsed_time = et - st
149
+ print('Loading Compel took: ', elapsed_time, 'seconds')
150
 
151
  st = time.time()
152
  zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
153
  et = time.time()
154
+ elapsed_time = et - st
155
+ print('Loading Zoe took: ', elapsed_time, 'seconds')
156
  zoe.to(device)
157
  pipe.to(device)
158
 
159
  last_lora = ""
160
  last_fused = False
161
+ js = '''
162
+ var button = document.getElementById('button');
163
+ // Add a click event listener to the button
164
+ button.addEventListener('click', function() {
165
+ element.classList.add('selected');
166
+ });
167
+ '''
168
  lora_archive = "/data"
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative, is_new=False):
171
  lora_repo = sdxl_loras[selected_state.index]["repo"]
172
  new_placeholder = "Type a prompt to use your selected LoRA"
 
175
 
176
  for lora_list in lora_defaults:
177
  if lora_list["model"] == sdxl_loras[selected_state.index]["repo"]:
178
+ face_strength = lora_list.get("face_strength", 0.85)
179
+ image_strength = lora_list.get("image_strength", 0.15)
180
+ weight = lora_list.get("weight", 0.9)
181
  depth_control_scale = lora_list.get("depth_control_scale", 0.8)
182
  negative = lora_list.get("negative", "")
183
 
 
198
  selected_state
199
  )
200
 
201
+ def center_crop_image_as_square(img):
202
+ square_size = min(img.size)
203
+
204
+ left = (img.width - square_size) / 2
205
+ top = (img.height - square_size) / 2
206
+ right = (img.width + square_size) / 2
207
+ bottom = (img.height + square_size) / 2
208
+
209
+ img_cropped = img.crop((left, top, right, bottom))
210
+ return img_cropped
211
+
212
  def check_selected(selected_state, custom_lora):
213
  if not selected_state and not custom_lora:
214
  raise gr.Error("You must select a style")
215
 
216
+ def merge_incompatible_lora(full_path_lora, lora_scale):
217
+ for weights_file in [full_path_lora]:
218
+ if ";" in weights_file:
219
+ weights_file, multiplier = weights_file.split(";")
220
+ multiplier = float(multiplier)
221
+ else:
222
+ multiplier = lora_scale
223
+
224
+ lora_model, weights_sd = lora.create_network_from_weights(
225
+ multiplier,
226
+ full_path_lora,
227
+ pipe.vae,
228
+ pipe.text_encoder,
229
+ pipe.unet,
230
+ for_inference=True,
231
+ )
232
+ lora_model.merge_to(
233
+ pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
234
+ )
235
+ del weights_sd
236
+ del lora_model
 
 
 
 
 
 
 
 
237
 
238
+ @spaces.GPU(duration=100)
239
+ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index, st):
240
+ print(loaded_state_dict)
241
+ et = time.time()
242
+ elapsed_time = et - st
243
+ print('Getting into the decorated function took: ', elapsed_time, 'seconds')
244
+ global last_fused, last_lora
245
+ print("Last LoRA: ", last_lora)
246
+ print("Current LoRA: ", repo_name)
247
+ print("Last fused: ", last_fused)
248
+ #prepare face zoe
249
  st = time.time()
250
+ with torch.no_grad():
251
+ image_zoe = zoe(face_image)
252
+ width, height = face_kps.size
253
+ images = [face_kps, image_zoe.resize((height, width))]
254
+ et = time.time()
255
+ elapsed_time = et - st
256
+ print('Zoe Depth calculations took: ', elapsed_time, 'seconds')
257
+ if last_lora != repo_name:
258
+ if(last_fused):
259
+ st = time.time()
260
+ pipe.unfuse_lora()
261
+ pipe.unload_lora_weights()
262
+ pipe.unload_textual_inversion()
263
+ et = time.time()
264
+ elapsed_time = et - st
265
+ print('Unfuse and unload LoRA took: ', elapsed_time, 'seconds')
266
+ st = time.time()
267
+ pipe.load_lora_weights(loaded_state_dict)
268
+ pipe.fuse_lora(lora_scale)
269
+ et = time.time()
270
+ elapsed_time = et - st
271
+ print('Fuse and load LoRA took: ', elapsed_time, 'seconds')
272
+ last_fused = True
273
+ is_pivotal = sdxl_loras[selected_state_index]["is_pivotal"]
274
+ if(is_pivotal):
275
+ #Add the textual inversion embeddings from pivotal tuning models
276
+ text_embedding_name = sdxl_loras[selected_state_index]["text_embedding_weights"]
277
+ embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
278
+ state_dict_embedding = load_file(embedding_path)
279
+ pipe.load_textual_inversion(state_dict_embedding["clip_l" if "clip_l" in state_dict_embedding else "text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
280
+ pipe.load_textual_inversion(state_dict_embedding["clip_g" if "clip_g" in state_dict_embedding else "text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
281
+
282
+ print("Processing prompt...")
283
+ st = time.time()
284
+ conditioning, pooled = compel(prompt)
285
+ if(negative):
286
+ negative_conditioning, negative_pooled = compel(negative)
287
  else:
288
+ negative_conditioning, negative_pooled = None, None
289
+ et = time.time()
290
+ elapsed_time = et - st
291
+ print('Prompt processing took: ', elapsed_time, 'seconds')
292
+ print("Processing image...")
293
+ st = time.time()
294
+ image = pipe(
295
+ prompt_embeds=conditioning,
296
+ pooled_prompt_embeds=pooled,
297
+ negative_prompt_embeds=negative_conditioning,
298
+ negative_pooled_prompt_embeds=negative_pooled,
299
+ width=1024,
300
+ height=1024,
301
+ image_embeds=face_emb,
302
+ image=face_image,
303
+ strength=1-image_strength,
304
+ control_image=images,
305
+ num_inference_steps=36,
306
+ guidance_scale = guidance_scale,
307
+ controlnet_conditioning_scale=[face_strength, depth_control_scale],
308
+ ).images[0]
309
+ et = time.time()
310
+ elapsed_time = et - st
311
+ print('Image processing took: ', elapsed_time, 'seconds')
312
+ last_lora = repo_name
313
+ return image
314
+
315
+ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, custom_lora, progress=gr.Progress(track_tqdm=True)):
316
+ print("Custom LoRA: ", custom_lora)
317
+ custom_lora_path = custom_lora[0] if custom_lora else None
318
+ selected_state_index = selected_state.index if selected_state else -1
319
+ st = time.time()
320
+ face_image = center_crop_image_as_square(face_image)
321
+ try:
322
+ face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
323
+ face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
324
+ face_emb = face_info['embedding']
325
+ face_kps = draw_kps(face_image, face_info['kps'])
326
+ except:
327
+ raise gr.Error("No face found in your image. Only face images work here. Try again")
328
  et = time.time()
329
+ elapsed_time = et - st
330
+ print('Cropping and calculating face embeds took: ', elapsed_time, 'seconds')
331
 
332
  st = time.time()
333
 
334
+ if(custom_lora_path and custom_lora[1]):
 
335
  prompt = f"{prompt} {custom_lora[1]}"
336
  else:
337
  for lora_list in lora_defaults:
338
  if lora_list["model"] == sdxl_loras[selected_state_index]["repo"]:
339
  prompt_full = lora_list.get("prompt", None)
340
+ if(prompt_full):
341
  prompt = prompt_full.replace("<subject>", prompt)
342
 
343
+ print("Prompt:", prompt)
344
+ if(prompt == ""):
345
+ prompt = "a person"
346
  print(f"Executing prompt: {prompt}")
347
+ #print("Selected State: ", selected_state_index)
348
+ #print(sdxl_loras[selected_state_index]["repo"])
349
  if negative == "":
350
+ negative = None
351
+ print("Custom Loaded LoRA: ", custom_lora_path)
 
 
 
352
  if not selected_state and not custom_lora_path:
353
  raise gr.Error("You must select a style")
354
  elif custom_lora_path:
 
356
  full_path_lora = custom_lora_path
357
  else:
358
  repo_name = sdxl_loras[selected_state_index]["repo"]
359
+ weight_name = sdxl_loras[selected_state_index]["weights"]
360
  full_path_lora = state_dicts[repo_name]["saved_name"]
361
+ print("Full path LoRA ", full_path_lora)
362
+ #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
363
+ cross_attention_kwargs = None
 
 
364
  et = time.time()
365
+ elapsed_time = et - st
366
+ print('Small content processing took: ', elapsed_time, 'seconds')
 
 
 
 
 
 
 
367
 
368
  st = time.time()
369
+ image = generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, full_path_lora, lora_scale, sdxl_loras, selected_state_index, st)
 
 
 
 
 
370
  return (face_image, image), gr.update(visible=True)
371
 
372
  run_lora.zerogpu = True
373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  def shuffle_gallery(sdxl_loras):
375
  random.shuffle(sdxl_loras)
376
  return [(item["image"], item["title"]) for item in sdxl_loras], sdxl_loras
 
386
  return classify_gallery(sdxl_loras)
387
 
388
  def deselect():
389
+ return gr.Gallery(selected_index=None)
390
 
391
  def get_huggingface_safetensors(link):
392
+ split_link = link.split("/")
393
+ if(len(split_link) == 2):
394
+ model_card = ModelCard.load(link)
395
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
396
+ trigger_word = model_card.data.get("instance_prompt", "")
397
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
398
+ fs = HfFileSystem()
399
+ try:
400
+ list_of_files = fs.ls(link, detail=False)
401
+ for file in list_of_files:
402
+ if(file.endswith(".safetensors")):
403
+ safetensors_name = file.replace("/", "_")
404
+ if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
405
+ fs.get_file(file, lpath=f"{lora_archive}/{safetensors_name}")
406
+ if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
407
+ image_elements = file.split("/")
408
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
409
+ except:
410
+ gr.Warning("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
411
+ raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
412
+ return split_link[1], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
413
 
414
  def get_civitai_safetensors(link):
415
+ link_split = link.split("civitai.com/")
416
+ pattern = re.compile(r'models\/(\d+)')
417
+ regex_match = pattern.search(link_split[1])
418
+ if(regex_match):
419
+ civitai_model_id = regex_match.group(1)
420
+ else:
421
+ gr.Warning("No CivitAI model id found in your URL")
422
+ raise Exception("No CivitAI model id found in your URL")
423
+ model_request_url = f"https://civitai.com/api/v1/models/{civitai_model_id}?token={os.getenv('CIVITAI_TOKEN')}"
424
+ x = requests.get(model_request_url)
425
+ if(x.status_code != 200):
426
+ raise Exception("Invalid CivitAI URL")
427
+ model_data = x.json()
428
+ #if(model_data["nsfw"] == True or model_data["nsfwLevel"] > 20):
429
+ # gr.Warning("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.")
430
+ # raise Exception("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.")
431
+ if(model_data["type"] != "LORA"):
432
+ gr.Warning("The model isn't tagged at CivitAI as a LoRA")
433
+ raise Exception("The model isn't tagged at CivitAI as a LoRA")
434
+ model_link_download = None
435
+ image_url = None
436
+ trigger_word = ""
437
+ for model in model_data["modelVersions"]:
438
+ if(model["baseModel"] == "SDXL 1.0"):
439
+ model_link_download = f"{model['downloadUrl']}/?token={os.getenv('CIVITAI_TOKEN')}"
440
+ safetensors_name = model["files"][0]["name"]
441
+ if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
442
+ safetensors_file_request = requests.get(model_link_download)
443
+ if(safetensors_file_request.status_code != 200):
444
+ raise Exception("Invalid CivitAI download link")
445
+ with open(f"{lora_archive}/{safetensors_name}", 'wb') as file:
446
+ file.write(safetensors_file_request.content)
447
+ trigger_word = model.get("trainedWords", [""])[0]
448
+ for image in model["images"]:
449
+ if(image["nsfwLevel"] == 1):
450
+ image_url = image["url"]
451
+ break
452
+ break
453
+ if(not model_link_download):
454
+ gr.Warning("We couldn't find a SDXL LoRA on the model you've sent")
455
+ raise Exception("We couldn't find a SDXL LoRA on the model you've sent")
456
+ return model_data["name"], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
457
+
458
  def check_custom_model(link):
459
  if(link.startswith("https://")):
460
  if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
 
465
  else:
466
  return get_huggingface_safetensors(link)
467
 
468
+ def show_loading_widget():
469
+ return gr.update(visible=True)
470
+
471
  def load_custom_lora(link):
472
  if(link):
473
  try:
 
493
 
494
  def remove_custom_lora():
495
  return "", gr.update(visible=False), gr.update(visible=False), None
 
 
496
  with gr.Blocks(css="custom.css") as demo:
497
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
498
  title = gr.HTML(
499
  """<h1><img src="https://i.imgur.com/DVoGw04.png">
500
+ <span>Face to All<br><small style="
501
  font-size: 13px;
502
  display: block;
503
  font-weight: normal;
504
  opacity: 0.75;
505
+ ">🧨 diffusers InstantID + ControlNet<br> inspired by fofr's <a href="https://github.com/fofr/cog-face-to-many" target="_blank">face-to-many</a></small></span></h1>""",
506
  elem_id="title",
507
  )
508
  selected_state = gr.State()
509
  custom_loaded_lora = gr.State()
 
510
  with gr.Row(elem_id="main_app"):
511
  with gr.Column(scale=4, elem_id="box_column"):
512
  with gr.Group(elem_id="gallery_box"):
513
+ photo = gr.Image(label="Upload a picture of yourself", interactive=True, type="pil", height=300)
514
+ selected_loras = gr.Gallery(label="Selected LoRAs", height=80, show_share_button=False, visible=False, elem_id="gallery_selected", )
515
+ #order_gallery = gr.Radio(choices=["random", "likes"], value="random", label="Order by", elem_id="order_radio")
516
+ #new_gallery = gr.Gallery(
517
+ # label="New LoRAs",
518
+ # elem_id="gallery_new",
519
+ # columns=3,
520
+ # value=[(item["image"], item["title"]) for item in sdxl_loras_raw_new], allow_preview=False, show_share_button=False)
521
  gallery = gr.Gallery(
522
+ #value=[(item["image"], item["title"]) for item in sdxl_loras],
523
  label="Pick a style from the gallery",
524
  allow_preview=False,
525
  columns=4,
 
530
  custom_model = gr.Textbox(label="or enter a custom Hugging Face or CivitAI SDXL LoRA", placeholder="Paste Hugging Face or CivitAI model path...")
531
  custom_model_card = gr.HTML(visible=False)
532
  custom_model_button = gr.Button("Remove custom LoRA", visible=False)
 
533
  with gr.Column(scale=5):
534
  with gr.Row():
535
+ prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt")
 
536
  button = gr.Button("Run", elem_id="run_button")
 
537
  result = ImageSlider(
538
  interactive=False, label="Generated Image", elem_id="result-image", position=0.1
539
  )
 
540
  with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
541
  community_icon = gr.HTML(community_icon_html)
542
  loading_icon = gr.HTML(loading_icon_html)
543
  share_button = gr.Button("Share to community", elem_id="share-btn")
 
544
  with gr.Accordion("Advanced options", open=False):
 
545
  negative = gr.Textbox(label="Negative Prompt")
546
  weight = gr.Slider(0, 10, value=0.9, step=0.1, label="LoRA weight")
547
+ face_strength = gr.Slider(0, 2, value=0.85, step=0.01, label="Face strength", info="Higher values increase the face likeness but reduce the creative liberty of the models")
548
+ image_strength = gr.Slider(0, 1, value=0.15, step=0.01, label="Image strength", info="Higher values increase the similarity with the structure/colors of the original photo")
549
+ guidance_scale = gr.Slider(0, 50, value=7, step=0.1, label="Guidance Scale")
550
+ depth_control_scale = gr.Slider(0, 1, value=0.8, step=0.01, label="Zoe Depth ControlNet strenght")
 
 
 
551
  prompt_title = gr.Markdown(
552
  value="### Click on a LoRA in the gallery to select it",
553
  visible=True,
554
  elem_id="selected_lora",
555
  )
556
+ #order_gallery.change(
557
+ # fn=swap_gallery,
558
+ # inputs=[order_gallery, gr_sdxl_loras],
559
+ # outputs=[gallery, gr_sdxl_loras],
560
+ # queue=False
561
+ #)
562
  custom_model.input(
563
  fn=load_custom_lora,
564
  inputs=[custom_model],
565
  outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title],
566
  )
 
567
  custom_model_button.click(
568
  fn=remove_custom_lora,
569
  outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora]
570
  )
 
571
  gallery.select(
572
  fn=update_selection,
573
  inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
574
  outputs=[prompt_title, prompt, face_strength, image_strength, weight, depth_control_scale, negative, selected_state],
575
  show_progress=False
576
  )
577
+ #new_gallery.select(
578
+ # fn=update_selection,
579
+ # inputs=[gr_sdxl_loras_new, gr.State(True)],
580
+ # outputs=[prompt_title, prompt, prompt, selected_state, gallery],
581
+ # queue=False,
582
+ # show_progress=False
583
+ #)
584
  prompt.submit(
585
  fn=check_selected,
586
  inputs=[selected_state, custom_loaded_lora],
587
  show_progress=False
588
  ).success(
589
  fn=run_lora,
590
+ inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
 
591
  outputs=[result, share_group],
592
  )
 
593
  button.click(
594
  fn=check_selected,
595
  inputs=[selected_state, custom_loaded_lora],
596
  show_progress=False
597
  ).success(
598
  fn=run_lora,
599
+ inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
 
600
  outputs=[result, share_group],
601
  )
 
602
  share_button.click(None, [], [], js=share_js)
603
+ demo.load(fn=classify_gallery, inputs=[gr_sdxl_loras], outputs=[gallery, gr_sdxl_loras], js=js)
604
 
605
  demo.queue(default_concurrency_limit=None, api_open=True)
606
  demo.launch(share=True)