primerz commited on
Commit
757dea6
·
verified ·
1 Parent(s): eb4d4f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -188
app.py CHANGED
@@ -1,9 +1,12 @@
1
  import gradio as gr
2
  import torch
3
- import spaces
 
 
 
 
4
  torch.jit.script = lambda f: f
5
  import timm
6
- import time
7
 
8
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
9
  from safetensors.torch import load_file
@@ -28,7 +31,6 @@ from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, UNet2DConditio
28
  import cv2
29
  import torch
30
  import numpy as np
31
- from PIL import Image
32
 
33
  from insightface.app import FaceAnalysis
34
  from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
@@ -159,50 +161,6 @@ last_lora = ""
159
  last_fused = False
160
  lora_archive = "/data"
161
 
162
- # Enhanced face detection with better face quality filtering
163
- def detect_faces(face_image, use_multiple_faces=False):
164
- """
165
- Detect faces in the image with quality filtering
166
- Returns: list of face info dictionaries, or empty list if no faces
167
- """
168
- try:
169
- face_info_list = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
170
-
171
- if not face_info_list or len(face_info_list) == 0:
172
- print("No faces detected")
173
- return []
174
-
175
- # Filter faces by quality score if available
176
- filtered_faces = []
177
- for face_info in face_info_list:
178
- # Check if face has minimum quality
179
- if 'det_score' in face_info and face_info['det_score'] > 0.5:
180
- filtered_faces.append(face_info)
181
- elif 'det_score' not in face_info:
182
- filtered_faces.append(face_info)
183
-
184
- if not filtered_faces:
185
- print("No high-quality faces detected")
186
- return []
187
-
188
- # Sort faces by size (largest first)
189
- filtered_faces = sorted(
190
- filtered_faces,
191
- key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]),
192
- reverse=True
193
- )
194
-
195
- if use_multiple_faces:
196
- print(f"Detected {len(filtered_faces)} high-quality faces")
197
- return filtered_faces
198
- else:
199
- print(f"Using largest face (detected {len(filtered_faces)} total)")
200
- return [filtered_faces[0]]
201
-
202
- except Exception as e:
203
- print(f"Face detection error: {e}")
204
- return []
205
-
206
  def process_face_embeddings_separately(face_info_list):
207
  """
208
  Process face embeddings separately for multi-face generation
@@ -300,42 +258,152 @@ def check_selected(selected_state, custom_lora):
300
  if not selected_state and not custom_lora:
301
  raise gr.Error("You must select a style")
302
 
303
- def resize_image_aspect_ratio(img, max_dim=1280):
304
- width, height = img.size
305
- aspect_ratio = width / height
306
 
307
- if aspect_ratio >= 1: # Landscape or square
308
- new_width = min(max_dim, width)
309
- new_height = int(new_width / aspect_ratio)
310
- else: # Portrait
311
- new_height = min(max_dim, height)
312
- new_width = int(new_height * aspect_ratio)
313
 
314
- new_width = (new_width // 8) * 8
315
- new_height = (new_height // 8) * 8
 
 
 
 
 
 
316
 
317
- return img.resize((new_width, new_height), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
 
320
  def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength,
321
  guidance_scale, depth_control_scale, sdxl_loras, custom_lora, use_multiple_faces=False,
322
  progress=gr.Progress(track_tqdm=True)):
323
  """
324
  Enhanced run_lora with improved face preservation and landscape mode
 
