LLM_API / main.py
Jacky2305's picture
支持流式响应 (stream=True)
bb288e7
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Optional
import os
import warnings
import json
# 屏蔽 Pydantic 弃用警告(可选,保持日志清洁)
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pydantic")
from llama_cpp import Llama
# Load model
MODEL_PATH = "/app/models/Qwen2.5-3B-Instruct-Q4_K_M.gguf"
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model not found at {MODEL_PATH}")
llm = Llama(
model_path=MODEL_PATH,
n_ctx=32768, # 🔥 改为 32K 上下文
n_threads=4,
chat_format="chatml",
hf_pretrained_model_name_or_path="Qwen/Qwen2.5-3B-Instruct",
verbose=False,
)
app = FastAPI(title="Qwen2.5-3B API (32K)", version="0.2.0")
class Message(BaseModel):
role: str = Field(..., description="Role: 'system', 'user', or 'assistant'")
content: str = Field(..., description="Message content")
class ChatRequest(BaseModel):
model: str = Field(..., description="Model identifier (ignored, single model)")
messages: List[Message] = Field(..., description="List of messages")
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
stream: Optional[bool] = Field(False, description="Stream response (SSE)")
@app.post("/v1/chat/completions")
async def chat_completion(req: ChatRequest):
"""
兼容 OpenAI 格式的 Chat Completions 端点。
支持 stream=True (SSE) 和 stream=False (完整 JSON)。
"""
try:
# 使用 model_dump() 替代已弃用的 dict(),消除 Pydantic 警告
messages_list = [m.model_dump() for m in req.messages]
# 流式响应
if req.stream:
# llama.cpp 生成器(同步)
result_stream = llm.create_chat_completion(
messages=messages_list,
max_tokens=req.max_tokens,
stream=True,
)
async def sse_generator():
for chunk in result_stream:
# 每个 chunk 已经是 OpenAI 格式的 dict
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(sse_generator(), media_type="text/event-stream")
# 非流式响应
result = llm.create_chat_completion(
messages=messages_list,
max_tokens=req.max_tokens,
stream=False,
)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/healthz")
async def healthz():
return {"status": "ok", "model": "Qwen2.5-3B-Instruct", "n_ctx": 32768}
# For HF Spaces compatibility
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)