Update app.py
Browse files
app.py
CHANGED
@@ -328,6 +328,56 @@ def leffa_predict_pt(src_image_path, ref_image_path):
|
|
328 |
print(f"Error in leffa_predict_pt: {str(e)}")
|
329 |
raise
|
330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
# 초기 설정 실행
|
332 |
setup()
|
333 |
def create_interface():
|
@@ -561,7 +611,7 @@ def create_interface():
|
|
561 |
if __name__ == "__main__":
|
562 |
setup_environment()
|
563 |
demo = create_interface()
|
564 |
-
demo.queue()
|
565 |
demo.launch(
|
566 |
server_name="0.0.0.0",
|
567 |
server_port=7860,
|
|
|
328 |
print(f"Error in leffa_predict_pt: {str(e)}")
|
329 |
raise
|
330 |
|
331 |
+
|
332 |
+
@spaces.GPU()
|
333 |
+
def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
|
334 |
+
try:
|
335 |
+
with torch_gc():
|
336 |
+
# 한글 처리
|
337 |
+
if contains_korean(prompt):
|
338 |
+
translator = get_translator()
|
339 |
+
with torch.inference_mode():
|
340 |
+
translated = translator(prompt)[0]['translation_text']
|
341 |
+
actual_prompt = translated
|
342 |
+
else:
|
343 |
+
actual_prompt = prompt
|
344 |
+
|
345 |
+
# 파이프라인 초기화
|
346 |
+
pipe = DiffusionPipeline.from_pretrained(
|
347 |
+
BASE_MODEL,
|
348 |
+
torch_dtype=torch.float16,
|
349 |
+
)
|
350 |
+
pipe = pipe.to("cuda")
|
351 |
+
|
352 |
+
# LoRA 설정
|
353 |
+
if mode == "Generate Model":
|
354 |
+
pipe.load_lora_weights(MODEL_LORA_REPO)
|
355 |
+
trigger_word = "fashion photography, professional model"
|
356 |
+
else:
|
357 |
+
pipe.load_lora_weights(CLOTHES_LORA_REPO)
|
358 |
+
trigger_word = "upper clothing, fashion item"
|
359 |
+
|
360 |
+
# 이미지 생성
|
361 |
+
with torch.inference_mode():
|
362 |
+
result = pipe(
|
363 |
+
prompt=f"{actual_prompt} {trigger_word}",
|
364 |
+
num_inference_steps=steps,
|
365 |
+
guidance_scale=cfg_scale,
|
366 |
+
width=width,
|
367 |
+
height=height,
|
368 |
+
generator=torch.Generator("cuda").manual_seed(
|
369 |
+
seed if seed is not None else torch.randint(0, 2**32 - 1, (1,)).item()
|
370 |
+
),
|
371 |
+
joint_attention_kwargs={"scale": lora_scale},
|
372 |
+
).images[0]
|
373 |
+
|
374 |
+
# 메모리 정리
|
375 |
+
del pipe
|
376 |
+
return result, seed
|
377 |
+
|
378 |
+
except Exception as e:
|
379 |
+
raise gr.Error(f"Generation failed: {str(e)}")
|
380 |
+
|
381 |
# 초기 설정 실행
|
382 |
setup()
|
383 |
def create_interface():
|
|
|
611 |
if __name__ == "__main__":
|
612 |
setup_environment()
|
613 |
demo = create_interface()
|
614 |
+
demo.queue()
|
615 |
demo.launch(
|
616 |
server_name="0.0.0.0",
|
617 |
server_port=7860,
|