325
  """
326
  print("Custom LoRA:", custom_lora)
327
  custom_lora_path = custom_lora[0] if custom_lora else None
328
  selected_state_index = selected_state.index if selected_state else -1
329
 
330
  st = time.time()
 
 
 
 
 
331
  face_image = resize_image_aspect_ratio(face_image)
332
 
333
- # Enhanced face detection
334
  face_info_list = detect_faces(face_image, use_multiple_faces)
335
  face_detected = len(face_info_list) > 0
336
 
337
  if face_detected:
338
- # CHANGED: Process faces separately instead of averaging
339
  face_embeddings = process_face_embeddings_separately(face_info_list)
340
  face_kps = create_face_kps_image(face_image, face_info_list)
341
  print(f"Processing with {len(face_info_list)} face(s) separately")
@@ -396,7 +464,7 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
396
  et = time.time()
397
  print('Prompt processing took:', et - st, 'seconds')
398
 
399
- # IMPROVED: Better parameter adjustment for face/landscape modes
400
  if not face_detected:
401
  # Enhanced landscape mode parameters
402
  face_strength = 0.0
@@ -411,22 +479,32 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
411
 
412
  st = time.time()
413
 
414
- # Generate single image with best face (or landscape)
415
- image = generate_image(
416
- prompt, negative, face_emb, face_image, face_kps, image_strength,
417
- guidance_scale, face_strength, depth_control_scale, repo_name,
418
- full_path_lora, lora_scale, sdxl_loras, selected_state_index, face_detected, st
419
- )
420
-
 
 
 
 
 
 
 
 
421
  return (face_image, image), gr.update(visible=True)
422
 
423
- run_lora.zerogpu = True
424
-
425
 
426
- @spaces.GPU(duration=90) # Increased duration for better quality
427
- def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale,
428
- face_strength, depth_control_scale, repo_name, loaded_state_dict, lora_scale,
429
- sdxl_loras, selected_state_index, face_detected, st):
 
 
 
 
430
  global last_fused, last_lora
431
 
432
  print("Loaded state dict:", loaded_state_dict)
@@ -502,6 +580,8 @@ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_stren
502
  num_inference_steps = 50 # Increased for better quality
503
 
504
  print("Generating image...")
 
 
505
  image = pipe(
506
  prompt_embeds=conditioning,
507
  pooled_prompt_embeds=pooled,
@@ -518,128 +598,79 @@ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_stren
518
  controlnet_conditioning_scale=control_scales,
519
  ).images[0]
520
 
 
 
521
  last_lora = repo_name
522
  return image
523
 
524
- def shuffle_gallery(sdxl_loras):
525
- random.shuffle(sdxl_loras)
526
- return [(item["image"], item["title"]) for item in sdxl_loras], sdxl_loras
527
 
528
- def classify_gallery(sdxl_loras):
529
- sorted_gallery = sorted(sdxl_loras, key=lambda x: x.get("likes", 0), reverse=True)
530
- return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
531
-
532
- def swap_gallery(order, sdxl_loras):
533
- if(order == "random"):
534
- return shuffle_gallery(sdxl_loras)
535
- else:
536
- return classify_gallery(sdxl_loras)
537
 
538
- def deselect():
539
- return gr.Gallery(selected_index=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
 
541
- def get_huggingface_safetensors(link):
542
- split_link = link.split("/")
543
- if(len(split_link) == 2):
544
- model_card = ModelCard.load(link)
545
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
546
- trigger_word = model_card.data.get("instance_prompt", "")
547
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
548
- fs = HfFileSystem()
549
- try:
550
- list_of_files = fs.ls(link, detail=False)
551
- for file in list_of_files:
552
- if(file.endswith(".safetensors")):
553
- safetensors_name = file.replace("/", "_")
554
- if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
555
- fs.get_file(file, lpath=f"{lora_archive}/{safetensors_name}")
556
- if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
557
- image_elements = file.split("/")
558
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
559
- except:
560
- gr.Warning("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
561
- raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
562
- return split_link[1], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
563
 
564
- def get_civitai_safetensors(link):
565
- link_split = link.split("civitai.com/")
566
- pattern = re.compile(r'models\/(\d+)')
567
- regex_match = pattern.search(link_split[1])
568
- if(regex_match):
569
- civitai_model_id = regex_match.group(1)
570
- else:
571
- gr.Warning("No CivitAI model id found in your URL")
572
- raise Exception("No CivitAI model id found in your URL")
573
- model_request_url = f"https://civitai.com/api/v1/models/{civitai_model_id}?token={os.getenv('CIVITAI_TOKEN')}"
574
- x = requests.get(model_request_url)
575
- if(x.status_code != 200):
576
- raise Exception("Invalid CivitAI URL")
577
- model_data = x.json()
578
-
579
- if(model_data["type"] != "LORA"):
580
- gr.Warning("The model isn't tagged at CivitAI as a LoRA")
581
- raise Exception("The model isn't tagged at CivitAI as a LoRA")
582
-
583
- model_link_download = None
584
- image_url = None
585
- trigger_word = ""
586
- for model in model_data["modelVersions"]:
587
- if(model["baseModel"] == "SDXL 1.0"):
588
- model_link_download = f"{model['downloadUrl']}/?token={os.getenv('CIVITAI_TOKEN')}"
589
- safetensors_name = model["files"][0]["name"]
590
- if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
591
- safetensors_file_request = requests.get(model_link_download)
592
- if(safetensors_file_request.status_code != 200):
593
- raise Exception("Invalid CivitAI download link")
594
- with open(f"{lora_archive}/{safetensors_name}", 'wb') as file:
595
- file.write(safetensors_file_request.content)
596
- trigger_word = model.get("trainedWords", [""])[0]
597
- for image in model["images"]:
598
- if(image["nsfwLevel"] == 1):
599
- image_url = image["url"]
600
- break
601
- break
602
-
603
- if(not model_link_download):
604
- gr.Warning("We couldn't find a SDXL LoRA on the model you've sent")
605
- raise Exception("We couldn't find a SDXL LoRA on the model you've sent")
606
- return model_data["name"], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
607
 
608
- def check_custom_model(link):
609
- if(link.startswith("https://")):
610
- if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
611
- link_split = link.split("huggingface.co/")
612
- return get_huggingface_safetensors(link_split[1])
613
- elif(link.startswith("https://civitai.com") or link.startswith("https://www.civitai.com")):
614
- return get_civitai_safetensors(link)
615
- else:
616
- return get_huggingface_safetensors(link)
617
 
618
- def load_custom_lora(link):
619
- if(link):
620
- try:
621
- title, path, trigger_word, image = check_custom_model(link)
622
- card = f'''
623
- <div class="custom_lora_card">
624
- <span>Loaded custom LoRA:</span>
625
- <div class="card_internal">
626
- <img src="{image}" />
627
- <div>
628
- <h3>{title}</h3>
629
- <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
630
- </div>
631
- </div>
632
- </div>
633
- '''
634
- return gr.update(visible=True), card, gr.update(visible=True), [path, trigger_word], gr.Gallery(selected_index=None), f"Custom: {path}"
635
- except Exception as e:
636
- gr.Warning("Invalid LoRA: either you entered an invalid link, a non-SDXL LoRA or a LoRA with mature content")
637
- return gr.update(visible=True), "Invalid LoRA: either you entered an invalid link, a non-SDXL LoRA or a LoRA with mature content", gr.update(visible=False), None, gr.update(visible=True), gr.update(visible=True)
638
- else:
639
- return gr.update(visible=False), "", gr.update(visible=False), None, gr.update(visible=True), gr.update(visible=True)
640
 
641
- def remove_custom_lora():
642
- return "", gr.update(visible=False), gr.update(visible=False), None
 
 
 
643
 
644
  # Build Gradio interface
645
  with gr.Blocks(css="custom.css") as demo:
@@ -755,7 +786,7 @@ with gr.Blocks(css="custom.css") as demo:
755
  inputs=[selected_state, custom_loaded_lora],
756
  show_progress=False
757
  ).success(
758
- fn=run_lora,
759
  inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength,
760
  guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora, use_multiple_faces],
761
  outputs=[result, share_group],
 
1
  import gradio as gr
2
  import torch
3
+ import spaces # Make sure this is imported
4
+ import time
5
+ from typing import Optional, List
6
+ import numpy as np
7
+ from PIL import Image
8
  torch.jit.script = lambda f: f
9
  import timm
 
10
 
11
  from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
12
  from safetensors.torch import load_file
 
31
  import cv2
32
  import torch
33
  import numpy as np
 
34
 
35
  from insightface.app import FaceAnalysis
36
  from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
 
161
  last_fused = False
162
  lora_archive = "/data"
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  def process_face_embeddings_separately(face_info_list):
165
  """
