import numpy as np from PIL import Image from huggingface_hub import snapshot_download, login from leffa.transform import LeffaTransform from leffa.model import LeffaModel from leffa.inference import LeffaInference from utils.garment_agnostic_mask_predictor import AutoMasker from utils.densepose_predictor import DensePosePredictor from utils.utils import resize_and_center import spaces import torch from diffusers import DiffusionPipeline from transformers import pipeline import gradio as gr import os import random import gc from contextlib import contextmanager # 상수 정의 MAX_SEED = 2**32 - 1 BASE_MODEL = "black-forest-labs/FLUX.1-dev" MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style" CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA" # 메모리 관리를 위한 데코레이터 def safe_model_call(func): def wrapper(*args, **kwargs): try: clear_memory() result = func(*args, **kwargs) clear_memory() return result except Exception as e: clear_memory() print(f"Error in {func.__name__}: {str(e)}") raise return wrapper # 메모리 관리를 위한 컨텍스트 매니저 @contextmanager def torch_gc(): try: yield finally: gc.collect() if torch.cuda.is_available() and torch.cuda.current_device() >= 0: with torch.cuda.device('cuda'): torch.cuda.empty_cache() def clear_memory(): gc.collect() def setup_environment(): os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN not found in environment variables") login(token=HF_TOKEN) return HF_TOKEN def contains_korean(text): return any(ord('가') <= ord(char) <= ord('힣') for char in text) @spaces.GPU() def get_translator(): return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda") # 환경 설정 실행 setup_environment() @spaces.GPU() def initialize_fashion_pipe(): with torch_gc(): pipe = DiffusionPipeline.from_pretrained( BASE_MODEL, torch_dtype=torch.float16, ) return pipe.to("cuda") def setup(): # Leffa 체크포인트 다운로드만 수행 snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts") @spaces.GPU() def get_translator(): with torch_gc(): return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda") @safe_model_call def get_mask_predictor(): global mask_predictor if mask_predictor is None: mask_predictor = AutoMasker( densepose_path="./ckpts/densepose", schp_path="./ckpts/schp", ) return mask_predictor @safe_model_call def get_densepose_predictor(): global densepose_predictor if densepose_predictor is None: densepose_predictor = DensePosePredictor( config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml", weights_path="./ckpts/densepose/model_final_162be9.pkl", ) return densepose_predictor @spaces.GPU() def get_vt_model(): with torch_gc(): model = LeffaModel( pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting", pretrained_model="./ckpts/virtual_tryon.pth" ) model = model.half() return model.to("cuda"), LeffaInference(model=model) def load_lora(pipe, lora_path): try: pipe.unload_lora_weights() except: pass try: pipe.load_lora_weights(lora_path) return pipe except Exception as e: print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}") return pipe @spaces.GPU() def get_mask_predictor(): global mask_predictor if mask_predictor is None: mask_predictor = AutoMasker( densepose_path="./ckpts/densepose", schp_path="./ckpts/schp", ) return mask_predictor # 모델 초기화 함수 수정 @spaces.GPU() def initialize_fashion_pipe(): try: pipe = DiffusionPipeline.from_pretrained( BASE_MODEL, torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False ).to("cuda") pipe.enable_model_cpu_offload() return pipe except Exception as e: print(f"Error initializing fashion pipe: {e}") raise @spaces.GPU() def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)): try: # 한글 처리 if contains_korean(prompt): with torch.inference_mode(): translator = get_translator() translated = translator(prompt)[0]['translation_text'] actual_prompt = translated else: actual_prompt = prompt # 파이프라인 초기화 pipe = initialize_fashion_pipe() # LoRA 설정 if mode == "Generate Model": pipe.load_lora_weights(MODEL_LORA_REPO) trigger_word = "fashion photography, professional model" else: pipe.load_lora_weights(CLOTHES_LORA_REPO) trigger_word = "upper clothing, fashion item" # 파라미터 제한 width = min(width, 768) height = min(height, 768) steps = min(steps, 30) # 시드 설정 if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator("cuda").manual_seed(seed) # 이미지 생성 with torch.inference_mode(): output = pipe( prompt=f"{actual_prompt} {trigger_word}", num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, cross_attention_kwargs={"scale": lora_scale}, ) image = output.images[0] # 메모리 정리 del pipe torch.cuda.empty_cache() gc.collect() return image, seed except Exception as e: print(f"Error in generate_fashion: {str(e)}") raise gr.Error(f"Generation failed: {str(e)}") class ModelManager: def __init__(self): self.mask_predictor = None self.densepose_predictor = None self.translator = None @spaces.GPU() def get_mask_predictor(self): if self.mask_predictor is None: self.mask_predictor = AutoMasker( densepose_path="./ckpts/densepose", schp_path="./ckpts/schp", ) return self.mask_predictor @spaces.GPU() def get_densepose_predictor(self): if self.densepose_predictor is None: self.densepose_predictor = DensePosePredictor( config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml", weights_path="./ckpts/densepose/model_final_162be9.pkl", ) return self.densepose_predictor @spaces.GPU() def get_translator(self): if self.translator is None: self.translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda") return self.translator # 모델 매니저 인스턴스 생성 model_manager = ModelManager() @spaces.GPU() def leffa_predict(src_image_path, ref_image_path, control_type): try: with torch_gc(): # 모델 초기화 model, inference = get_vt_model() # 이미지 처리 src_image = Image.open(src_image_path) ref_image = Image.open(ref_image_path) src_image = resize_and_center(src_image, 768, 1024) ref_image = resize_and_center(ref_image, 768, 1024) src_image_array = np.array(src_image) ref_image_array = np.array(ref_image) # Mask 및 DensePose 처리 with torch.inference_mode(): src_image = src_image.convert("RGB") mask_pred = model_manager.get_mask_predictor() mask = mask_pred(src_image, "upper")["mask"] dense_pred = model_manager.get_densepose_predictor() src_image_seg_array = dense_pred.predict_seg(src_image_array) densepose = Image.fromarray(src_image_seg_array) # Leffa 변환 및 추론 transform = LeffaTransform() data = { "src_image": [src_image], "ref_image": [ref_image], "mask": [mask], "densepose": [densepose], } data = transform(data) with torch.inference_mode(): output = inference(data) # 메모리 정리 del model del inference torch.cuda.empty_cache() gc.collect() return np.array(output["generated_image"][0]) except Exception as e: print(f"Error in leffa_predict: {str(e)}") raise @spaces.GPU() def leffa_predict_vt(src_image_path, ref_image_path): try: return leffa_predict(src_image_path, ref_image_path, "virtual_tryon") except Exception as e: print(f"Error in leffa_predict_vt: {str(e)}") raise @spaces.GPU() def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85): try: with torch_gc(): # 한글 처리 if contains_korean(prompt): translator = model_manager.get_translator() with torch.inference_mode(): translated = translator(prompt)[0]['translation_text'] actual_prompt = translated else: actual_prompt = prompt # 파이프라인 초기화 pipe = DiffusionPipeline.from_pretrained( BASE_MODEL, torch_dtype=torch.float16, ) pipe = pipe.to("cuda") # LoRA 설정 if mode == "Generate Model": pipe.load_lora_weights(MODEL_LORA_REPO) trigger_word = "fashion photography, professional model" else: pipe.load_lora_weights(CLOTHES_LORA_REPO) trigger_word = "upper clothing, fashion item" # 이미지 생성 with torch.inference_mode(): result = pipe( prompt=f"{actual_prompt} {trigger_word}", num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=torch.Generator("cuda").manual_seed( seed if seed is not None else torch.randint(0, 2**32 - 1, (1,)).item() ), joint_attention_kwargs={"scale": lora_scale}, ).images[0] # 메모리 정리 del pipe return result, seed except Exception as e: raise gr.Error(f"Generation failed: {str(e)}") # 초기 설정 실행 setup() def create_interface(): with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo: gr.Markdown("# 🎭 FitGen:Fashion Studio & Virtual Try-on") with gr.Tabs(): # 패션 생성 탭 with gr.Tab("Fashion Generation"): with gr.Column(): mode = gr.Radio( choices=["Generate Model", "Generate Clothes"], label="Generation Mode", value="Generate Model" ) # 예제 프롬프트 설정 example_model_prompts = [ "professional fashion model, full body shot, standing pose, natural lighting, studio background, high fashion, elegant pose", "fashion model portrait, upper body, confident pose, fashion photography, neutral background, professional lighting", "stylish fashion model, three-quarter view, editorial pose, high-end fashion magazine style, minimal background" ] example_clothes_prompts = [ "luxury designer sweater, cashmere material, cream color, cable knit pattern, high-end fashion, product photography", "elegant business blazer, tailored fit, charcoal grey, premium wool fabric, professional wear", "modern streetwear hoodie, oversized fit, minimalist design, premium cotton, urban style" ] prompt = gr.TextArea( label="Fashion Description (한글 또는 영어)", placeholder="패션 모델이나 의류를 설명하세요..." ) # 예제 섹션 추가 gr.Examples( examples=example_model_prompts + example_clothes_prompts, inputs=prompt, label="Example Prompts" ) with gr.Row(): with gr.Column(): result = gr.Image(label="Generated Result") generate_button = gr.Button("Generate Fashion") with gr.Accordion("Advanced Options", open=False): with gr.Group(): with gr.Row(): with gr.Column(): cfg_scale = gr.Slider( label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.0 ) steps = gr.Slider( label="Steps", minimum=1, maximum=30, step=1, value=30 ) lora_scale = gr.Slider( label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.85 ) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=768, step=64, value=512 ) height = gr.Slider( label="Height", minimum=256, maximum=768, step=64, value=768 ) with gr.Row(): randomize_seed = gr.Checkbox( True, label="Randomize seed" ) seed = gr.Slider( label="Seed", minimum=0, maximum=2**32-1, step=1, value=42 ) # 가상 피팅 탭 with gr.Tab("Virtual Try-on"): with gr.Row(): with gr.Column(): gr.Markdown("#### Person Image") vt_src_image = gr.Image( sources=["upload"], type="filepath", label="Person Image", width=512, height=512, ) gr.Examples( inputs=vt_src_image, examples_per_page=5, examples=["a1.webp", "a2.webp", "a3.webp", "a4.webp", "a5.webp"] ) with gr.Column(): gr.Markdown("#### Garment Image") vt_ref_image = gr.Image( sources=["upload"], type="filepath", label="Garment Image", width=512, height=512, ) gr.Examples( inputs=vt_ref_image, examples_per_page=5, examples=["b1.webp", "b2.webp", "b3.webp", "b4.webp", "c1.png", "c2.png", "c3.png", "c4.png", "c5.png", "c6.png", "c7.png", "c8.png", "c9.png", "c10.png", "c11.png", "c12.png", "c13.png", "c14.png", "c15.png", "c16.png", "b5.webp"] ) with gr.Column(): gr.Markdown("#### Generated Image") vt_gen_image = gr.Image( label="Generated Image", width=512, height=512, ) vt_gen_button = gr.Button("Try-on") vt_gen_button.click( fn=leffa_predict_vt, inputs=[vt_src_image, vt_ref_image], outputs=[vt_gen_image] ) generate_button.click( fn=generate_image, inputs=[prompt, mode, cfg_scale, steps, seed, width, height, lora_scale], outputs=[result, seed] ).success( fn=lambda: gc.collect(), # 성공 후 메모리 정리 inputs=None, outputs=None ) return demo if __name__ == "__main__": setup_environment() demo = create_interface() demo.queue() demo.launch( server_name="0.0.0.0", server_port=7860, share=False )