chatbit-api / main.py
Seounghyup's picture
Fix streaming buffering issue
bb5ce7e
"""
ChatBIA FastAPI Server
24/7 ํšŒ๊ณ„ AI ์„œ๋ฒ„
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Optional, List, AsyncGenerator
import os
import json
from llama_cpp import Llama
app = FastAPI(
title="ChatBIA API",
description="ํšŒ๊ณ„ ์ „๋ฌธ AI ์„œ๋ฒ„",
version="1.0.0"
)
# CORS ์„ค์ • (์•ˆ๋“œ๋กœ์ด๋“œ/์›น์—์„œ ์ ‘๊ทผ ๊ฐ€๋Šฅ)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ๋ชจ๋ธ ๊ฒฝ๋กœ
MODEL_DIR = "models"
GENERAL_MODEL_PATH = os.path.join(MODEL_DIR, "Qwen2.5-3B-Instruct-Q4_K_M.gguf")
BSL_MODEL_PATH = os.path.join(MODEL_DIR, "ChatBIA-3B-v0.1-Q4_K_M.gguf")
# ์ „์—ญ ๋ชจ๋ธ ๋ณ€์ˆ˜
general_model = None
bsl_model = None
class ChatRequest(BaseModel):
message: str
mode: str = "bsl" # "general" or "bsl"
max_tokens: int = 1024
temperature: float = 0.7
class ChatResponse(BaseModel):
response: str
mode: str
tokens: int
@app.on_event("startup")
async def load_models():
"""์„œ๋ฒ„ ์‹œ์ž‘ ์‹œ ๋ชจ๋ธ ๋กœ๋“œ"""
global general_model, bsl_model
os.makedirs(MODEL_DIR, exist_ok=True)
# General ๋ชจ๋ธ ๋กœ๋“œ
if os.path.exists(GENERAL_MODEL_PATH):
print(f"๐Ÿ”„ ์ผ๋ฐ˜ ๋ชจ๋“œ ๋ชจ๋ธ ๋กœ๋“œ ์ค‘: {GENERAL_MODEL_PATH}")
try:
general_model = Llama(
model_path=GENERAL_MODEL_PATH,
n_ctx=2048,
n_threads=4,
n_gpu_layers=0, # Oracle Cloud๋Š” CPU
verbose=False
)
print("โœ… ์ผ๋ฐ˜ ๋ชจ๋“œ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ")
except Exception as e:
print(f"โŒ ์ผ๋ฐ˜ ๋ชจ๋“œ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}")
# BSL ๋ชจ๋ธ ๋กœ๋“œ
if os.path.exists(BSL_MODEL_PATH):
print(f"๐Ÿ”„ BSL ๋ชจ๋“œ ๋ชจ๋ธ ๋กœ๋“œ ์ค‘: {BSL_MODEL_PATH}")
try:
bsl_model = Llama(
model_path=BSL_MODEL_PATH,
n_ctx=2048,
n_threads=4,
n_gpu_layers=0,
verbose=False
)
print("โœ… BSL ๋ชจ๋“œ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ")
except Exception as e:
print(f"โŒ BSL ๋ชจ๋“œ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}")
def build_prompt(message: str, mode: str) -> str:
"""ํ”„๋กฌํ”„ํŠธ ๋นŒ๋“œ"""
if mode == "bsl":
return f"""<|im_start|>system
You are a professional accounting AI assistant. Respond naturally in Korean.
Important: Only generate BSL DSL code when the user explicitly requests calculations (e.g., "๊ณ„์‚ฐํ•ด์ค˜", "์ฝ”๋“œ ์ž‘์„ฑํ•ด์ค˜", "BSL๋กœ ์ž‘์„ฑํ•ด์ค˜"). For general questions or greetings, respond conversationally without code.<|im_end|>
<|im_start|>user
{message}<|im_end|>
<|im_start|>assistant
"""
else:
return f"""<|im_start|>system
You are a helpful AI assistant. Respond naturally in Korean.<|im_end|>
<|im_start|>user
{message}<|im_end|>
<|im_start|>assistant
"""
@app.get("/")
async def root():
"""ํ—ฌ์Šค ์ฒดํฌ"""
return {
"status": "online",
"service": "ChatBIA API",
"version": "1.0.0",
"models": {
"general": general_model is not None,
"bsl": bsl_model is not None
}
}
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
"""์ฑ„ํŒ… ์—”๋“œํฌ์ธํŠธ"""
# ๋ชจ๋ธ ์„ ํƒ
if request.mode == "general":
model = general_model
model_name = "General"
else:
model = bsl_model
model_name = "BSL"
# ๋ชจ๋ธ์ด ์—†์œผ๋ฉด ์—๋Ÿฌ
if model is None:
raise HTTPException(
status_code=503,
detail=f"{model_name} ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค."
)
try:
# ํ”„๋กฌํ”„ํŠธ ๋นŒ๋“œ
prompt = build_prompt(request.message, request.mode)
# ์ถ”๋ก 
response = model(
prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=0.9,
top_k=40,
repeat_penalty=1.1,
stop=["<|im_end|>", "###", "\n\n\n"]
)
text = response["choices"][0]["text"].strip()
tokens = len(response["choices"][0]["text"].split())
return ChatResponse(
response=text,
mode=request.mode,
tokens=tokens
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"AI ๋ชจ๋ธ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜: {str(e)}"
)
@app.get("/models")
async def get_models():
"""์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ ๋ชฉ๋ก"""
return {
"general": {
"loaded": general_model is not None,
"path": GENERAL_MODEL_PATH if os.path.exists(GENERAL_MODEL_PATH) else None
},
"bsl": {
"loaded": bsl_model is not None,
"path": BSL_MODEL_PATH if os.path.exists(BSL_MODEL_PATH) else None
}
}
@app.post("/chat/stream")
async def chat_stream(request: ChatRequest):
"""์ŠคํŠธ๋ฆฌ๋ฐ ์ฑ„ํŒ… ์—”๋“œํฌ์ธํŠธ (์•ˆ๋“œ๋กœ์ด๋“œ/ํƒ€์ž„์•„์›ƒ ๋ฐฉ์ง€)"""
# ๋ชจ๋ธ ์„ ํƒ
if request.mode == "general":
model = general_model
model_name = "General"
else:
model = bsl_model
model_name = "BSL"
# ๋ชจ๋ธ์ด ์—†์œผ๋ฉด ์—๋Ÿฌ
if model is None:
raise HTTPException(
status_code=503,
detail=f"{model_name} ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค."
)
async def generate_stream() -> AsyncGenerator[str, None]:
"""ํ† ํฐ ๋‹จ์œ„ ์ŠคํŠธ๋ฆฌ๋ฐ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ"""
import asyncio
try:
# ํ”„๋กฌํ”„ํŠธ ๋นŒ๋“œ
prompt = build_prompt(request.message, request.mode)
# ์ŠคํŠธ๋ฆฌ๋ฐ ์ถ”๋ก 
stream = model(
prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=0.9,
top_k=40,
repeat_penalty=1.1,
stop=["<|im_end|>", "###", "\n\n\n"],
stream=True # ์ŠคํŠธ๋ฆฌ๋ฐ ํ™œ์„ฑํ™”
)
token_count = 0
for chunk in stream:
if "choices" in chunk and len(chunk["choices"]) > 0:
delta = chunk["choices"][0].get("text", "")
if delta:
token_count += 1
# SSE ํ˜•์‹: data: {json}\n\n
data = {
"token": delta,
"done": False,
"token_count": token_count
}
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
# ์ฆ‰์‹œ ์ „์†ก์„ ์œ„ํ•œ ์งง์€ ๋Œ€๊ธฐ
await asyncio.sleep(0)
# ์™„๋ฃŒ ์‹ ํ˜ธ
final_data = {
"token": "",
"done": True,
"token_count": token_count,
"mode": request.mode
}
yield f"data: {json.dumps(final_data, ensure_ascii=False)}\n\n"
except Exception as e:
error_data = {
"error": str(e),
"done": True
}
yield f"data: {json.dumps(error_data, ensure_ascii=False)}\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # Nginx ๋ฒ„ํผ๋ง ๋น„ํ™œ์„ฑํ™”
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=False # ํ”„๋กœ๋•์…˜์—์„œ๋Š” reload=False
)