Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Lily LLM API ์๋ฒ | |
| ํ์ธํ๋๋ Mistral-7B ๋ชจ๋ธ์ RESTful API๋ก ์๋น | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import uvicorn | |
| import logging | |
| import time | |
| import torch | |
| from typing import Optional, List | |
| # ๋ก๊น ์ค์ | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # FastAPI ์ฑ ์์ฑ | |
| app = FastAPI( | |
| title="Lily LLM API", | |
| description="Hearth Chat์ฉ ํ์ธํ๋๋ Mistral-7B ๋ชจ๋ธ API", | |
| version="1.0.0" | |
| ) | |
| # CORS ์ค์ | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # ๊ฐ๋ฐ์ฉ, ํ๋ก๋์ ์์๋ ํน์ ๋๋ฉ์ธ๋ง ํ์ฉ | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Pydantic ๋ชจ๋ธ๋ค | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| max_length: Optional[int] = 100 | |
| temperature: Optional[float] = 0.7 | |
| top_p: Optional[float] = 0.9 | |
| do_sample: Optional[bool] = True | |
| class GenerateResponse(BaseModel): | |
| generated_text: str | |
| processing_time: float | |
| model_name: str = "Lily LLM (Mistral-7B)" | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| model_name: str | |
| # ์ ์ญ ๋ณ์ | |
| model = None | |
| tokenizer = None | |
| model_loaded = False | |
| async def startup_event(): | |
| """์๋ฒ ์์ ์ ๋ชจ๋ธ ๋ก๋""" | |
| global model, tokenizer, model_loaded | |
| logger.info("๐ Lily LLM API ์๋ฒ ์์ ์ค...") | |
| logger.info("๐ API ๋ฌธ์: http://localhost:8001/docs") | |
| logger.info("๐ ํฌ์ค ์ฒดํฌ: http://localhost:8001/health") | |
| try: | |
| # ๋ชจ๋ธ ๋ก๋ฉ (๋น๋๊ธฐ๋ก ์ฒ๋ฆฌํ์ฌ ์๋ฒ ์์ ์๋ ํฅ์) | |
| await load_model_async() | |
| model_loaded = True | |
| logger.info("โ ๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ!") | |
| except Exception as e: | |
| logger.error(f"โ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}") | |
| model_loaded = False | |
| async def load_model_async(): | |
| """๋น๋๊ธฐ ๋ชจ๋ธ ๋ก๋ฉ""" | |
| global model, tokenizer | |
| # ๋ชจ๋ธ ๋ก๋ฉ์ ๋ณ๋ ์ค๋ ๋์์ ์คํ | |
| import asyncio | |
| import concurrent.futures | |
| def load_model_sync(): | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| import torch | |
| logger.info("๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| # ๋ก์ปฌ ๋ชจ๋ธ ๊ฒฝ๋ก ์ฌ์ฉ | |
| local_model_path = "./lily_llm_core/models/polyglot-ko-1.3b" | |
| try: | |
| # ๋ก์ปฌ ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ก๋ | |
| tokenizer = AutoTokenizer.from_pretrained(local_model_path, use_fast=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # ๋ชจ๋ธ ๋ก๋ (CPU์์) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| local_model_path, | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| low_cpu_mem_usage=True | |
| ) | |
| logger.info("โ polyglot-ko-1.3b ๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต!") | |
| return model, tokenizer | |
| except Exception as e: | |
| logger.error(f"๋ก์ปฌ ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}") | |
| logger.info("ํ ์คํธ์ฉ ๊ฐ๋จํ ๋ชจ๋ธ ๋ก๋ ์ค...") | |
| # DialoGPT-medium์ผ๋ก ๋์ฒด (๋ ์์ ๋ชจ๋ธ) | |
| test_model_name = "microsoft/DialoGPT-medium" | |
| tokenizer = AutoTokenizer.from_pretrained(test_model_name) | |
| model = AutoModelForCausalLM.from_pretrained(test_model_name) | |
| return model, tokenizer | |
| # ๋ณ๋ ์ค๋ ๋์์ ๋ชจ๋ธ ๋ก๋ฉ | |
| loop = asyncio.get_event_loop() | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| model, tokenizer = await loop.run_in_executor(executor, load_model_sync) | |
| async def root(): | |
| """๋ฃจํธ ์๋ํฌ์ธํธ""" | |
| return { | |
| "message": "Lily LLM API ์๋ฒ", | |
| "version": "1.0.0", | |
| "model": "Mistral-7B-Instruct-v0.2 (Fine-tuned)", | |
| "docs": "/docs" | |
| } | |
| async def health_check(): | |
| """ํฌ์ค ์ฒดํฌ ์๋ํฌ์ธํธ""" | |
| return HealthResponse( | |
| status="healthy", | |
| model_loaded=model_loaded, | |
| model_name="Lily LLM (Mistral-7B)" | |
| ) | |
| async def generate_text(request: GenerateRequest): | |
| """ํ ์คํธ ์์ฑ ์๋ํฌ์ธํธ""" | |
| global model, tokenizer | |
| if not model_loaded or model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค") | |
| start_time = time.time() | |
| try: | |
| logger.info(f"ํ ์คํธ ์์ฑ ์์: '{request.prompt}'") | |
| # polyglot ๋ชจ๋ธ์ ๋ง๋ ํ๋กฌํํธ ํ์์ผ๋ก ์์ | |
| formatted_prompt = f"์ง๋ฌธ: {request.prompt}\n๋ต๋ณ:" | |
| logger.info(f"ํฌ๋งท๋ ํ๋กฌํํธ: '{formatted_prompt}'") | |
| # ์ ๋ ฅ ํ ํฌ๋์ด์ง - padding ์ ๊ฑฐํ๊ณ ํจ๋ฉ ํ ํฐ ์ค์ | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True) | |
| logger.info(f"์ ๋ ฅ ํ ํฐ ์: {inputs['input_ids'].shape[1]}") | |
| # ํ ์คํธ ์์ฑ - ๋ ๊ฐ๋ ฅํ ์ค์ ์ผ๋ก ์์ | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_new_tokens=request.max_length, | |
| do_sample=True, | |
| temperature=0.9, # ๋ ๋์ temperature | |
| top_k=50, # top_k ์ถ๊ฐ | |
| top_p=0.95, # top_p ์ถ๊ฐ | |
| repetition_penalty=1.2, # ๋ฐ๋ณต ๋ฐฉ์ง | |
| no_repeat_ngram_size=2, # n-gram ๋ฐ๋ณต ๋ฐฉ์ง | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| logger.info(f"์์ฑ๋ ํ ํฐ ์: {outputs.shape[1]}") | |
| # ๊ฒฐ๊ณผ ๋์ฝ๋ฉ | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| logger.info(f"๋์ฝ๋ฉ๋ ์ ์ฒด ํ ์คํธ: '{generated_text}'") | |
| # polyglot ์๋ต ๋ถ๋ถ๋ง ์ถ์ถ | |
| if "๋ต๋ณ:" in generated_text: | |
| response = generated_text.split("๋ต๋ณ:")[-1].strip() | |
| logger.info(f"๋ต๋ณ ์ถ์ถ: '{response}'") | |
| else: | |
| # ๊ธฐ์กด ๋ฐฉ์์ผ๋ก ํ๋กฌํํธ ์ ๊ฑฐ | |
| if formatted_prompt in generated_text: | |
| response = generated_text.replace(formatted_prompt, "").strip() | |
| else: | |
| response = generated_text.strip() | |
| logger.info(f"ํ๋กฌํํธ ์ ๊ฑฐ ํ: '{response}'") | |
| # ๋น ์๋ต ์ฒ๋ฆฌ | |
| if not response.strip(): | |
| logger.warning("์์ฑ๋ ํ ์คํธ๊ฐ ๋น์ด์์, ๊ธฐ๋ณธ ์๋ต ์ฌ์ฉ") | |
| response = "์๋ ํ์ธ์! ๋ฌด์์ ๋์๋๋ฆด๊น์?" | |
| processing_time = time.time() - start_time | |
| logger.info(f"์์ฑ ์๋ฃ: {processing_time:.2f}์ด, ํ ์คํธ ๊ธธ์ด: {len(response)}") | |
| return GenerateResponse( | |
| generated_text=response, | |
| processing_time=processing_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"ํ ์คํธ ์์ฑ ์ค๋ฅ: {e}") | |
| raise HTTPException(status_code=500, detail=f"ํ ์คํธ ์์ฑ ์คํจ: {str(e)}") | |
| async def list_models(): | |
| """์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ธ ๋ชฉ๋ก""" | |
| return { | |
| "models": [ | |
| { | |
| "id": "lily-llm", | |
| "name": "Lily LLM", | |
| "description": "Hearth Chat์ฉ ํ์ธํ๋๋ Mistral-7B ๋ชจ๋ธ", | |
| "base_model": "mistralai/Mistral-7B-Instruct-v0.2", | |
| "fine_tuned": True | |
| } | |
| ] | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=8001, | |
| reload=False, | |
| log_level="info" | |
| ) | |