166
  Process face embeddings separately for multi-face generation
 
258
  if not selected_state and not custom_lora:
259
  raise gr.Error("You must select a style")
260
 
261
+ def shuffle_gallery(sdxl_loras):
262
+ random.shuffle(sdxl_loras)
263
+ return [(item["image"], item["title"]) for item in sdxl_loras], sdxl_loras
264
 
265
+ def classify_gallery(sdxl_loras):
266
+ sorted_gallery = sorted(sdxl_loras, key=lambda x: x.get("likes", 0), reverse=True)
267
+ return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
 
 
 
268
 
269
+ def swap_gallery(order, sdxl_loras):
270
+ if(order == "random"):
271
+ return shuffle_gallery(sdxl_loras)
272
+ else:
273
+ return classify_gallery(sdxl_loras)
274
+
275
+ def deselect():
276
+ return gr.Gallery(selected_index=None)
277
 
278
+ def get_huggingface_safetensors(link):
279
+ split_link = link.split("/")
280
+ if(len(split_link) == 2):
281
+ model_card = ModelCard.load(link)
282
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
283
+ trigger_word = model_card.data.get("instance_prompt", "")
284
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
285
+ fs = HfFileSystem()
286
+ try:
287
+ list_of_files = fs.ls(link, detail=False)
288
+ for file in list_of_files:
289
+ if(file.endswith(".safetensors")):
290
+ safetensors_name = file.replace("/", "_")
291
+ if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
292
+ fs.get_file(file, lpath=f"{lora_archive}/{safetensors_name}")
293
+ if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
294
+ image_elements = file.split("/")
295
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
296
+ except:
297
+ gr.Warning("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
298
+ raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
299
+ return split_link[1], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
300
 
301
+ def get_civitai_safetensors(link):
302
+ link_split = link.split("civitai.com/")
303
+ pattern = re.compile(r'models\/(\d+)')
304
+ regex_match = pattern.search(link_split[1])
305
+ if(regex_match):
306
+ civitai_model_id = regex_match.group(1)
307
+ else:
308
+ gr.Warning("No CivitAI model id found in your URL")
309
+ raise Exception("No CivitAI model id found in your URL")
310
+ model_request_url = f"https://civitai.com/api/v1/models/{civitai_model_id}?token={os.getenv('CIVITAI_TOKEN')}"
311
+ x = requests.get(model_request_url)
312
+ if(x.status_code != 200):
313
+ raise Exception("Invalid CivitAI URL")
314
+ model_data = x.json()
315
+
316
+ if(model_data["type"] != "LORA"):
317
+ gr.Warning("The model isn't tagged at CivitAI as a LoRA")
318
+ raise Exception("The model isn't tagged at CivitAI as a LoRA")
319
+
320
+ model_link_download = None
321
+ image_url = None
322
+ trigger_word = ""
323
+ for model in model_data["modelVersions"]:
324
+ if(model["baseModel"] == "SDXL 1.0"):
325
+ model_link_download = f"{model['downloadUrl']}/?token={os.getenv('CIVITAI_TOKEN')}"
326
+ safetensors_name = model["files"][0]["name"]
327
+ if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
328
+ safetensors_file_request = requests.get(model_link_download)
329
+ if(safetensors_file_request.status_code != 200):
330
+ raise Exception("Invalid CivitAI download link")
331
+ with open(f"{lora_archive}/{safetensors_name}", 'wb') as file:
332
+ file.write(safetensors_file_request.content)
333
+ trigger_word = model.get("trainedWords", [""])[0]
334
+ for image in model["images"]:
335
+ if(image["nsfwLevel"] == 1):
336
+ image_url = image["url"]
337
+ break
338
+ break
339
+
340
+ if(not model_link_download):
341
+ gr.Warning("We couldn't find a SDXL LoRA on the model you've sent")
342
+ raise Exception("We couldn't find a SDXL LoRA on the model you've sent")
343
+ return model_data["name"], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
344
+
345
+ def check_custom_model(link):
346
+ if(link.startswith("https://")):
347
+ if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
348
+ link_split = link.split("huggingface.co/")
349
+ return get_huggingface_safetensors(link_split[1])
350
+ elif(link.startswith("https://civitai.com") or link.startswith("https://www.civitai.com")):
351
+ return get_civitai_safetensors(link)
352
+ else:
353
+ return get_huggingface_safetensors(link)
354
+
355
+ def load_custom_lora(link):
356
+ if(link):
357
+ try:
358
+ title, path, trigger_word, image = check_custom_model(link)
359
+ card = f'''
360
+ <div class="custom_lora_card">
361
+ <span>Loaded custom LoRA:</span>
362
+ <div class="card_internal">
363
+ <img src="{image}" />
364
+ <div>
365
+ <h3>{title}</h3>
366
+ <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
367
+ </div>
368
+ </div>
369
+ </div>
370
+ '''
371
+ return gr.update(visible=True), card, gr.update(visible=True), [path, trigger_word], gr.Gallery(selected_index=None), f"Custom: {path}"
372
+ except Exception as e:
373
+ gr.Warning("Invalid LoRA: either you entered an invalid link, a non-SDXL LoRA or a LoRA with mature content")
374
+ return gr.update(visible=True), "Invalid LoRA: either you entered an invalid link, a non-SDXL LoRA or a LoRA with mature content", gr.update(visible=False), None, gr.update(visible=True), gr.update(visible=True)
375
+ else:
376
+ return gr.update(visible=False), "", gr.update(visible=False), None, gr.update(visible=True), gr.update(visible=True)
377
+
378
+ def remove_custom_lora():
379
+ return "", gr.update(visible=False), gr.update(visible=False), None
380
 
381
+ @spaces.GPU(duration=120)
382
  def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength,
383
  guidance_scale, depth_control_scale, sdxl_loras, custom_lora, use_multiple_faces=False,
384
  progress=gr.Progress(track_tqdm=True)):
