Update app.py
Browse files
app.py
CHANGED
@@ -364,8 +364,12 @@ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps
|
|
364 |
).images[0]
|
365 |
return final_image
|
366 |
|
367 |
-
def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices,
|
|
|
|
|
368 |
try:
|
|
|
|
|
369 |
# 한글 감지 및 번역 (이 부분은 그대로 유지)
|
370 |
if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
|
371 |
translated = translator(prompt, max_length=512)[0]['translation_text']
|
@@ -402,26 +406,35 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
|
|
402 |
lora_names = []
|
403 |
lora_weights = []
|
404 |
with calculateDuration("Loading LoRA weights"):
|
|
|
|
|
405 |
for idx, lora in enumerate(selected_loras):
|
406 |
try:
|
407 |
lora_name = f"lora_{idx}"
|
408 |
lora_path = lora['repo']
|
|
|
|
|
|
|
|
|
|
|
409 |
weight_name = lora.get("weights")
|
410 |
print(f"Loading LoRA {lora_name} from {lora_path}")
|
|
|
411 |
if image_input is not None:
|
412 |
if weight_name:
|
413 |
-
pipe_i2i.load_lora_weights(lora_path, weight_name=weight_name,
|
|
|
414 |
else:
|
415 |
-
pipe_i2i.load_lora_weights(lora_path, adapter_name=lora_name
|
|
|
416 |
else:
|
417 |
if weight_name:
|
418 |
-
pipe.load_lora_weights(lora_path, weight_name=weight_name,
|
|
|
419 |
else:
|
420 |
-
pipe.load_lora_weights(lora_path, adapter_name=lora_name
|
421 |
-
|
422 |
-
|
423 |
-
except Exception as e:
|
424 |
-
print(f"Failed to load LoRA {lora_name}: {str(e)}")
|
425 |
|
426 |
print("Loaded LoRAs:", lora_names)
|
427 |
print("Adapter weights:", lora_weights)
|
@@ -525,44 +538,107 @@ def update_history(new_image, history):
|
|
525 |
|
526 |
def refresh_models(huggingface_token):
|
527 |
try:
|
528 |
-
# HuggingFace API를 통해 사용자의 모델 검색
|
529 |
-
headers = {
|
530 |
-
|
531 |
-
|
532 |
-
"author": huggingface_token, # 사용자의 모델만 검색
|
533 |
-
"filter": "base_model:black-forest-labs/FLUX.1-dev" # FLUX 모델만 필터링
|
534 |
}
|
535 |
|
536 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
537 |
if response.status_code != 200:
|
538 |
raise Exception("Failed to fetch models from HuggingFace")
|
539 |
|
540 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
|
542 |
# 새로운 모델 정보 생성
|
543 |
new_models = []
|
544 |
for model in user_models:
|
545 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
546 |
model_info = {
|
547 |
-
"image":
|
548 |
-
"title":
|
549 |
-
"repo":
|
550 |
"weights": "pytorch_lora_weights.safetensors",
|
551 |
-
"trigger_word":
|
|
|
552 |
}
|
553 |
new_models.append(model_info)
|
|
|
554 |
except Exception as e:
|
555 |
print(f"Error processing model {model['id']}: {str(e)}")
|
556 |
continue
|
557 |
-
|
558 |
-
#
|
559 |
-
updated_loras = new_models + loras
|
560 |
|
561 |
return updated_loras
|
562 |
except Exception as e:
|
563 |
print(f"Error refreshing models: {str(e)}")
|
564 |
return loras
|
565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
566 |
custom_theme = gr.themes.Base(
|
567 |
primary_hue="blue",
|
568 |
secondary_hue="purple",
|
|
|
364 |
).images[0]
|
365 |
return final_image
|
366 |
|
367 |
+
def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices,
|
368 |
+
lora_scale_1, lora_scale_2, lora_scale_3, randomize_seed, seed,
|
369 |
+
width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
|
370 |
try:
|
371 |
+
|
372 |
+
|
373 |
# 한글 감지 및 번역 (이 부분은 그대로 유지)
|
374 |
if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
|
375 |
translated = translator(prompt, max_length=512)[0]['translation_text']
|
|
|
406 |
lora_names = []
|
407 |
lora_weights = []
|
408 |
with calculateDuration("Loading LoRA weights"):
|
409 |
+
|
410 |
+
# LoRA 로딩 부분 수정
|
411 |
for idx, lora in enumerate(selected_loras):
|
412 |
try:
|
413 |
lora_name = f"lora_{idx}"
|
414 |
lora_path = lora['repo']
|
415 |
+
|
416 |
+
# Private 모델인 경우 특별 처리
|
417 |
+
if lora.get('private', False):
|
418 |
+
lora_path = load_private_model(lora_path, huggingface_token)
|
419 |
+
|
420 |
weight_name = lora.get("weights")
|
421 |
print(f"Loading LoRA {lora_name} from {lora_path}")
|
422 |
+
|
423 |
if image_input is not None:
|
424 |
if weight_name:
|
425 |
+
pipe_i2i.load_lora_weights(lora_path, weight_name=weight_name,
|
426 |
+
adapter_name=lora_name, token=huggingface_token)
|
427 |
else:
|
428 |
+
pipe_i2i.load_lora_weights(lora_path, adapter_name=lora_name,
|
429 |
+
token=huggingface_token)
|
430 |
else:
|
431 |
if weight_name:
|
432 |
+
pipe.load_lora_weights(lora_path, weight_name=weight_name,
|
433 |
+
adapter_name=lora_name, token=huggingface_token)
|
434 |
else:
|
435 |
+
pipe.load_lora_weights(lora_path, adapter_name=lora_name,
|
436 |
+
token=huggingface_token)
|
437 |
+
|
|
|
|
|
438 |
|
439 |
print("Loaded LoRAs:", lora_names)
|
440 |
print("Adapter weights:", lora_weights)
|
|
|
538 |
|
539 |
def refresh_models(huggingface_token):
|
540 |
try:
|
541 |
+
# HuggingFace API를 통해 사용자의 모델 검색 (private 포함)
|
542 |
+
headers = {
|
543 |
+
"Authorization": f"Bearer {huggingface_token}",
|
544 |
+
"Accept": "application/json"
|
|
|
|
|
545 |
}
|
546 |
|
547 |
+
# 사용자 이름 가져오기
|
548 |
+
user_info_url = "https://huggingface.co/api/whoami"
|
549 |
+
user_response = requests.get(user_info_url, headers=headers)
|
550 |
+
if user_response.status_code != 200:
|
551 |
+
raise Exception("Failed to get user information")
|
552 |
+
|
553 |
+
username = user_response.json().get('name')
|
554 |
+
|
555 |
+
# 사용자의 모든 모델 검색 (private 포함)
|
556 |
+
api_url = f"https://huggingface.co/api/models?author={username}"
|
557 |
+
response = requests.get(api_url, headers=headers)
|
558 |
if response.status_code != 200:
|
559 |
raise Exception("Failed to fetch models from HuggingFace")
|
560 |
|
561 |
+
all_models = response.json()
|
562 |
+
|
563 |
+
# FLUX 기반 모델 필터링
|
564 |
+
user_models = [
|
565 |
+
model for model in all_models
|
566 |
+
if model.get('tags') and 'flux' in [tag.lower() for tag in model.get('tags', [])]
|
567 |
+
]
|
568 |
|
569 |
# 새로운 모델 정보 생성
|
570 |
new_models = []
|
571 |
for model in user_models:
|
572 |
try:
|
573 |
+
# 모델 카드 정보 가져오기
|
574 |
+
model_id = model['id']
|
575 |
+
model_card_url = f"https://huggingface.co/api/models/{model_id}"
|
576 |
+
model_info_response = requests.get(model_card_url, headers=headers)
|
577 |
+
model_info = model_info_response.json()
|
578 |
+
|
579 |
+
# 프리뷰 이미지 URL 구성
|
580 |
+
preview_images = [
|
581 |
+
f"https://huggingface.co/{model_id}/resolve/main/preview.png",
|
582 |
+
f"https://huggingface.co/{model_id}/resolve/main/sample.png",
|
583 |
+
f"https://huggingface.co/{model_id}/resolve/main/example.png"
|
584 |
+
]
|
585 |
+
|
586 |
+
# 이미지 존재 확인
|
587 |
+
image_url = None
|
588 |
+
for preview_url in preview_images:
|
589 |
+
img_response = requests.head(preview_url, headers=headers)
|
590 |
+
if img_response.status_code == 200:
|
591 |
+
image_url = preview_url
|
592 |
+
break
|
593 |
+
|
594 |
+
if not image_url:
|
595 |
+
image_url = "path/to/default/image.png" # 기본 이미지
|
596 |
+
|
597 |
+
# 트리거 워드 추출 시도
|
598 |
+
trigger_word = ""
|
599 |
+
if 'instance_prompt' in model_info:
|
600 |
+
trigger_word = model_info['instance_prompt']
|
601 |
+
|
602 |
model_info = {
|
603 |
+
"image": image_url,
|
604 |
+
"title": f"[Private] {model_id.split('/')[-1]}" if model.get('private') else model_id.split('/')[-1],
|
605 |
+
"repo": model_id,
|
606 |
"weights": "pytorch_lora_weights.safetensors",
|
607 |
+
"trigger_word": trigger_word,
|
608 |
+
"private": model.get('private', False)
|
609 |
}
|
610 |
new_models.append(model_info)
|
611 |
+
|
612 |
except Exception as e:
|
613 |
print(f"Error processing model {model['id']}: {str(e)}")
|
614 |
continue
|
615 |
+
|
616 |
+
# 사용자의 모델을 최상단에 배치
|
617 |
+
updated_loras = new_models + [lora for lora in loras if lora['repo'] not in [m['repo'] for m in new_models]]
|
618 |
|
619 |
return updated_loras
|
620 |
except Exception as e:
|
621 |
print(f"Error refreshing models: {str(e)}")
|
622 |
return loras
|
623 |
|
624 |
+
def load_private_model(model_id, huggingface_token):
|
625 |
+
"""Private 모델을 로드하는 함수"""
|
626 |
+
try:
|
627 |
+
headers = {"Authorization": f"Bearer {huggingface_token}"}
|
628 |
+
|
629 |
+
# 모델 다운로드
|
630 |
+
local_dir = snapshot_download(
|
631 |
+
repo_id=model_id,
|
632 |
+
token=huggingface_token,
|
633 |
+
local_dir=f"models/{model_id.replace('/', '_')}",
|
634 |
+
local_dir_use_symlinks=False
|
635 |
+
)
|
636 |
+
|
637 |
+
return local_dir
|
638 |
+
except Exception as e:
|
639 |
+
print(f"Error loading private model {model_id}: {str(e)}")
|
640 |
+
raise e
|
641 |
+
|
642 |
custom_theme = gr.themes.Base(
|
643 |
primary_hue="blue",
|
644 |
secondary_hue="purple",
|