Spaces:
Sleeping
Sleeping
| """ | |
| ChatBIA FastAPI Server | |
| 24/7 ํ๊ณ AI ์๋ฒ | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from typing import Optional, List, AsyncGenerator | |
| import os | |
| import json | |
| from llama_cpp import Llama | |
| app = FastAPI( | |
| title="ChatBIA API", | |
| description="ํ๊ณ ์ ๋ฌธ AI ์๋ฒ", | |
| version="1.0.0" | |
| ) | |
| # CORS ์ค์ (์๋๋ก์ด๋/์น์์ ์ ๊ทผ ๊ฐ๋ฅ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ๋ชจ๋ธ ๊ฒฝ๋ก | |
| MODEL_DIR = "models" | |
| GENERAL_MODEL_PATH = os.path.join(MODEL_DIR, "Qwen2.5-3B-Instruct-Q4_K_M.gguf") | |
| BSL_MODEL_PATH = os.path.join(MODEL_DIR, "ChatBIA-3B-v0.1-Q4_K_M.gguf") | |
| # ์ ์ญ ๋ชจ๋ธ ๋ณ์ | |
| general_model = None | |
| bsl_model = None | |
| class ChatRequest(BaseModel): | |
| message: str | |
| mode: str = "bsl" # "general" or "bsl" | |
| max_tokens: int = 1024 | |
| temperature: float = 0.7 | |
| class ChatResponse(BaseModel): | |
| response: str | |
| mode: str | |
| tokens: int | |
| async def load_models(): | |
| """์๋ฒ ์์ ์ ๋ชจ๋ธ ๋ก๋""" | |
| global general_model, bsl_model | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| # General ๋ชจ๋ธ ๋ก๋ | |
| if os.path.exists(GENERAL_MODEL_PATH): | |
| print(f"๐ ์ผ๋ฐ ๋ชจ๋ ๋ชจ๋ธ ๋ก๋ ์ค: {GENERAL_MODEL_PATH}") | |
| try: | |
| general_model = Llama( | |
| model_path=GENERAL_MODEL_PATH, | |
| n_ctx=2048, | |
| n_threads=4, | |
| n_gpu_layers=0, # Oracle Cloud๋ CPU | |
| verbose=False | |
| ) | |
| print("โ ์ผ๋ฐ ๋ชจ๋ ๋ชจ๋ธ ๋ก๋ ์๋ฃ") | |
| except Exception as e: | |
| print(f"โ ์ผ๋ฐ ๋ชจ๋ ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}") | |
| # BSL ๋ชจ๋ธ ๋ก๋ | |
| if os.path.exists(BSL_MODEL_PATH): | |
| print(f"๐ BSL ๋ชจ๋ ๋ชจ๋ธ ๋ก๋ ์ค: {BSL_MODEL_PATH}") | |
| try: | |
| bsl_model = Llama( | |
| model_path=BSL_MODEL_PATH, | |
| n_ctx=2048, | |
| n_threads=4, | |
| n_gpu_layers=0, | |
| verbose=False | |
| ) | |
| print("โ BSL ๋ชจ๋ ๋ชจ๋ธ ๋ก๋ ์๋ฃ") | |
| except Exception as e: | |
| print(f"โ BSL ๋ชจ๋ ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}") | |
| def build_prompt(message: str, mode: str) -> str: | |
| """ํ๋กฌํํธ ๋น๋""" | |
| if mode == "bsl": | |
| return f"""<|im_start|>system | |
| You are a professional accounting AI assistant. Respond naturally in Korean. | |
| Important: Only generate BSL DSL code when the user explicitly requests calculations (e.g., "๊ณ์ฐํด์ค", "์ฝ๋ ์์ฑํด์ค", "BSL๋ก ์์ฑํด์ค"). For general questions or greetings, respond conversationally without code.<|im_end|> | |
| <|im_start|>user | |
| {message}<|im_end|> | |
| <|im_start|>assistant | |
| """ | |
| else: | |
| return f"""<|im_start|>system | |
| You are a helpful AI assistant. Respond naturally in Korean.<|im_end|> | |
| <|im_start|>user | |
| {message}<|im_end|> | |
| <|im_start|>assistant | |
| """ | |
| async def root(): | |
| """ํฌ์ค ์ฒดํฌ""" | |
| return { | |
| "status": "online", | |
| "service": "ChatBIA API", | |
| "version": "1.0.0", | |
| "models": { | |
| "general": general_model is not None, | |
| "bsl": bsl_model is not None | |
| } | |
| } | |
| async def chat(request: ChatRequest): | |
| """์ฑํ ์๋ํฌ์ธํธ""" | |
| # ๋ชจ๋ธ ์ ํ | |
| if request.mode == "general": | |
| model = general_model | |
| model_name = "General" | |
| else: | |
| model = bsl_model | |
| model_name = "BSL" | |
| # ๋ชจ๋ธ์ด ์์ผ๋ฉด ์๋ฌ | |
| if model is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"{model_name} ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค." | |
| ) | |
| try: | |
| # ํ๋กฌํํธ ๋น๋ | |
| prompt = build_prompt(request.message, request.mode) | |
| # ์ถ๋ก | |
| response = model( | |
| prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=0.9, | |
| top_k=40, | |
| repeat_penalty=1.1, | |
| stop=["<|im_end|>", "###", "\n\n\n"] | |
| ) | |
| text = response["choices"][0]["text"].strip() | |
| tokens = len(response["choices"][0]["text"].split()) | |
| return ChatResponse( | |
| response=text, | |
| mode=request.mode, | |
| tokens=tokens | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"AI ๋ชจ๋ธ ์ฒ๋ฆฌ ์ค ์ค๋ฅ: {str(e)}" | |
| ) | |
| async def get_models(): | |
| """์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ธ ๋ชฉ๋ก""" | |
| return { | |
| "general": { | |
| "loaded": general_model is not None, | |
| "path": GENERAL_MODEL_PATH if os.path.exists(GENERAL_MODEL_PATH) else None | |
| }, | |
| "bsl": { | |
| "loaded": bsl_model is not None, | |
| "path": BSL_MODEL_PATH if os.path.exists(BSL_MODEL_PATH) else None | |
| } | |
| } | |
| async def chat_stream(request: ChatRequest): | |
| """์คํธ๋ฆฌ๋ฐ ์ฑํ ์๋ํฌ์ธํธ (์๋๋ก์ด๋/ํ์์์ ๋ฐฉ์ง)""" | |
| # ๋ชจ๋ธ ์ ํ | |
| if request.mode == "general": | |
| model = general_model | |
| model_name = "General" | |
| else: | |
| model = bsl_model | |
| model_name = "BSL" | |
| # ๋ชจ๋ธ์ด ์์ผ๋ฉด ์๋ฌ | |
| if model is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"{model_name} ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค." | |
| ) | |
| async def generate_stream() -> AsyncGenerator[str, None]: | |
| """ํ ํฐ ๋จ์ ์คํธ๋ฆฌ๋ฐ ์ ๋๋ ์ดํฐ""" | |
| import asyncio | |
| try: | |
| # ํ๋กฌํํธ ๋น๋ | |
| prompt = build_prompt(request.message, request.mode) | |
| # ์คํธ๋ฆฌ๋ฐ ์ถ๋ก | |
| stream = model( | |
| prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=0.9, | |
| top_k=40, | |
| repeat_penalty=1.1, | |
| stop=["<|im_end|>", "###", "\n\n\n"], | |
| stream=True # ์คํธ๋ฆฌ๋ฐ ํ์ฑํ | |
| ) | |
| token_count = 0 | |
| for chunk in stream: | |
| if "choices" in chunk and len(chunk["choices"]) > 0: | |
| delta = chunk["choices"][0].get("text", "") | |
| if delta: | |
| token_count += 1 | |
| # SSE ํ์: data: {json}\n\n | |
| data = { | |
| "token": delta, | |
| "done": False, | |
| "token_count": token_count | |
| } | |
| yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" | |
| # ์ฆ์ ์ ์ก์ ์ํ ์งง์ ๋๊ธฐ | |
| await asyncio.sleep(0) | |
| # ์๋ฃ ์ ํธ | |
| final_data = { | |
| "token": "", | |
| "done": True, | |
| "token_count": token_count, | |
| "mode": request.mode | |
| } | |
| yield f"data: {json.dumps(final_data, ensure_ascii=False)}\n\n" | |
| except Exception as e: | |
| error_data = { | |
| "error": str(e), | |
| "done": True | |
| } | |
| yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n" | |
| return StreamingResponse( | |
| generate_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no" # Nginx ๋ฒํผ๋ง ๋นํ์ฑํ | |
| } | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "main:app", | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=False # ํ๋ก๋์ ์์๋ reload=False | |
| ) | |