385
  """
386
  Enhanced run_lora with improved face preservation and landscape mode
387
+ FIXED: Proper ZeroGPU decorator, no nested GPU calls
388
  """
389
  print("Custom LoRA:", custom_lora)
390
  custom_lora_path = custom_lora[0] if custom_lora else None
391
  selected_state_index = selected_state.index if selected_state else -1
392
 
393
  st = time.time()
394
+
395
+ # Ensure models are on GPU
396
+ pipe.to(device)
397
+ zoe.to(device)
398
+
399
  face_image = resize_image_aspect_ratio(face_image)
400
 
401
+ # Enhanced face detection (CPU operation - InsightFace uses CPU)
402
  face_info_list = detect_faces(face_image, use_multiple_faces)
403
  face_detected = len(face_info_list) > 0
404
 
405
  if face_detected:
406
+ # Process faces separately instead of averaging
407
  face_embeddings = process_face_embeddings_separately(face_info_list)
408
  face_kps = create_face_kps_image(face_image, face_info_list)
409
  print(f"Processing with {len(face_info_list)} face(s) separately")
 
464
  et = time.time()
465
  print('Prompt processing took:', et - st, 'seconds')
466
 
467
+ # Better parameter adjustment for face/landscape modes
468
  if not face_detected:
469
  # Enhanced landscape mode parameters
