Spaces:
Running
Running
from fastapi import FastAPI, Request, Response, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse | |
import httpx | |
import json | |
import uuid | |
from typing import Optional, List, Dict, Any | |
from pydantic import BaseModel | |
import asyncio | |
# 创建FastAPI应用 | |
app = FastAPI() | |
# 配置CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# 定义数据模型 | |
class Message(BaseModel): | |
role: str | |
content: str | |
class ChatRequest(BaseModel): | |
messages: List[Message] | |
model: str | |
stream: Optional[bool] = True | |
class ChatResponse(BaseModel): | |
id: str | |
object: str = "chat.completion" | |
created: int | |
model: str | |
choices: List[Dict[str, Any]] | |
usage: Optional[Dict[str, int]] = None | |
# 模型映射 | |
MODEL_MAPPING = { | |
"gpt-4o-mini-abacus": "OPENAI_GPT4O_MINI", | |
"claude-3.5-sonnet-abacus": "CLAUDE_V3_5_SONNET", | |
"claude-3.7-sonnet-abacus": "CLAUDE_V3_7_SONNET", | |
"claude-3.7-sonnet-thinking-abacus": "CLAUDE_V3_7_SONNET_THINKING", | |
"o3-mini-abacus": "OPENAI_O3_MINI", | |
"o3-mini-high-abacus": "OPENAI_O3_MINI_HIGH", | |
"o1-mini-abacus": "OPENAI_O1_MINI", | |
"deepseek-r1-abacus": "DEEPSEEK_R1", | |
"gemini-2-pro-abacus": "GEMINI_2_PRO", | |
"gemini-2-flash-thinking-abacus": "GEMINI_2_FLASH_THINKING", | |
"gemini-2-flash-abacus": "GEMINI_2_FLASH", | |
"gemini-1.5-pro-abacus": "GEMINI_1_5_PRO", | |
"xai-grok-abacus": "XAI_GROK", | |
"deepseek-v3-abacus": "DEEPSEEK_V3", | |
"llama3-1-405b-abacus": "LLAMA3_1_405B", | |
"gpt-4o-abacus": "OPENAI_GPT4O", | |
"gpt-4o-2024-08-06-abacus": "OPENAI_GPT4O", | |
"gpt-3.5-turbo-abacus": "OPENAI_O3_MINI", | |
"gpt-3.5-turbo-16k-abacus": "OPENAI_O3_MINI_HIGH" | |
} | |
BASE_URL = "https://pa002.abacus.ai" | |
TIMEOUT = 30.0 # 请求超时时间(秒) | |
MAX_RETRIES = 3 # 最大重试次数 | |
RETRY_DELAY = 1 # 重试延迟(秒) | |
async def list_models(): | |
"""返回支持的模型列表""" | |
models = [ | |
{ | |
"id": model_id, | |
"object": "model", | |
"created": 1677610602, | |
"owned_by": "system", | |
} | |
for model_id in MODEL_MAPPING.keys() | |
] | |
return { | |
"object": "list", | |
"data": models | |
} | |
# 工具函数:获取请求头 | |
def get_headers(auth_token: str) -> Dict[str, str]: | |
"""生成请求头""" | |
return { | |
"sec-ch-ua-platform": "Windows", | |
"sec-ch-ua": '"Not(A:Brand";v="99", "Microsoft Edge";v="133", "Chromium";v="133"', | |
"sec-ch-ua-mobile": "?0", | |
"X-Abacus-Org-Host": "apps", | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36 Edg/133.0.0.0", | |
"Sec-Fetch-Site": "same-site", | |
"Sec-Fetch-Mode": "cors", | |
"Sec-Fetch-Dest": "empty", | |
"host": "pa002.abacus.ai", | |
"Cookie": auth_token, | |
"Accept": "text/event-stream", | |
"Content-Type": "text/plain;charset=UTF-8" | |
} | |
def process_messages(messages: List[Message]) -> str: | |
"""处理消息列表,合并成单个消息""" | |
system_message = next((msg.content for msg in messages if msg.role == "system"), None) | |
context_messages = [msg for msg in messages if msg.role != "system"][:-1] | |
current_message = messages[-1].content | |
full_message = current_message | |
if system_message: | |
full_message = f"System: {system_message}\n\n{full_message}" | |
if context_messages: | |
context_str = "\n".join(f"{msg.role}: {msg.content}" for msg in context_messages) | |
full_message = f"Previous conversation:\n{context_str}\nCurrent message: {full_message}" | |
return full_message | |
async def chat_completions(request: Request, chat_request: ChatRequest): | |
"""处理聊天完成请求""" | |
# 获取认证token | |
auth_header = request.headers.get("Authorization", "") | |
if not auth_header.startswith("Bearer "): | |
return Response( | |
content=json.dumps({"error": "未提供有效的Authorization header"}), | |
status_code=401 | |
) | |
auth_token = auth_header.replace("Bearer ", "") | |
# 创建会话ID | |
conversation_id = str(uuid.uuid4()) | |
# 处理消息 | |
full_message = process_messages(chat_request.messages) | |
# 准备请求数据 | |
request_data = { | |
"requestId": str(uuid.uuid4()), | |
"deploymentConversationId": conversation_id, | |
"message": full_message, | |
"isDesktop": True, | |
"chatConfig": { | |
"timezone": "Asia/Shanghai", | |
"language": "zh-CN" | |
}, | |
"llmName": MODEL_MAPPING.get(chat_request.model, chat_request.model), | |
"externalApplicationId": str(uuid.uuid4()) | |
} | |
# 流式请求处理 | |
async def generate_stream(): | |
headers = get_headers(auth_token) | |
for retry in range(MAX_RETRIES): | |
try: | |
async with httpx.AsyncClient() as client: | |
async with client.stream( | |
"POST", | |
f"{BASE_URL}/api/_chatLLMSendMessageSSE", | |
headers=headers, | |
content=json.dumps(request_data), | |
timeout=TIMEOUT | |
) as response: | |
async for line in response.aiter_lines(): | |
if not line.strip(): | |
continue | |
try: | |
data = json.loads(line) | |
if data.get("type") == "text" and data.get("title") != "Thinking...": | |
chunk = { | |
"id": str(uuid.uuid4()), | |
"object": "chat.completion.chunk", | |
"created": int(uuid.uuid1().time_low), | |
"model": chat_request.model, | |
"choices": [{ | |
"delta": { | |
"role": "assistant", | |
"content": data.get("segment", "") | |
}, | |
"index": 0 | |
}] | |
} | |
yield f"data: {json.dumps(chunk)}\n\n" | |
if data.get("end"): | |
# 发送结束标记 | |
chunk = { | |
"id": str(uuid.uuid4()), | |
"object": "chat.completion.chunk", | |
"created": int(uuid.uuid1().time_low), | |
"model": chat_request.model, | |
"choices": [{ | |
"delta": {"content": ""}, | |
"index": 0, | |
"finish_reason": "stop" | |
}] | |
} | |
yield f"data: {json.dumps(chunk)}\n\n" | |
yield "data: [DONE]\n\n" | |
break # 成功完成,退出重试循环 | |
except json.JSONDecodeError: | |
continue | |
except (httpx.TimeoutException, httpx.RequestError) as e: | |
if retry == MAX_RETRIES - 1: # 最后一次重试 | |
yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
yield "data: [DONE]\n\n" | |
return | |
await asyncio.sleep(RETRY_DELAY) | |
return StreamingResponse( | |
generate_stream(), | |
media_type="text/event-stream" | |
) | |
async def health_check(): | |
"""健康检查""" | |
return {"status": "ok", "version": "1.0.0"} | |
async def global_exception_handler(request: Request, exc: Exception): | |
"""全局异常处理""" | |
error_message = str(exc) | |
return Response( | |
content=json.dumps({ | |
"error": { | |
"message": error_message, | |
"type": exc.__class__.__name__, | |
"code": 500 | |
} | |
}), | |
status_code=500, | |
media_type="application/json" | |
) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |