openfree commited on
Commit
8799b68
·
verified ·
1 Parent(s): 633485a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -24
app.py CHANGED
@@ -364,8 +364,12 @@ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps
364
  ).images[0]
365
  return final_image
366
 
367
- def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
 
 
368
  try:
 
 
369
  # 한글 감지 및 번역 (이 부분은 그대로 유지)
370
  if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
371
  translated = translator(prompt, max_length=512)[0]['translation_text']
@@ -402,26 +406,35 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
402
  lora_names = []
403
  lora_weights = []
404
  with calculateDuration("Loading LoRA weights"):
 
 
405
  for idx, lora in enumerate(selected_loras):
406
  try:
407
  lora_name = f"lora_{idx}"
408
  lora_path = lora['repo']
 
 
 
 
 
409
  weight_name = lora.get("weights")
410
  print(f"Loading LoRA {lora_name} from {lora_path}")
 
411
  if image_input is not None:
412
  if weight_name:
413
- pipe_i2i.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=lora_name)
 
414
  else:
415
- pipe_i2i.load_lora_weights(lora_path, adapter_name=lora_name)
 
416
  else:
417
  if weight_name:
418
- pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=lora_name)
 
419
  else:
420
- pipe.load_lora_weights(lora_path, adapter_name=lora_name)
421
- lora_names.append(lora_name)
422
- lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2 if idx == 1 else lora_scale_3)
423
- except Exception as e:
424
- print(f"Failed to load LoRA {lora_name}: {str(e)}")
425
 
426
  print("Loaded LoRAs:", lora_names)
427
  print("Adapter weights:", lora_weights)
@@ -525,44 +538,107 @@ def update_history(new_image, history):
525
 
526
  def refresh_models(huggingface_token):
527
  try:
528
- # HuggingFace API를 통해 사용자의 모델 검색
529
- headers = {"Authorization": f"Bearer {huggingface_token}"}
530
- api_url = "https://huggingface.co/api/models"
531
- params = {
532
- "author": huggingface_token, # 사용자의 모델만 검색
533
- "filter": "base_model:black-forest-labs/FLUX.1-dev" # FLUX 모델만 필터링
534
  }
535
 
536
- response = requests.get(api_url, headers=headers, params=params)
 
 
 
 
 
 
 
 
 
 
537
  if response.status_code != 200:
538
  raise Exception("Failed to fetch models from HuggingFace")
539
 
540
- user_models = response.json()
 
 
 
 
 
 
541
 
542
  # 새로운 모델 정보 생성
543
  new_models = []
544
  for model in user_models:
545
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
  model_info = {
547
- "image": f"https://huggingface.co/{model['id']}/resolve/main/preview.png",
548
- "title": model['id'].split('/')[-1],
549
- "repo": model['id'],
550
  "weights": "pytorch_lora_weights.safetensors",
551
- "trigger_word": "" # 필요한 경우 모델 카드에서 추출
 
552
  }
553
  new_models.append(model_info)
 
554
  except Exception as e:
555
  print(f"Error processing model {model['id']}: {str(e)}")
556
  continue
557
-
558
- # 기존 LoRA 목록과 병합
559
- updated_loras = new_models + loras
560
 
561
  return updated_loras
562
  except Exception as e:
563
  print(f"Error refreshing models: {str(e)}")
564
  return loras
565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  custom_theme = gr.themes.Base(
567
  primary_hue="blue",
568
  secondary_hue="purple",
 
364
  ).images[0]
365
  return final_image
366
 
367
+ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices,
368
+ lora_scale_1, lora_scale_2, lora_scale_3, randomize_seed, seed,
369
+ width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
370
  try:
371
+
372
+
373
  # 한글 감지 및 번역 (이 부분은 그대로 유지)
374
  if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
375
  translated = translator(prompt, max_length=512)[0]['translation_text']
 
406
  lora_names = []
407
  lora_weights = []
408
  with calculateDuration("Loading LoRA weights"):
409
+
410
+ # LoRA 로딩 부분 수정
411
  for idx, lora in enumerate(selected_loras):
412
  try:
413
  lora_name = f"lora_{idx}"
414
  lora_path = lora['repo']
415
+
416
+ # Private 모델인 경우 특별 처리
417
+ if lora.get('private', False):
418
+ lora_path = load_private_model(lora_path, huggingface_token)
419
+
420
  weight_name = lora.get("weights")
421
  print(f"Loading LoRA {lora_name} from {lora_path}")
422
+
423
  if image_input is not None:
424
  if weight_name:
425
+ pipe_i2i.load_lora_weights(lora_path, weight_name=weight_name,
426
+ adapter_name=lora_name, token=huggingface_token)
427
  else:
428
+ pipe_i2i.load_lora_weights(lora_path, adapter_name=lora_name,
429
+ token=huggingface_token)
430
  else:
431
  if weight_name:
432
+ pipe.load_lora_weights(lora_path, weight_name=weight_name,
433
+ adapter_name=lora_name, token=huggingface_token)
434
  else:
435
+ pipe.load_lora_weights(lora_path, adapter_name=lora_name,
436
+ token=huggingface_token)
437
+
 
 
438
 
439
  print("Loaded LoRAs:", lora_names)
440
  print("Adapter weights:", lora_weights)
 
538
 
539
  def refresh_models(huggingface_token):
540
  try:
541
+ # HuggingFace API를 통해 사용자의 모델 검색 (private 포함)
542
+ headers = {
543
+ "Authorization": f"Bearer {huggingface_token}",
544
+ "Accept": "application/json"
 
 
545
  }
546
 
547
+ # 사용자 이름 가져오기
548
+ user_info_url = "https://huggingface.co/api/whoami"
549
+ user_response = requests.get(user_info_url, headers=headers)
550
+ if user_response.status_code != 200:
551
+ raise Exception("Failed to get user information")
552
+
553
+ username = user_response.json().get('name')
554
+
555
+ # 사용자의 모든 모델 검색 (private 포함)
556
+ api_url = f"https://huggingface.co/api/models?author={username}"
557
+ response = requests.get(api_url, headers=headers)
558
  if response.status_code != 200:
559
  raise Exception("Failed to fetch models from HuggingFace")
560
 
561
+ all_models = response.json()
562
+
563
+ # FLUX 기반 모델 필터링
564
+ user_models = [
565
+ model for model in all_models
566
+ if model.get('tags') and 'flux' in [tag.lower() for tag in model.get('tags', [])]
567
+ ]
568
 
569
  # 새로운 모델 정보 생성
570
  new_models = []
571
  for model in user_models:
572
  try:
573
+ # 모델 카드 정보 가져오기
574
+ model_id = model['id']
575
+ model_card_url = f"https://huggingface.co/api/models/{model_id}"
576
+ model_info_response = requests.get(model_card_url, headers=headers)
577
+ model_info = model_info_response.json()
578
+
579
+ # 프리뷰 이미지 URL 구성
580
+ preview_images = [
581
+ f"https://huggingface.co/{model_id}/resolve/main/preview.png",
582
+ f"https://huggingface.co/{model_id}/resolve/main/sample.png",
583
+ f"https://huggingface.co/{model_id}/resolve/main/example.png"
584
+ ]
585
+
586
+ # 이미지 존재 확인
587
+ image_url = None
588
+ for preview_url in preview_images:
589
+ img_response = requests.head(preview_url, headers=headers)
590
+ if img_response.status_code == 200:
591
+ image_url = preview_url
592
+ break
593
+
594
+ if not image_url:
595
+ image_url = "path/to/default/image.png" # 기본 이미지
596
+
597
+ # 트리거 워드 추출 시도
598
+ trigger_word = ""
599
+ if 'instance_prompt' in model_info:
600
+ trigger_word = model_info['instance_prompt']
601
+
602
  model_info = {
603
+ "image": image_url,
604
+ "title": f"[Private] {model_id.split('/')[-1]}" if model.get('private') else model_id.split('/')[-1],
605
+ "repo": model_id,
606
  "weights": "pytorch_lora_weights.safetensors",
607
+ "trigger_word": trigger_word,
608
+ "private": model.get('private', False)
609
  }
610
  new_models.append(model_info)
611
+
612
  except Exception as e:
613
  print(f"Error processing model {model['id']}: {str(e)}")
614
  continue
615
+
616
+ # 사용자의 모델을 최상단에 배치
617
+ updated_loras = new_models + [lora for lora in loras if lora['repo'] not in [m['repo'] for m in new_models]]
618
 
619
  return updated_loras
620
  except Exception as e:
621
  print(f"Error refreshing models: {str(e)}")
622
  return loras
623
 
624
+ def load_private_model(model_id, huggingface_token):
625
+ """Private 모델을 로드하는 함수"""
626
+ try:
627
+ headers = {"Authorization": f"Bearer {huggingface_token}"}
628
+
629
+ # 모델 다운로드
630
+ local_dir = snapshot_download(
631
+ repo_id=model_id,
632
+ token=huggingface_token,
633
+ local_dir=f"models/{model_id.replace('/', '_')}",
634
+ local_dir_use_symlinks=False
635
+ )
636
+
637
+ return local_dir
638
+ except Exception as e:
639
+ print(f"Error loading private model {model_id}: {str(e)}")
640
+ raise e
641
+
642
  custom_theme = gr.themes.Base(
643
  primary_hue="blue",
644
  secondary_hue="purple",