Spaces:
Sleeping
Sleeping
| import json | |
| import time | |
| import uuid | |
| import os | |
| import asyncio | |
| from typing import List, Optional, AsyncGenerator | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import httpx | |
| app = FastAPI(title="OpenAI-Compatible GLM API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| GRADIO_API_BASE = "https://shtyyg-zai-org-glm-5-1.hf.space/gradio_api" | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = "glm-5-1" | |
| messages: List[ChatMessage] | |
| stream: Optional[bool] = False | |
| def extract_gradio_text(data_payload: dict) -> Optional[str]: | |
| """ | |
| 从 Gradio 5 SSE 数据中提取文本。 | |
| 增量 (process_generating): | |
| output.data = [[], [["append", [1,"content",0,"text"], "你好"]]] | |
| 完成 (process_completed): | |
| output.data = [null, [{role:"user",...}, | |
| {role:"assistant", content:[{type:"text", text:"你好!..."}]}]] | |
| """ | |
| msg = data_payload.get("msg", "") | |
| if msg in ( | |
| "send_hash", "send_data", "process_starts", | |
| "estimation", "queue_full", "close_stream", | |
| ): | |
| return None | |
| if msg in ("unexpected_error", "error"): | |
| return None | |
| output = data_payload.get("output") | |
| if not output or not isinstance(output, dict): | |
| return None | |
| output_data = output.get("data") | |
| if not output_data or not isinstance(output_data, list): | |
| return None | |
| if len(output_data) < 2: | |
| return None | |
| chatbot_data = output_data[1] | |
| if chatbot_data is None: | |
| return None | |
| # ── Case 1: append 增量操作 ── | |
| # [["append", [1,"content",0,"text"], "你好"]] | |
| if isinstance(chatbot_data, list) and len(chatbot_data) > 0: | |
| first = chatbot_data[0] | |
| if isinstance(first, list) and len(first) >= 3 and first[0] == "append": | |
| return str(first[2]) | |
| # ── Case 2: 完整 chatbot 消息 (process_completed) ── | |
| if isinstance(chatbot_data, list): | |
| for message in reversed(chatbot_data): | |
| if isinstance(message, dict) and message.get("role") == "assistant": | |
| content = message.get("content") | |
| if isinstance(content, list): | |
| texts = [] | |
| for item in content: | |
| if isinstance(item, dict) and item.get("type") == "text": | |
| texts.append(item.get("text", "")) | |
| return "".join(texts) if texts else None | |
| elif isinstance(content, str): | |
| return content | |
| return None | |
| async def wake_up_space() -> bool: | |
| space_url = "https://shtyyg-zai-org-glm-5-1.hf.space" | |
| print(f"[INFO] >>> 检查/唤醒 Space: {space_url}") | |
| async with httpx.AsyncClient(follow_redirects=True) as client: | |
| try: | |
| resp = await client.get(space_url, timeout=120) | |
| if resp.status_code == 200: | |
| print("[INFO] <<< Space 可达") | |
| return True | |
| print(f"[WARN] <<< Space 响应: {resp.status_code}") | |
| return False | |
| except Exception as e: | |
| print(f"[WARN] <<< Space 异常: {e}") | |
| return False | |
| async def stream_gradio_response(prompt: str) -> AsyncGenerator[str, None]: | |
| """ | |
| 严格按测试验证的 3 步顺序执行: | |
| 1. POST /run/predict (fn=2) → 提交用户输入 | |
| 2. POST /queue/join (fn=4) → 排队生成 | |
| 3. GET /queue/data?session_hash= → 接收 SSE 结果 | |
| ★ 所有请求在同一 client 内完成,按序执行 | |
| """ | |
| session_hash = uuid.uuid4().hex | |
| trigger_id = 9 | |
| run_url = f"{GRADIO_API_BASE}/run/predict" | |
| stream_url = ( | |
| f"{GRADIO_API_BASE}/queue/data?session_hash={session_hash}" | |
| ) | |
| join_url = f"{GRADIO_API_BASE}/queue/join" | |
| last_text = "" | |
| print(f"\n[DEBUG] >>> session_hash: {session_hash}") | |
| # ═══════════════════════════════════════════════════════ | |
| # 同一个 Client,按序执行 1 → 2 → 3 | |
| # ═══════════════════════════════════════════════════════ | |
| async with httpx.AsyncClient() as client: | |
| # ── Step 1: POST /run/predict (fn=2) ── | |
| r1_payload = { | |
| "data": [prompt], | |
| "fn_index": 2, | |
| "trigger_id": trigger_id, | |
| "session_hash": session_hash, | |
| } | |
| print(f"[DEBUG] >>> Step1 POST /run/predict fn=2") | |
| r1 = await client.post(run_url, json=r1_payload, timeout=60) | |
| print( | |
| f"[DEBUG] <<< Step1: {r1.status_code}, " | |
| f"内容: {r1.text[:200]}" | |
| ) | |
| if r1.status_code != 200: | |
| raise HTTPException( | |
| status_code=502, | |
| detail=f"Step1 失败: HTTP {r1.status_code}", | |
| ) | |
| # ── Step 2: POST /queue/join (fn=4) ── | |
| r2_payload = { | |
| "data": [None, None], | |
| "event_data": None, | |
| "fn_index": 4, | |
| "trigger_id": trigger_id, | |
| "session_hash": session_hash, | |
| } | |
| print(f"[DEBUG] >>> Step2 POST /queue/join fn=4") | |
| r2 = await client.post(join_url, json=r2_payload, timeout=60) | |
| print( | |
| f"[DEBUG] <<< Step2: {r2.status_code}, " | |
| f"内容: {r2.text[:200]}" | |
| ) | |
| if r2.status_code != 200: | |
| raise HTTPException( | |
| status_code=502, | |
| detail=f"Step2 失败: HTTP {r2.status_code}", | |
| ) | |
| # ── Step 3: GET /queue/data (SSE) ── | |
| print(f"[DEBUG] >>> Step3 GET /queue/data SSE") | |
| async with client.stream( | |
| "GET", | |
| stream_url, | |
| timeout=httpx.Timeout(300.0, read=300.0), | |
| ) as resp: | |
| print(f"[DEBUG] <<< SSE 连接状态: {resp.status_code}") | |
| if resp.status_code != 200: | |
| body = await resp.aread() | |
| raise HTTPException( | |
| status_code=502, | |
| detail=( | |
| f"SSE 连接失败: HTTP {resp.status_code} " | |
| f"- {body.decode()[:200]}" | |
| ), | |
| ) | |
| # ── 从 SSE 读取结果 ── | |
| current_event = None | |
| async for line in resp.aiter_lines(): | |
| raw = repr(line) | |
| print(f"[DEBUG] <<< 原始行: {raw[:300]}") | |
| if not line.strip(): | |
| current_event = None | |
| continue | |
| if line.startswith("event:"): | |
| current_event = line.split(":", 1)[1].strip() | |
| if current_event == "error": | |
| print("[DEBUG] ⚠️ error 事件") | |
| elif current_event == "heartbeat": | |
| continue | |
| continue | |
| if line.startswith("data:"): | |
| data_str = line.split(":", 1)[1].strip() | |
| if current_event == "heartbeat": | |
| continue | |
| if current_event == "error": | |
| print(f"[DEBUG] ⚠️ SSE error: {data_str}") | |
| raise HTTPException( | |
| status_code=502, | |
| detail=f"Gradio SSE error: {data_str}", | |
| ) | |
| if data_str == "[DONE]": | |
| print("[DEBUG] <<< [DONE]") | |
| return | |
| try: | |
| payload = json.loads(data_str) | |
| except json.JSONDecodeError: | |
| print( | |
| f"[DEBUG] ⚠️ JSON 解析失败: " | |
| f"{data_str[:100]}" | |
| ) | |
| continue | |
| msg = payload.get("msg", "") | |
| # JSON 内嵌错误 | |
| if msg in ("unexpected_error", "error"): | |
| error_msg = payload.get("message", data_str) | |
| print(f"[DEBUG] ⚠️ Gradio 错误: {error_msg}") | |
| raise HTTPException( | |
| status_code=502, | |
| detail=f"Gradio error: {error_msg}", | |
| ) | |
| # 流关闭 | |
| if msg == "close_stream": | |
| print("[DEBUG] <<< close_stream") | |
| return | |
| # 提取文本 | |
| text = extract_gradio_text(payload) | |
| if text is not None: | |
| # 增量输出:只输出新增部分 | |
| if ( | |
| len(text) > len(last_text) | |
| and text.startswith(last_text) | |
| ): | |
| delta = text[len(last_text):] | |
| last_text = text | |
| print( | |
| f"[DEBUG] <<< 增量文本: " | |
| f"{repr(delta)[:80]}" | |
| ) | |
| yield delta | |
| elif not last_text and text: | |
| last_text = text | |
| print( | |
| f"[DEBUG] <<< 首段文本: " | |
| f"{repr(text)[:80]}" | |
| ) | |
| yield text | |
| else: | |
| # 文本不连续,直接输出 | |
| last_text = text | |
| yield text | |
| elif msg == "process_completed": | |
| print( | |
| "[DEBUG] <<< process_completed" | |
| "(无新增量文本,结束)" | |
| ) | |
| return | |
| async def create_chat_completion(request: ChatCompletionRequest): | |
| last_user_msg = "" | |
| for msg in reversed(request.messages): | |
| if msg.role == "user": | |
| last_user_msg = msg.content | |
| break | |
| if not last_user_msg: | |
| raise HTTPException( | |
| status_code=400, detail="No user message found" | |
| ) | |
| await wake_up_space() | |
| completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" | |
| if request.stream: | |
| async def generate_sse(): | |
| try: | |
| async for text_chunk in stream_gradio_response( | |
| last_user_msg | |
| ): | |
| chunk_data = { | |
| "id": completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": request.model, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "delta": { | |
| "role": "assistant", | |
| "content": text_chunk, | |
| }, | |
| "finish_reason": None, | |
| } | |
| ], | |
| } | |
| yield f"data: {json.dumps(chunk_data)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| except HTTPException as e: | |
| error_data = { | |
| "error": { | |
| "message": e.detail, | |
| "type": "server_error", | |
| } | |
| } | |
| yield f"data: {json.dumps(error_data)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| except Exception as e: | |
| error_data = { | |
| "error": { | |
| "message": str(e), | |
| "type": "server_error", | |
| } | |
| } | |
| yield f"data: {json.dumps(error_data)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse( | |
| generate_sse(), media_type="text/event-stream" | |
| ) | |
| else: | |
| full_text = "" | |
| async for text_chunk in stream_gradio_response(last_user_msg): | |
| full_text += text_chunk | |
| print( | |
| f"[DEBUG] <<< 最终非流式输出: " | |
| f"{repr(full_text[:200])}" | |
| ) | |
| response_data = { | |
| "id": completion_id, | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": request.model, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": full_text, | |
| }, | |
| "finish_reason": "stop", | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": 0, | |
| "completion_tokens": 0, | |
| "total_tokens": 0, | |
| }, | |
| } | |
| return JSONResponse(content=response_data) | |
| async def list_models(): | |
| return { | |
| "object": "list", | |
| "data": [ | |
| { | |
| "id": "glm-5-1", | |
| "object": "model", | |
| "owned_by": "zai", | |
| } | |
| ], | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860))) |