net-chatbot / app.py
shtyyg's picture
Update app.py
b849f0c verified
import json
import time
import uuid
import os
import asyncio
from typing import List, Optional, AsyncGenerator
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import httpx
app = FastAPI(title="OpenAI-Compatible GLM API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
GRADIO_API_BASE = "https://shtyyg-zai-org-glm-5-1.hf.space/gradio_api"
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = "glm-5-1"
messages: List[ChatMessage]
stream: Optional[bool] = False
def extract_gradio_text(data_payload: dict) -> Optional[str]:
"""
从 Gradio 5 SSE 数据中提取文本。
增量 (process_generating):
output.data = [[], [["append", [1,"content",0,"text"], "你好"]]]
完成 (process_completed):
output.data = [null, [{role:"user",...},
{role:"assistant", content:[{type:"text", text:"你好!..."}]}]]
"""
msg = data_payload.get("msg", "")
if msg in (
"send_hash", "send_data", "process_starts",
"estimation", "queue_full", "close_stream",
):
return None
if msg in ("unexpected_error", "error"):
return None
output = data_payload.get("output")
if not output or not isinstance(output, dict):
return None
output_data = output.get("data")
if not output_data or not isinstance(output_data, list):
return None
if len(output_data) < 2:
return None
chatbot_data = output_data[1]
if chatbot_data is None:
return None
# ── Case 1: append 增量操作 ──
# [["append", [1,"content",0,"text"], "你好"]]
if isinstance(chatbot_data, list) and len(chatbot_data) > 0:
first = chatbot_data[0]
if isinstance(first, list) and len(first) >= 3 and first[0] == "append":
return str(first[2])
# ── Case 2: 完整 chatbot 消息 (process_completed) ──
if isinstance(chatbot_data, list):
for message in reversed(chatbot_data):
if isinstance(message, dict) and message.get("role") == "assistant":
content = message.get("content")
if isinstance(content, list):
texts = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
texts.append(item.get("text", ""))
return "".join(texts) if texts else None
elif isinstance(content, str):
return content
return None
async def wake_up_space() -> bool:
space_url = "https://shtyyg-zai-org-glm-5-1.hf.space"
print(f"[INFO] >>> 检查/唤醒 Space: {space_url}")
async with httpx.AsyncClient(follow_redirects=True) as client:
try:
resp = await client.get(space_url, timeout=120)
if resp.status_code == 200:
print("[INFO] <<< Space 可达")
return True
print(f"[WARN] <<< Space 响应: {resp.status_code}")
return False
except Exception as e:
print(f"[WARN] <<< Space 异常: {e}")
return False
async def stream_gradio_response(prompt: str) -> AsyncGenerator[str, None]:
"""
严格按测试验证的 3 步顺序执行:
1. POST /run/predict (fn=2) → 提交用户输入
2. POST /queue/join (fn=4) → 排队生成
3. GET /queue/data?session_hash= → 接收 SSE 结果
★ 所有请求在同一 client 内完成,按序执行
"""
session_hash = uuid.uuid4().hex
trigger_id = 9
run_url = f"{GRADIO_API_BASE}/run/predict"
stream_url = (
f"{GRADIO_API_BASE}/queue/data?session_hash={session_hash}"
)
join_url = f"{GRADIO_API_BASE}/queue/join"
last_text = ""
print(f"\n[DEBUG] >>> session_hash: {session_hash}")
# ═══════════════════════════════════════════════════════
# 同一个 Client,按序执行 1 → 2 → 3
# ═══════════════════════════════════════════════════════
async with httpx.AsyncClient() as client:
# ── Step 1: POST /run/predict (fn=2) ──
r1_payload = {
"data": [prompt],
"fn_index": 2,
"trigger_id": trigger_id,
"session_hash": session_hash,
}
print(f"[DEBUG] >>> Step1 POST /run/predict fn=2")
r1 = await client.post(run_url, json=r1_payload, timeout=60)
print(
f"[DEBUG] <<< Step1: {r1.status_code}, "
f"内容: {r1.text[:200]}"
)
if r1.status_code != 200:
raise HTTPException(
status_code=502,
detail=f"Step1 失败: HTTP {r1.status_code}",
)
# ── Step 2: POST /queue/join (fn=4) ──
r2_payload = {
"data": [None, None],
"event_data": None,
"fn_index": 4,
"trigger_id": trigger_id,
"session_hash": session_hash,
}
print(f"[DEBUG] >>> Step2 POST /queue/join fn=4")
r2 = await client.post(join_url, json=r2_payload, timeout=60)
print(
f"[DEBUG] <<< Step2: {r2.status_code}, "
f"内容: {r2.text[:200]}"
)
if r2.status_code != 200:
raise HTTPException(
status_code=502,
detail=f"Step2 失败: HTTP {r2.status_code}",
)
# ── Step 3: GET /queue/data (SSE) ──
print(f"[DEBUG] >>> Step3 GET /queue/data SSE")
async with client.stream(
"GET",
stream_url,
timeout=httpx.Timeout(300.0, read=300.0),
) as resp:
print(f"[DEBUG] <<< SSE 连接状态: {resp.status_code}")
if resp.status_code != 200:
body = await resp.aread()
raise HTTPException(
status_code=502,
detail=(
f"SSE 连接失败: HTTP {resp.status_code} "
f"- {body.decode()[:200]}"
),
)
# ── 从 SSE 读取结果 ──
current_event = None
async for line in resp.aiter_lines():
raw = repr(line)
print(f"[DEBUG] <<< 原始行: {raw[:300]}")
if not line.strip():
current_event = None
continue
if line.startswith("event:"):
current_event = line.split(":", 1)[1].strip()
if current_event == "error":
print("[DEBUG] ⚠️ error 事件")
elif current_event == "heartbeat":
continue
continue
if line.startswith("data:"):
data_str = line.split(":", 1)[1].strip()
if current_event == "heartbeat":
continue
if current_event == "error":
print(f"[DEBUG] ⚠️ SSE error: {data_str}")
raise HTTPException(
status_code=502,
detail=f"Gradio SSE error: {data_str}",
)
if data_str == "[DONE]":
print("[DEBUG] <<< [DONE]")
return
try:
payload = json.loads(data_str)
except json.JSONDecodeError:
print(
f"[DEBUG] ⚠️ JSON 解析失败: "
f"{data_str[:100]}"
)
continue
msg = payload.get("msg", "")
# JSON 内嵌错误
if msg in ("unexpected_error", "error"):
error_msg = payload.get("message", data_str)
print(f"[DEBUG] ⚠️ Gradio 错误: {error_msg}")
raise HTTPException(
status_code=502,
detail=f"Gradio error: {error_msg}",
)
# 流关闭
if msg == "close_stream":
print("[DEBUG] <<< close_stream")
return
# 提取文本
text = extract_gradio_text(payload)
if text is not None:
# 增量输出:只输出新增部分
if (
len(text) > len(last_text)
and text.startswith(last_text)
):
delta = text[len(last_text):]
last_text = text
print(
f"[DEBUG] <<< 增量文本: "
f"{repr(delta)[:80]}"
)
yield delta
elif not last_text and text:
last_text = text
print(
f"[DEBUG] <<< 首段文本: "
f"{repr(text)[:80]}"
)
yield text
else:
# 文本不连续,直接输出
last_text = text
yield text
elif msg == "process_completed":
print(
"[DEBUG] <<< process_completed"
"(无新增量文本,结束)"
)
return
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
last_user_msg = ""
for msg in reversed(request.messages):
if msg.role == "user":
last_user_msg = msg.content
break
if not last_user_msg:
raise HTTPException(
status_code=400, detail="No user message found"
)
await wake_up_space()
completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
if request.stream:
async def generate_sse():
try:
async for text_chunk in stream_gradio_response(
last_user_msg
):
chunk_data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": text_chunk,
},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(chunk_data)}\n\n"
yield "data: [DONE]\n\n"
except HTTPException as e:
error_data = {
"error": {
"message": e.detail,
"type": "server_error",
}
}
yield f"data: {json.dumps(error_data)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
error_data = {
"error": {
"message": str(e),
"type": "server_error",
}
}
yield f"data: {json.dumps(error_data)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_sse(), media_type="text/event-stream"
)
else:
full_text = ""
async for text_chunk in stream_gradio_response(last_user_msg):
full_text += text_chunk
print(
f"[DEBUG] <<< 最终非流式输出: "
f"{repr(full_text[:200])}"
)
response_data = {
"id": completion_id,
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": full_text,
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}
return JSONResponse(content=response_data)
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": "glm-5-1",
"object": "model",
"owned_by": "zai",
}
],
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))