Jacky2305 commited on
Commit
bb288e7
·
1 Parent(s): 44159b9

支持流式响应 (stream=True)

Browse files
Files changed (1) hide show
  1. main.py +24 -7
main.py CHANGED
@@ -1,9 +1,10 @@
1
- from fastapi import FastAPI, HTTPException
2
- from fastapi.responses import JSONResponse
3
  from pydantic import BaseModel, Field
4
  from typing import List, Optional
5
  import os
6
  import warnings
 
7
 
8
  # 屏蔽 Pydantic 弃用警告(可选,保持日志清洁)
9
  warnings.filterwarnings("ignore", category=DeprecationWarning, module="pydantic")
@@ -34,22 +35,38 @@ class ChatRequest(BaseModel):
34
  model: str = Field(..., description="Model identifier (ignored, single model)")
35
  messages: List[Message] = Field(..., description="List of messages")
36
  max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
37
- stream: Optional[bool] = Field(False, description="Stream response (not supported)")
38
 
39
  @app.post("/v1/chat/completions")
40
  async def chat_completion(req: ChatRequest):
41
  """
42
  兼容 OpenAI 格式的 Chat Completions 端点。
43
- 注意:此 3B 模型即使上下文设为 32K,在处理长上下文时生成质量可能受限
44
  """
45
  try:
46
  # 使用 model_dump() 替代已弃用的 dict(),消除 Pydantic 警告
47
  messages_list = [m.model_dump() for m in req.messages]
48
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  result = llm.create_chat_completion(
50
  messages=messages_list,
51
  max_tokens=req.max_tokens,
52
- stream=req.stream,
53
  )
54
  return JSONResponse(content=result)
55
  except Exception as e:
@@ -63,4 +80,4 @@ async def healthz():
63
  if __name__ == "__main__":
64
  import uvicorn
65
  port = int(os.getenv("PORT", 7860))
66
- uvicorn.run(app, host="0.0.0.0", port=port)
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
+ from fastapi.responses import JSONResponse, StreamingResponse
3
  from pydantic import BaseModel, Field
4
  from typing import List, Optional
5
  import os
6
  import warnings
7
+ import json
8
 
9
  # 屏蔽 Pydantic 弃用警告(可选,保持日志清洁)
10
  warnings.filterwarnings("ignore", category=DeprecationWarning, module="pydantic")
 
35
  model: str = Field(..., description="Model identifier (ignored, single model)")
36
  messages: List[Message] = Field(..., description="List of messages")
37
  max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
38
+ stream: Optional[bool] = Field(False, description="Stream response (SSE)")
39
 
40
  @app.post("/v1/chat/completions")
41
  async def chat_completion(req: ChatRequest):
42
  """
43
  兼容 OpenAI 格式的 Chat Completions 端点。
44
+ 支持 stream=True (SSE) 和 stream=False (完整 JSON)
45
  """
46
  try:
47
  # 使用 model_dump() 替代已弃用的 dict(),消除 Pydantic 警告
48
  messages_list = [m.model_dump() for m in req.messages]
49
+
50
+ # 流式响应
51
+ if req.stream:
52
+ # llama.cpp 生成器(同步)
53
+ result_stream = llm.create_chat_completion(
54
+ messages=messages_list,
55
+ max_tokens=req.max_tokens,
56
+ stream=True,
57
+ )
58
+ async def sse_generator():
59
+ for chunk in result_stream:
60
+ # 每个 chunk 已经是 OpenAI 格式的 dict
61
+ yield f"data: {json.dumps(chunk)}\n\n"
62
+ yield "data: [DONE]\n\n"
63
+ return StreamingResponse(sse_generator(), media_type="text/event-stream")
64
+
65
+ # 非流式响应
66
  result = llm.create_chat_completion(
67
  messages=messages_list,
68
  max_tokens=req.max_tokens,
69
+ stream=False,
70
  )
71
  return JSONResponse(content=result)
72
  except Exception as e:
 
80
  if __name__ == "__main__":
81
  import uvicorn
82
  port = int(os.getenv("PORT", 7860))
83
+ uvicorn.run(app, host="0.0.0.0", port=port)