470
  face_strength = 0.0
 
479
 
480
  st = time.time()
481
 
482
+ # FIXED: Call non-decorated version (inline generation)
483
+ try:
484
+ image = generate_image_inline(
485
+ prompt, negative, face_emb, face_image, face_kps, image_strength,
486
+ guidance_scale, face_strength, depth_control_scale, repo_name,
487
+ full_path_lora, lora_scale, sdxl_loras, selected_state_index, face_detected, st
488
+ )
489
+ except Exception as e:
490
+ print(f"Generation error: {e}")
491
+ torch.cuda.empty_cache()
492
+ raise gr.Error(f"Image generation failed: {str(e)}")
493
+
494
+ # Cleanup GPU memory
495
+ torch.cuda.empty_cache()
496
+
497
  return (face_image, image), gr.update(visible=True)
498
 
 
 
499
 
500
+ # FIXED: Removed @spaces.GPU decorator - this runs within GPU context
501
+ def generate_image_inline(prompt, negative, face_emb, face_image, face_kps, image_strength,
502
+ guidance_scale, face_strength, depth_control_scale, repo_name,
503
+ loaded_state_dict, lora_scale, sdxl_loras, selected_state_index,
504
+ face_detected, st):
505
+ """
506
+ FIXED: No decorator - called from within GPU context
507
+ """
508
  global last_fused, last_lora
