gbrabbit's picture
Fresh start for HF Spaces deployment
526927a
raw
history blame
8.23 kB
#!/usr/bin/env python3
"""
Lily LLM API ์„œ๋ฒ„
ํŒŒ์ธํŠœ๋‹๋œ Mistral-7B ๋ชจ๋ธ์„ RESTful API๋กœ ์„œ๋น™
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
import logging
import time
import torch
from typing import Optional, List
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# FastAPI ์•ฑ ์ƒ์„ฑ
app = FastAPI(
title="Lily LLM API",
description="Hearth Chat์šฉ ํŒŒ์ธํŠœ๋‹๋œ Mistral-7B ๋ชจ๋ธ API",
version="1.0.0"
)
# CORS ์„ค์ •
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # ๊ฐœ๋ฐœ์šฉ, ํ”„๋กœ๋•์…˜์—์„œ๋Š” ํŠน์ • ๋„๋ฉ”์ธ๋งŒ ํ—ˆ์šฉ
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Pydantic ๋ชจ๋ธ๋“ค
class GenerateRequest(BaseModel):
prompt: str
max_length: Optional[int] = 100
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
do_sample: Optional[bool] = True
class GenerateResponse(BaseModel):
generated_text: str
processing_time: float
model_name: str = "Lily LLM (Mistral-7B)"
class HealthResponse(BaseModel):
status: str
model_loaded: bool
model_name: str
# ์ „์—ญ ๋ณ€์ˆ˜
model = None
tokenizer = None
model_loaded = False
@app.on_event("startup")
async def startup_event():
"""์„œ๋ฒ„ ์‹œ์ž‘ ์‹œ ๋ชจ๋ธ ๋กœ๋“œ"""
global model, tokenizer, model_loaded
logger.info("๐Ÿš€ Lily LLM API ์„œ๋ฒ„ ์‹œ์ž‘ ์ค‘...")
logger.info("๐Ÿ“ API ๋ฌธ์„œ: http://localhost:8001/docs")
logger.info("๐Ÿ” ํ—ฌ์Šค ์ฒดํฌ: http://localhost:8001/health")
try:
# ๋ชจ๋ธ ๋กœ๋”ฉ (๋น„๋™๊ธฐ๋กœ ์ฒ˜๋ฆฌํ•˜์—ฌ ์„œ๋ฒ„ ์‹œ์ž‘ ์†๋„ ํ–ฅ์ƒ)
await load_model_async()
model_loaded = True
logger.info("โœ… ๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
except Exception as e:
logger.error(f"โŒ ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
model_loaded = False
async def load_model_async():
"""๋น„๋™๊ธฐ ๋ชจ๋ธ ๋กœ๋”ฉ"""
global model, tokenizer
# ๋ชจ๋ธ ๋กœ๋”ฉ์€ ๋ณ„๋„ ์Šค๋ ˆ๋“œ์—์„œ ์‹คํ–‰
import asyncio
import concurrent.futures
def load_model_sync():
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
logger.info("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
# ๋กœ์ปฌ ๋ชจ๋ธ ๊ฒฝ๋กœ ์‚ฌ์šฉ
local_model_path = "./lily_llm_core/models/polyglot-ko-1.3b"
try:
# ๋กœ์ปฌ ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained(local_model_path, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# ๋ชจ๋ธ ๋กœ๋“œ (CPU์—์„œ)
model = AutoModelForCausalLM.from_pretrained(
local_model_path,
torch_dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True
)
logger.info("โœ… polyglot-ko-1.3b ๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต!")
return model, tokenizer
except Exception as e:
logger.error(f"๋กœ์ปฌ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}")
logger.info("ํ…Œ์ŠคํŠธ์šฉ ๊ฐ„๋‹จํ•œ ๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
# DialoGPT-medium์œผ๋กœ ๋Œ€์ฒด (๋” ์ž‘์€ ๋ชจ๋ธ)
test_model_name = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(test_model_name)
model = AutoModelForCausalLM.from_pretrained(test_model_name)
return model, tokenizer
# ๋ณ„๋„ ์Šค๋ ˆ๋“œ์—์„œ ๋ชจ๋ธ ๋กœ๋”ฉ
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
model, tokenizer = await loop.run_in_executor(executor, load_model_sync)
@app.get("/", response_model=dict)
async def root():
"""๋ฃจํŠธ ์—”๋“œํฌ์ธํŠธ"""
return {
"message": "Lily LLM API ์„œ๋ฒ„",
"version": "1.0.0",
"model": "Mistral-7B-Instruct-v0.2 (Fine-tuned)",
"docs": "/docs"
}
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""ํ—ฌ์Šค ์ฒดํฌ ์—”๋“œํฌ์ธํŠธ"""
return HealthResponse(
status="healthy",
model_loaded=model_loaded,
model_name="Lily LLM (Mistral-7B)"
)
@app.post("/generate", response_model=GenerateResponse)
async def generate_text(request: GenerateRequest):
"""ํ…์ŠคํŠธ ์ƒ์„ฑ ์—”๋“œํฌ์ธํŠธ"""
global model, tokenizer
if not model_loaded or model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค")
start_time = time.time()
try:
logger.info(f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์‹œ์ž‘: '{request.prompt}'")
# polyglot ๋ชจ๋ธ์— ๋งž๋Š” ํ”„๋กฌํ”„ํŠธ ํ˜•์‹์œผ๋กœ ์ˆ˜์ •
formatted_prompt = f"์งˆ๋ฌธ: {request.prompt}\n๋‹ต๋ณ€:"
logger.info(f"ํฌ๋งท๋œ ํ”„๋กฌํ”„ํŠธ: '{formatted_prompt}'")
# ์ž…๋ ฅ ํ† ํฌ๋‚˜์ด์ง• - padding ์ œ๊ฑฐํ•˜๊ณ  ํŒจ๋”ฉ ํ† ํฐ ์„ค์ •
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True)
logger.info(f"์ž…๋ ฅ ํ† ํฐ ์ˆ˜: {inputs['input_ids'].shape[1]}")
# ํ…์ŠคํŠธ ์ƒ์„ฑ - ๋” ๊ฐ•๋ ฅํ•œ ์„ค์ •์œผ๋กœ ์ˆ˜์ •
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=request.max_length,
do_sample=True,
temperature=0.9, # ๋” ๋†’์€ temperature
top_k=50, # top_k ์ถ”๊ฐ€
top_p=0.95, # top_p ์ถ”๊ฐ€
repetition_penalty=1.2, # ๋ฐ˜๋ณต ๋ฐฉ์ง€
no_repeat_ngram_size=2, # n-gram ๋ฐ˜๋ณต ๋ฐฉ์ง€
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
logger.info(f"์ƒ์„ฑ๋œ ํ† ํฐ ์ˆ˜: {outputs.shape[1]}")
# ๊ฒฐ๊ณผ ๋””์ฝ”๋”ฉ
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"๋””์ฝ”๋”ฉ๋œ ์ „์ฒด ํ…์ŠคํŠธ: '{generated_text}'")
# polyglot ์‘๋‹ต ๋ถ€๋ถ„๋งŒ ์ถ”์ถœ
if "๋‹ต๋ณ€:" in generated_text:
response = generated_text.split("๋‹ต๋ณ€:")[-1].strip()
logger.info(f"๋‹ต๋ณ€ ์ถ”์ถœ: '{response}'")
else:
# ๊ธฐ์กด ๋ฐฉ์‹์œผ๋กœ ํ”„๋กฌํ”„ํŠธ ์ œ๊ฑฐ
if formatted_prompt in generated_text:
response = generated_text.replace(formatted_prompt, "").strip()
else:
response = generated_text.strip()
logger.info(f"ํ”„๋กฌํ”„ํŠธ ์ œ๊ฑฐ ํ›„: '{response}'")
# ๋นˆ ์‘๋‹ต ์ฒ˜๋ฆฌ
if not response.strip():
logger.warning("์ƒ์„ฑ๋œ ํ…์ŠคํŠธ๊ฐ€ ๋น„์–ด์žˆ์Œ, ๊ธฐ๋ณธ ์‘๋‹ต ์‚ฌ์šฉ")
response = "์•ˆ๋…•ํ•˜์„ธ์š”! ๋ฌด์—‡์„ ๋„์™€๋“œ๋ฆด๊นŒ์š”?"
processing_time = time.time() - start_time
logger.info(f"์ƒ์„ฑ ์™„๋ฃŒ: {processing_time:.2f}์ดˆ, ํ…์ŠคํŠธ ๊ธธ์ด: {len(response)}")
return GenerateResponse(
generated_text=response,
processing_time=processing_time
)
except Exception as e:
logger.error(f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์˜ค๋ฅ˜: {e}")
raise HTTPException(status_code=500, detail=f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์‹คํŒจ: {str(e)}")
@app.get("/models")
async def list_models():
"""์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ ๋ชฉ๋ก"""
return {
"models": [
{
"id": "lily-llm",
"name": "Lily LLM",
"description": "Hearth Chat์šฉ ํŒŒ์ธํŠœ๋‹๋œ Mistral-7B ๋ชจ๋ธ",
"base_model": "mistralai/Mistral-7B-Instruct-v0.2",
"fine_tuned": True
}
]
}
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=8001,
reload=False,
log_level="info"
)