509
 
510
  print("Loaded state dict:", loaded_state_dict)
 
580
  num_inference_steps = 50 # Increased for better quality
581
 
582
  print("Generating image...")
583
+ print(f"GPU Memory before generation: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
584
+
585
  image = pipe(
586
  prompt_embeds=conditioning,
587
  pooled_prompt_embeds=pooled,
 
598
  controlnet_conditioning_scale=control_scales,
599
  ).images[0]
600
 
601
+ print(f"GPU Memory after generation: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
602
+
603
  last_lora = repo_name
604
  return image
605
 
 
 
 
606
 
607
+ # CPU-bound helper functions (no decorators needed)
608
+ def detect_faces(face_image, use_multiple_faces=False):
609
+ """
610
+ Detect faces in the image with quality filtering
611
+ CPU operation - no GPU decorator needed
612
+ """
613
+ try:
614
+ face_info_list = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
 
615
 
616
+ if not face_info_list or len(face_info_list) == 0:
617
+ print("No faces detected")
618
+ return []
619
+
620
+ # Filter faces by quality score if available
621
+ filtered_faces = []
622
+ for face_info in face_info_list:
623
+ # Check if face has minimum quality
624
+ if 'det_score' in face_info and face_info['det_score'] > 0.5:
625
+ filtered_faces.append(face_info)
626
+ elif 'det_score' not in face_info:
627
+ filtered_faces.append(face_info)
628
+
629
+ if not filtered_faces:
630
+ print("No high-quality faces detected")
631
+ return []
632
+
633
+ # Sort faces by size (largest first)
634
+ filtered_faces = sorted(
635
+ filtered_faces,
636
+ key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]),
637
+ reverse=True
638
+ )
639
+
640
+ if use_multiple_faces:
641
+ print(f"Detected {len(filtered_faces)} high-quality faces")
642
+ return filtered_faces
643
+ else:
644
+ print(f"Using largest face (detected {len(filtered_faces)} total)")
645
+ return [filtered_faces[0]]
646
+
647
+ except Exception as e:
648
+ print(f"Face detection error: {e}")
649
+ return []
650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
 
652
+ def resize_image_aspect_ratio(img, max_dim=1280):
653
+ """CPU operation"""
654
+ width, height = img.size
655
+ aspect_ratio = width / height
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
 
657
+ if aspect_ratio >= 1: # Landscape or square
658
+ new_width = min(max_dim, width)
659
+ new_height = int(new_width / aspect_ratio)
660
+ else: # Portrait
661
+ new_height = min(max_dim, height)
662
+ new_width = int(new_height * aspect_ratio)
 
 
 
663
 
664
+ new_width = (new_width // 8) * 8
665
+ new_height = (new_height // 8) * 8
666
+
667
+ return img.resize((new_width, new_height), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668
 
669
+
670
+ def check_selected(selected_state, custom_lora):
671
+ """CPU operation"""
672
+ if not selected_state and not custom_lora:
673
+ raise gr.Error("You must select a style")
674
 
675
  # Build Gradio interface
676
  with gr.Blocks(css="custom.css") as demo:
 
786
  inputs=[selected_state, custom_loaded_lora],
787
  show_progress=False
788
  ).success(
789
+ fn=run_lora, # This now has proper @spaces.GPU decorator
790
  inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength,
791
  guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora, use_multiple_faces],
792
  outputs=[result, share_group],