Spaces:
Running
Running
kevin
commited on
Commit
·
feb939c
1
Parent(s):
a66c30a
think buddy
Browse files- Dockerfile +17 -0
- README.md +1 -0
- core/__init__.py +0 -0
- core/app.py +69 -0
- core/auth.py +12 -0
- core/config.py +79 -0
- core/logger.py +20 -0
- core/models.py +16 -0
- core/refresh_token.py +105 -0
- core/router.py +241 -0
- core/utils.py +267 -0
- main.py +6 -0
- requirements.txt +11 -0
Dockerfile
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10.16-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
# 复制所需文件到容器中
|
6 |
+
COPY . .
|
7 |
+
|
8 |
+
|
9 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
10 |
+
ENV APP_SECRET="sk-123456"
|
11 |
+
ENV REQUEST_TIMEOUT=30
|
12 |
+
|
13 |
+
# Expose port
|
14 |
+
EXPOSE 8001
|
15 |
+
|
16 |
+
# Run the application
|
17 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8001"]
|
README.md
CHANGED
@@ -5,6 +5,7 @@ colorFrom: blue
|
|
5 |
colorTo: red
|
6 |
sdk: docker
|
7 |
pinned: false
|
|
|
8 |
---
|
9 |
|
10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
5 |
colorTo: red
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
+
app_port: 8001
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
core/__init__.py
ADDED
File without changes
|
core/app.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
from fastapi import FastAPI, Request
|
4 |
+
from starlette.middleware.cors import CORSMiddleware
|
5 |
+
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
6 |
+
from fastapi.responses import JSONResponse
|
7 |
+
from core.config import get_settings
|
8 |
+
from core.logger import setup_logger
|
9 |
+
from core.refresh_token import TokenManager
|
10 |
+
from core.router import router
|
11 |
+
|
12 |
+
settings = get_settings()
|
13 |
+
logger = setup_logger(__name__)
|
14 |
+
|
15 |
+
# print(settings.SECRET)
|
16 |
+
def create_app() -> FastAPI:
|
17 |
+
app = FastAPI(
|
18 |
+
title=settings.PROJECT_NAME,
|
19 |
+
version="0.0.1",
|
20 |
+
description=settings.DESCRIPTION,
|
21 |
+
)
|
22 |
+
# 配置中间件
|
23 |
+
app.add_middleware(
|
24 |
+
CORSMiddleware,
|
25 |
+
allow_origins=["*"],
|
26 |
+
allow_credentials=True,
|
27 |
+
allow_methods=["*"],
|
28 |
+
allow_headers=["*"],
|
29 |
+
)
|
30 |
+
|
31 |
+
# # 添加可信主机中间件
|
32 |
+
app.add_middleware(
|
33 |
+
TrustedHostMiddleware,
|
34 |
+
allowed_hosts=["*"] # 在生产环境中应该限制允许的主机
|
35 |
+
)
|
36 |
+
# 添加路由
|
37 |
+
app.include_router(router, prefix="/api/v1")
|
38 |
+
app.include_router(router, prefix="/v1") # 兼容性路由
|
39 |
+
|
40 |
+
@app.exception_handler(Exception)
|
41 |
+
async def global_exception_handler(request: Request, exc: Exception):
|
42 |
+
logger.error(f"An error occurred: {str(exc)}", exc_info=True)
|
43 |
+
return JSONResponse(
|
44 |
+
status_code=500,
|
45 |
+
content={
|
46 |
+
"message": "An internal server error occurred.",
|
47 |
+
"detail": str(exc)
|
48 |
+
},
|
49 |
+
)
|
50 |
+
|
51 |
+
# # 创建 TokenManager 实例
|
52 |
+
token_manager = TokenManager()
|
53 |
+
@app.on_event("startup")
|
54 |
+
async def startup_event():
|
55 |
+
# 在应用启动时创建任务
|
56 |
+
app.state.refresh_task = asyncio.create_task(token_manager.start_auto_refresh())
|
57 |
+
|
58 |
+
@app.on_event("shutdown")
|
59 |
+
async def shutdown_event():
|
60 |
+
# 在应用关闭时取消任务
|
61 |
+
if hasattr(app.state, 'refresh_task'):
|
62 |
+
app.state.refresh_task.cancel()
|
63 |
+
try:
|
64 |
+
await app.state.refresh_task
|
65 |
+
except asyncio.CancelledError:
|
66 |
+
pass
|
67 |
+
return app
|
68 |
+
|
69 |
+
app = create_app()
|
core/auth.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import Depends, HTTPException
|
2 |
+
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
3 |
+
from core.config import get_settings
|
4 |
+
|
5 |
+
settings = get_settings()
|
6 |
+
APP_SECRET = settings.APP_SECRET
|
7 |
+
security = HTTPBearer()
|
8 |
+
|
9 |
+
def verify_app_secret(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
10 |
+
if credentials.credentials != APP_SECRET:
|
11 |
+
raise HTTPException(status_code=403, detail="Invalid SECRET")
|
12 |
+
return credentials.credentials
|
core/config.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import platform
|
3 |
+
import uuid
|
4 |
+
from typing import List, Dict
|
5 |
+
|
6 |
+
import httpx
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from pydantic_settings import BaseSettings
|
9 |
+
|
10 |
+
load_dotenv()
|
11 |
+
|
12 |
+
|
13 |
+
class Settings(BaseSettings):
|
14 |
+
APP_SECRET: str = os.getenv('APP_SECRET', '')
|
15 |
+
TOKEN: str = os.getenv('TOKEN', '')
|
16 |
+
PROJECT_NAME: str = os.getenv('PROJECT_NAME', 'FastAPI')
|
17 |
+
DESCRIPTION: str = os.getenv('DESCRIPTION', 'FastAPI template')
|
18 |
+
FIREBASE_API_KEY: str = os.getenv('FIREBASE_API_KEY', '')
|
19 |
+
REFRESH_TOKEN: str = os.getenv('REFRESH_TOKEN', '')
|
20 |
+
AUTHORIZATION_TOKEN: str = os.getenv('AUTHORIZATION_TOKEN', '')
|
21 |
+
|
22 |
+
ALLOWED_MODELS: List[Dict[str, str]] = [
|
23 |
+
{"id": "gpt-4o", "name": "GPT-4o [thinkbuddy]"},
|
24 |
+
{"id": "claude-3-5-sonnet", "name": "Claude 3.5 Sonnet v2 [thinkbuddy]"},
|
25 |
+
{"id": "gemini-2-flash", "name": "Gemini 2 Flash [thinkbuddy]"},
|
26 |
+
{"id": "nova-pro", "name": "Nova Pro [thinkbuddy]"},
|
27 |
+
{"id": "deepseek-v3", "name": "DeepSeek v3 [thinkbuddy]"},
|
28 |
+
{"id": "llama-3-3", "name": "Llama 3.3 (70B) [thinkbuddy]"},
|
29 |
+
{"id": "mistral-large-2", "name": "Mistral Large v2 [thinkbuddy]"},
|
30 |
+
{"id": "command-r-plus", "name": "Command R+ [thinkbuddy]"},
|
31 |
+
{"id": "o1-preview", "name": "o1-preview [thinkbuddy]"},
|
32 |
+
{"id": "o1-mini", "name": "o1-mini [thinkbuddy]"},
|
33 |
+
{"id": "gemini-2-thinking", "name": "Gemini 2 Thinking [thinkbuddy]"},
|
34 |
+
{"id": "claude-3-5-haiku", "name": "Claude 3.5 Haiku [thinkbuddy]"},
|
35 |
+
{"id": "gemini-1-5-flash", "name": "Gemini 1.5 Flash [thinkbuddy]"},
|
36 |
+
{"id": "gpt-4o-mini", "name": "GPT-4o mini [thinkbuddy]"},
|
37 |
+
{"id": "nova-lite", "name": "Nova Lite [thinkbuddy]"},
|
38 |
+
{"id": "gemini-1-5-pro", "name": "Gemini 1.5 Pro [thinkbuddy]"},
|
39 |
+
{"id": "claude-3-opus", "name": "Claude 3 Opus [thinkbuddy]"},
|
40 |
+
{"id": "gpt-4-turbo", "name": "GPT-4 Turbo [thinkbuddy]"},
|
41 |
+
{"id": "llama-3-1", "name": "Llama 3.1 [thinkbuddy]"}
|
42 |
+
]
|
43 |
+
MODEL_MAPPING: Dict[str, str] = {
|
44 |
+
"gpt-4o": "gpt-4o",
|
45 |
+
"o1-preview": "o1-preview",
|
46 |
+
"claude-3-5-sonnet": "claude-3-5-sonnet",
|
47 |
+
"o1-mini": "o1-mini",
|
48 |
+
"gemini-1.5-pro": "gemini-1.5-pro",
|
49 |
+
"gemini-2.0-flash": "gemini-2.0-flash",
|
50 |
+
}
|
51 |
+
HEADERS: Dict[str, str] = {
|
52 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36',
|
53 |
+
'Accept-Encoding': 'gzip, deflate, br, zstd',
|
54 |
+
'Content-Type': 'application/json',
|
55 |
+
'sec-ch-ua-platform': 'Windows',
|
56 |
+
'authorization': f"Bearer {os.getenv('TOKEN', '')}",
|
57 |
+
'sec-ch-ua': '"Google Chrome";v="131", "Chromium";v="131", "Not_A Brand";v="24"',
|
58 |
+
'dnt': '1',
|
59 |
+
'sec-ch-ua-mobile': '?0',
|
60 |
+
'origin': 'https://thinkbuddy.ai',
|
61 |
+
'sec-fetch-site': 'same-site',
|
62 |
+
'sec-fetch-mode': 'cors',
|
63 |
+
'sec-fetch-dest': 'empty',
|
64 |
+
'referer': 'https://thinkbuddy.ai/',
|
65 |
+
'accept-language': 'en-US,en;q=0.9,zh-CN;q=0.8,zh-TW;q=0.7,zh;q=0.6',
|
66 |
+
'priority': 'u=1, i'
|
67 |
+
}
|
68 |
+
|
69 |
+
class Config:
|
70 |
+
env_file = '.env'
|
71 |
+
case_sensitive = True
|
72 |
+
|
73 |
+
_settings = None
|
74 |
+
|
75 |
+
def get_settings():
|
76 |
+
global _settings
|
77 |
+
if _settings is None:
|
78 |
+
_settings = Settings()
|
79 |
+
return _settings
|
core/logger.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
def setup_logger(name):
|
4 |
+
logger = logging.getLogger(name)
|
5 |
+
if not logger.handlers:
|
6 |
+
logger.setLevel(logging.INFO)
|
7 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
8 |
+
|
9 |
+
# 控制台处理器
|
10 |
+
console_handler = logging.StreamHandler()
|
11 |
+
console_handler.setFormatter(formatter)
|
12 |
+
logger.addHandler(console_handler)
|
13 |
+
|
14 |
+
# 文件处理器 - 错误级别
|
15 |
+
# error_file_handler = logging.FileHandler('error.log')
|
16 |
+
# error_file_handler.setFormatter(formatter)
|
17 |
+
# error_file_handler.setLevel(logging.ERROR)
|
18 |
+
# logger.addHandler(error_file_handler)
|
19 |
+
|
20 |
+
return logger
|
core/models.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
from pydantic import BaseModel
|
3 |
+
|
4 |
+
|
5 |
+
class Message(BaseModel):
|
6 |
+
role: str
|
7 |
+
content: str | list
|
8 |
+
|
9 |
+
|
10 |
+
class ChatRequest(BaseModel):
|
11 |
+
model: str
|
12 |
+
messages: List[Message]
|
13 |
+
stream: Optional[bool] = False
|
14 |
+
temperature: Optional[float] = 0.7
|
15 |
+
top_p: Optional[float] = 0.9
|
16 |
+
max_tokens: Optional[int] = 8192
|
core/refresh_token.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 用于刷新idToken的定时任务
|
2 |
+
import asyncio
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from typing import Optional
|
6 |
+
from datetime import datetime
|
7 |
+
from core.utils import sign_in_with_idp, handle_firebase_response, refresh_token_via_rest
|
8 |
+
|
9 |
+
|
10 |
+
class TokenManager:
|
11 |
+
def __init__(self):
|
12 |
+
self.id_token: Optional[str] = None
|
13 |
+
self.last_refresh_time: Optional[float] = None
|
14 |
+
self.refresh_interval = 30 * 60 # 30分钟,单位:秒
|
15 |
+
self.is_running = False
|
16 |
+
|
17 |
+
async def get_token(self) -> str:
|
18 |
+
"""
|
19 |
+
获取当前的 idToken,如果不存在或已过期则刷新
|
20 |
+
"""
|
21 |
+
if not self.id_token or self._should_refresh():
|
22 |
+
await self.refresh_token()
|
23 |
+
return self.id_token
|
24 |
+
|
25 |
+
def _should_refresh(self) -> bool:
|
26 |
+
"""
|
27 |
+
检查是否需要刷新 token
|
28 |
+
"""
|
29 |
+
if not self.last_refresh_time:
|
30 |
+
return True
|
31 |
+
return time.time() - self.last_refresh_time >= self.refresh_interval
|
32 |
+
|
33 |
+
async def refresh_token(self):
|
34 |
+
"""
|
35 |
+
刷新 idToken
|
36 |
+
"""
|
37 |
+
try:
|
38 |
+
if os.getenv("REFRESH_TOKEN", "") == "" or os.getenv("REFRESH_TOKEN", "") == "None":
|
39 |
+
response = await sign_in_with_idp()
|
40 |
+
result = await handle_firebase_response(response) # idToken 实际就是bearer token
|
41 |
+
else:
|
42 |
+
result = await refresh_token_via_rest(os.getenv("REFRESH_TOKEN"))
|
43 |
+
if result is not None:
|
44 |
+
self.id_token = result
|
45 |
+
# 修改配置中的 TOKEN
|
46 |
+
print(f"Before Token is {os.getenv('TOKEN', '')}")
|
47 |
+
os.environ["TOKEN"] = self.id_token
|
48 |
+
print(f"Now Token is {os.getenv('TOKEN', '')}")
|
49 |
+
self.last_refresh_time = time.time()
|
50 |
+
print(f"Token refreshed at {datetime.now()}")
|
51 |
+
else:
|
52 |
+
print(f"Failed to refresh token: {result['error']}")
|
53 |
+
except Exception as e:
|
54 |
+
print(f"Error refreshing token: {str(e)}")
|
55 |
+
|
56 |
+
async def start_auto_refresh(self):
|
57 |
+
"""
|
58 |
+
启动自动刷新任务
|
59 |
+
"""
|
60 |
+
if self.is_running:
|
61 |
+
return
|
62 |
+
|
63 |
+
self.is_running = True
|
64 |
+
while self.is_running:
|
65 |
+
try:
|
66 |
+
await self.refresh_token()
|
67 |
+
# 等待到下次刷新时间
|
68 |
+
await asyncio.sleep(self.refresh_interval)
|
69 |
+
except Exception as e:
|
70 |
+
print(f"Auto refresh error: {str(e)}")
|
71 |
+
# 发生错误时等待短暂时间后重试
|
72 |
+
await asyncio.sleep(60)
|
73 |
+
|
74 |
+
async def stop_auto_refresh(self):
|
75 |
+
"""
|
76 |
+
停止自动刷新任务
|
77 |
+
"""
|
78 |
+
self.is_running = False
|
79 |
+
|
80 |
+
# 使用示例
|
81 |
+
# async def main():
|
82 |
+
# # 创建 TokenManager 实例
|
83 |
+
# token_manager = TokenManager()
|
84 |
+
#
|
85 |
+
# try:
|
86 |
+
# # 启动自动刷新任务
|
87 |
+
# refresh_task = asyncio.create_task(token_manager.start_auto_refresh())
|
88 |
+
#
|
89 |
+
# # 模拟应用运行
|
90 |
+
# while True:
|
91 |
+
# # 获取当前 token
|
92 |
+
# token = await token_manager.get_token()
|
93 |
+
# print(f"Current token: {token[:20]}...")
|
94 |
+
#
|
95 |
+
# # 等待一段时间再次获取
|
96 |
+
# await asyncio.sleep(300) # 每5分钟打印一次当前token
|
97 |
+
#
|
98 |
+
# except KeyboardInterrupt:
|
99 |
+
# # 处理 Ctrl+C
|
100 |
+
# await token_manager.stop_auto_refresh()
|
101 |
+
# await refresh_task
|
102 |
+
#
|
103 |
+
# # 运行示例
|
104 |
+
# if __name__ == "__main__":
|
105 |
+
# asyncio.run(main())
|
core/router.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import traceback
|
5 |
+
from typing import Optional
|
6 |
+
from uuid import uuid4
|
7 |
+
|
8 |
+
import httpx
|
9 |
+
from fastapi import File, UploadFile
|
10 |
+
from fastapi import APIRouter, Response, Request, Depends, HTTPException
|
11 |
+
from fastapi.responses import StreamingResponse
|
12 |
+
|
13 |
+
|
14 |
+
from core.auth import verify_app_secret
|
15 |
+
from core.config import get_settings
|
16 |
+
from core.logger import setup_logger
|
17 |
+
from core.models import ChatRequest
|
18 |
+
from core.utils import process_streaming_response
|
19 |
+
from playsound import playsound # 用于播放音频
|
20 |
+
|
21 |
+
logger = setup_logger(__name__)
|
22 |
+
router = APIRouter()
|
23 |
+
ALLOWED_MODELS = get_settings().ALLOWED_MODELS
|
24 |
+
|
25 |
+
@router.get("/models")
|
26 |
+
async def list_models():
|
27 |
+
return {"object": "list", "data": ALLOWED_MODELS, "success": True}
|
28 |
+
|
29 |
+
@router.options("/chat/completions")
|
30 |
+
async def chat_completions_options():
|
31 |
+
return Response(
|
32 |
+
status_code=200,
|
33 |
+
headers={
|
34 |
+
"Access-Control-Allow-Origin": "*",
|
35 |
+
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
36 |
+
"Access-Control-Allow-Headers": "Content-Type, Authorization",
|
37 |
+
},
|
38 |
+
)
|
39 |
+
# 识图
|
40 |
+
|
41 |
+
# 识图
|
42 |
+
# 文本转语音
|
43 |
+
@router.post("/audio/speech")
|
44 |
+
async def speech(request: Request):
|
45 |
+
url = 'https://api.thinkbuddy.ai/v1/content/speech/tts'
|
46 |
+
request_headers = {**get_settings().HEADERS,
|
47 |
+
'authorization': f"Bearer {os.getenv('TOKEN', '')}",
|
48 |
+
'Accept': 'application/json, text/plain, */*',
|
49 |
+
}
|
50 |
+
# data = {
|
51 |
+
# "input": "这是一张插图,显示了一杯饮料,可能是奶昔、冰沙或其他冷饮。杯子上有一个盖子和一根吸管,表明这是一种便于携带和饮用的饮品。这种设计通常用于提供咖啡、冰茶或果汁等饮品。杯子颜色简约,可能用于说明饮品的内容或品牌。",
|
52 |
+
# "voice": "nova" # alloy echo fable onyx nova shimmer
|
53 |
+
# }
|
54 |
+
body = await request.json()
|
55 |
+
try:
|
56 |
+
async with httpx.AsyncClient(http2=True) as client:
|
57 |
+
response = await client.post(url, headers=request_headers, json=body)
|
58 |
+
response.raise_for_status()
|
59 |
+
|
60 |
+
# 假设响应是音频数据,保存为文件
|
61 |
+
if response.status_code == 200:
|
62 |
+
# 保存音频文件
|
63 |
+
with open('output.mp3', 'wb') as f:
|
64 |
+
f.write(response.content)
|
65 |
+
print("音频文件已保存为 output.mp3")
|
66 |
+
|
67 |
+
# 异步播放音频
|
68 |
+
# 使用 asyncio.to_thread 来避免阻塞事件循环
|
69 |
+
# await asyncio.to_thread(playsound, 'output.mp3')
|
70 |
+
return True
|
71 |
+
else:
|
72 |
+
print(f"请求失败,状态码: {response.status_code}")
|
73 |
+
print(f"响应内容: {response.text}")
|
74 |
+
return False
|
75 |
+
|
76 |
+
except httpx.RequestError as e:
|
77 |
+
print(f"请求错误: {e}")
|
78 |
+
print("错误堆栈:")
|
79 |
+
traceback.print_exc()
|
80 |
+
return False
|
81 |
+
except httpx.HTTPStatusError as e:
|
82 |
+
print(f"HTTP 错误: {e}")
|
83 |
+
print("错误堆栈:")
|
84 |
+
traceback.print_exc()
|
85 |
+
return False
|
86 |
+
except Exception as e:
|
87 |
+
print(f"发生错误: {e}")
|
88 |
+
print("错误堆栈:")
|
89 |
+
traceback.print_exc()
|
90 |
+
return False
|
91 |
+
|
92 |
+
|
93 |
+
# 语音转文本
|
94 |
+
@router.post("/audio/transcriptions")
|
95 |
+
async def transcriptions(request: Request, file: UploadFile = File(...)):
|
96 |
+
url = 'https://api.thinkbuddy.ai/v1/content/transcribe'
|
97 |
+
params = {'enhance': 'true'}
|
98 |
+
try:
|
99 |
+
# 读取文件内容
|
100 |
+
content = await safe_read_file(file)
|
101 |
+
# 获取原始 content-type
|
102 |
+
content_type = request.headers.get('content-type')
|
103 |
+
# files = {
|
104 |
+
# 'file': (str(uuid4()),
|
105 |
+
# content,
|
106 |
+
# file.content_type or 'application/octet-stream')
|
107 |
+
# }
|
108 |
+
files = {
|
109 |
+
'file': ('file.mp4', content, 'audio/mp4'),
|
110 |
+
'model': (None, 'whisper-1')
|
111 |
+
}
|
112 |
+
# 记录请求信息
|
113 |
+
logger.info(f"Received upload request for file: {file.filename}")
|
114 |
+
logger.info(f"Content-Type: {request.headers.get('content-type')}")
|
115 |
+
request_headers = {**get_settings().HEADERS,
|
116 |
+
'authorization': f"Bearer {os.getenv('TOKEN', '')}",
|
117 |
+
'Accept': 'application/json, text/plain, */*',
|
118 |
+
'Content-Type': content_type,
|
119 |
+
}
|
120 |
+
# 设置较长的超时时间
|
121 |
+
timeout = httpx.Timeout(
|
122 |
+
connect=30.0, # 连接超时
|
123 |
+
read=300.0, # 读取超时
|
124 |
+
write=30.0, # 写入超时
|
125 |
+
pool=30.0 # 连接池超时
|
126 |
+
)
|
127 |
+
# 使用httpx发送异步请求
|
128 |
+
async with httpx.AsyncClient(http2=True, timeout=timeout) as client:
|
129 |
+
response = await client.post(url,
|
130 |
+
params=params,
|
131 |
+
headers=request_headers,
|
132 |
+
files=files)
|
133 |
+
response.raise_for_status()
|
134 |
+
return response.json()
|
135 |
+
|
136 |
+
except httpx.TimeoutException:
|
137 |
+
raise HTTPException(status_code=504, detail="请求目标服务器超时")
|
138 |
+
except httpx.HTTPStatusError as e:
|
139 |
+
raise HTTPException(status_code=e.response.status_code, detail=str(e))
|
140 |
+
except Exception as e:
|
141 |
+
traceback.print_tb(e.__traceback__)
|
142 |
+
raise HTTPException(status_code=500, detail=str(e))
|
143 |
+
finally:
|
144 |
+
# 清理资源
|
145 |
+
await file.close()
|
146 |
+
async def safe_read_file(file: UploadFile) -> Optional[bytes]:
|
147 |
+
"""安全地读取文件内容"""
|
148 |
+
try:
|
149 |
+
return await file.read()
|
150 |
+
except Exception as e:
|
151 |
+
logger.error(f"Error reading file: {str(e)}")
|
152 |
+
return None
|
153 |
+
# 文件上传
|
154 |
+
@router.post("/upload")
|
155 |
+
async def upload_file(request: Request, file: UploadFile = File(...)):
|
156 |
+
try:
|
157 |
+
# 读取文件内容
|
158 |
+
content = await safe_read_file(file)
|
159 |
+
# 获取原始 content-type
|
160 |
+
content_type = request.headers.get('content-type')
|
161 |
+
files = {
|
162 |
+
'file': (
|
163 |
+
# str(uuid4()),
|
164 |
+
file.filename, # 使用原始文件名而不是 UUID
|
165 |
+
content,
|
166 |
+
file.content_type )
|
167 |
+
}
|
168 |
+
# 记录请求信息
|
169 |
+
logger.info(f"Received upload request for file: {file.filename}")
|
170 |
+
logger.info(f"Content-Type: {request.headers.get('content-type')}")
|
171 |
+
request_headers = {**get_settings().HEADERS,
|
172 |
+
'authorization': f"Bearer {os.getenv('TOKEN', '')}",
|
173 |
+
'Accept': 'application/json, text/plain, */*',
|
174 |
+
'Content-Type': content_type,
|
175 |
+
}
|
176 |
+
# 使用httpx发送异步请求
|
177 |
+
async with httpx.AsyncClient() as client:
|
178 |
+
response = await client.post(f"https://api.thinkbuddy.ai/v1/uploads/images", headers=request_headers,files=files, timeout=100)
|
179 |
+
response.raise_for_status()
|
180 |
+
return response.json()
|
181 |
+
|
182 |
+
except httpx.TimeoutException:
|
183 |
+
raise HTTPException(status_code=504, detail="请求目标服务器超时")
|
184 |
+
except httpx.HTTPStatusError as e:
|
185 |
+
# raise HTTPException(status_code=e.response.status_code, detail=str(e))
|
186 |
+
print(f"HTTPStatusError发生错误: {e}")
|
187 |
+
print("错误堆栈:")
|
188 |
+
traceback.print_exc()
|
189 |
+
except Exception as e:
|
190 |
+
# traceback.print_tb(e.__traceback__)
|
191 |
+
# raise HTTPException(status_code=500, detail=str(e))
|
192 |
+
print(f"发生错误: {e}")
|
193 |
+
print("错误堆栈:")
|
194 |
+
traceback.print_exc()
|
195 |
+
finally:
|
196 |
+
# 清理资源
|
197 |
+
await file.close()
|
198 |
+
|
199 |
+
@router.post("/chat/completions")
|
200 |
+
async def chat_completions(
|
201 |
+
request: ChatRequest, app_secret: str = Depends(verify_app_secret)
|
202 |
+
):
|
203 |
+
logger.info("Entering chat_completions route")
|
204 |
+
# logger.info(f"Received request: {request}")
|
205 |
+
logger.info(f"Received request json format: {json.dumps(request.dict(), indent=4)}")
|
206 |
+
logger.info(f"App secret: {app_secret}")
|
207 |
+
logger.info(f"Received chat completion request for model: {request.model}")
|
208 |
+
|
209 |
+
if request.model not in [model["id"] for model in ALLOWED_MODELS]:
|
210 |
+
raise HTTPException(
|
211 |
+
status_code=400,
|
212 |
+
detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(model['id'] for model in ALLOWED_MODELS)}",
|
213 |
+
)
|
214 |
+
|
215 |
+
if request.stream:
|
216 |
+
logger.info("Streaming response")
|
217 |
+
return StreamingResponse(
|
218 |
+
process_streaming_response(request, app_secret),
|
219 |
+
media_type="text/event-stream",
|
220 |
+
headers={
|
221 |
+
"Cache-Control": "no-cache",
|
222 |
+
"Connection": "keep-alive",
|
223 |
+
"Transfer-Encoding": "chunked"
|
224 |
+
}
|
225 |
+
)
|
226 |
+
else:
|
227 |
+
logger.info("Non-streaming response")
|
228 |
+
# return await process_non_streaming_response(request)
|
229 |
+
|
230 |
+
@router.route('/')
|
231 |
+
@router.route('/healthz')
|
232 |
+
@router.route('/ready')
|
233 |
+
@router.route('/alive')
|
234 |
+
@router.route('/status')
|
235 |
+
@router.get("/health")
|
236 |
+
async def health_check(request: Request):
|
237 |
+
return Response(content=json.dumps({"status": "ok"}), media_type="application/json")
|
238 |
+
|
239 |
+
@router.post("/env")
|
240 |
+
async def environment(app_secret: str = Depends(verify_app_secret)):
|
241 |
+
return Response(content=json.dumps({"token": os.getenv("TOKEN", ""), "refresh_token": os.getenv("REFRESH_TOKEN", ""), "key": os.getenv("FIREBASE_API_KEY", "")}), media_type="application/json")
|
core/utils.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import codecs
|
2 |
+
import hashlib
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import ssl
|
6 |
+
import uuid
|
7 |
+
from datetime import datetime
|
8 |
+
# from http.client import HTTPException
|
9 |
+
from typing import Dict, Any, Optional
|
10 |
+
|
11 |
+
import httpx
|
12 |
+
from fastapi import HTTPException
|
13 |
+
from httpx import ConnectError, TransportError
|
14 |
+
from starlette import status
|
15 |
+
|
16 |
+
from core.config import get_settings
|
17 |
+
from core.logger import setup_logger
|
18 |
+
from core.models import ChatRequest
|
19 |
+
|
20 |
+
settings = get_settings()
|
21 |
+
logger = setup_logger(__name__)
|
22 |
+
|
23 |
+
def decode_unicode_escape(s):
|
24 |
+
# 检查输入是否为字典类型
|
25 |
+
if isinstance(s, dict):
|
26 |
+
return s
|
27 |
+
# 如果需要,将输入转换为字符串
|
28 |
+
if not isinstance(s, (str, bytes)):
|
29 |
+
s = str(s)
|
30 |
+
# 如果是字符串,转换为字节
|
31 |
+
if isinstance(s, str):
|
32 |
+
s = s.encode('utf-8')
|
33 |
+
return codecs.decode(s, 'unicode_escape')
|
34 |
+
|
35 |
+
FIREBASE_API_KEY = settings.FIREBASE_API_KEY
|
36 |
+
async def refresh_token_via_rest(refresh_token):
|
37 |
+
# Firebase Auth REST API endpoint
|
38 |
+
url = f"https://securetoken.googleapis.com/v1/token?key={FIREBASE_API_KEY}"
|
39 |
+
|
40 |
+
payload = {
|
41 |
+
'grant_type': 'refresh_token',
|
42 |
+
'refresh_token': refresh_token
|
43 |
+
}
|
44 |
+
|
45 |
+
try:
|
46 |
+
async with httpx.AsyncClient() as client:
|
47 |
+
response = await client.post(url, json=payload)
|
48 |
+
if response.status_code == 200:
|
49 |
+
data = response.json()
|
50 |
+
print(json.dumps(data, indent=2))
|
51 |
+
# return {
|
52 |
+
# 'id_token': data['id_token'],
|
53 |
+
# 'refresh_token': data.get('refresh_token'),
|
54 |
+
# 'expires_in': data['expires_in']
|
55 |
+
# }
|
56 |
+
return data['id_token']
|
57 |
+
else:
|
58 |
+
print(f"刷新失败: {response.text}")
|
59 |
+
return None
|
60 |
+
except Exception as e:
|
61 |
+
print(f"请求异常: {e}")
|
62 |
+
return None
|
63 |
+
|
64 |
+
|
65 |
+
async def sign_in_with_idp():
|
66 |
+
url = "https://identitytoolkit.googleapis.com/v1/accounts:signInWithIdp"
|
67 |
+
|
68 |
+
# 查询参数
|
69 |
+
params = {
|
70 |
+
"key": FIREBASE_API_KEY
|
71 |
+
}
|
72 |
+
|
73 |
+
# 请求头
|
74 |
+
headers = {
|
75 |
+
"X-Client-Version": "Node/JsCore/10.5.2/FirebaseCore-web",
|
76 |
+
"X-Firebase-gmpid": "1:123807869619:web:43b278a622ed6322789ec6",
|
77 |
+
"Content-Type": "application/json",
|
78 |
+
"User-Agent": "node-fetch/1.0 (+https://github.com/bitinn/node-fetch)"
|
79 |
+
}
|
80 |
+
|
81 |
+
# 请求体
|
82 |
+
data = {
|
83 |
+
"requestUri": "http://localhost",
|
84 |
+
"returnSecureToken": True,
|
85 |
+
"postBody": f"&id_token={settings.AUTHORIZATION_TOKEN}&providerId=google.com"
|
86 |
+
}
|
87 |
+
print("Request Headers:", json.dumps(headers, indent=2)) # 格式化打印
|
88 |
+
print("Request Body:", json.dumps(data, indent=2)) # 格式化打印
|
89 |
+
print("Request params:", json.dumps(params, indent=2)) # 格式化打印
|
90 |
+
|
91 |
+
async with httpx.AsyncClient() as client:
|
92 |
+
response = await client.post(
|
93 |
+
url,
|
94 |
+
params=params,
|
95 |
+
headers=headers,
|
96 |
+
json=data
|
97 |
+
)
|
98 |
+
# 检查状态码
|
99 |
+
if response.status_code == 200:
|
100 |
+
return response.json()
|
101 |
+
else:
|
102 |
+
raise Exception(f"Request failed with status code: {response.status_code}")
|
103 |
+
|
104 |
+
async def handle_firebase_response(response) -> str:
|
105 |
+
try:
|
106 |
+
# 如果响应是字典(已经解析的 JSON)
|
107 |
+
if isinstance(response, dict):
|
108 |
+
print(json.dumps(response, indent=2))
|
109 |
+
if response.get('error', {}).get('code') == 400:
|
110 |
+
print("Invalid id_token in IdP response")
|
111 |
+
# 保存refresh_token到配置中
|
112 |
+
if 'refreshToken' in response:
|
113 |
+
os.environ["REFRESH_TOKEN"] = response['refreshToken']
|
114 |
+
if 'idToken' in response:
|
115 |
+
return response['idToken']
|
116 |
+
else:
|
117 |
+
raise ValueError("dict case Response does not contain idToken")
|
118 |
+
|
119 |
+
# 如果响应是 Response 对象
|
120 |
+
elif hasattr(response, 'status_code'):
|
121 |
+
if response.status_code == 200:
|
122 |
+
data = response.json()
|
123 |
+
print(data)
|
124 |
+
# 保存refresh_token到配置中
|
125 |
+
if 'refreshToken' in data:
|
126 |
+
os.environ["REFRESH_TOKEN"] = data['refreshToken']
|
127 |
+
if 'idToken' in data:
|
128 |
+
return data['idToken']
|
129 |
+
else:
|
130 |
+
raise ValueError("response case Response does not contain idToken")
|
131 |
+
|
132 |
+
# 处理其他状态码
|
133 |
+
elif response.status_code == 400:
|
134 |
+
error_data = response.json()
|
135 |
+
raise ValueError(f"Bad Request: {error_data.get('error', {}).get('message', 'Unknown error')}")
|
136 |
+
elif response.status_code == 401:
|
137 |
+
raise ValueError("Unauthorized: Invalid credentials")
|
138 |
+
elif response.status_code == 403:
|
139 |
+
raise ValueError("Forbidden: Insufficient permissions")
|
140 |
+
elif response.status_code == 404:
|
141 |
+
raise ValueError("Not Found: Resource doesn't exist")
|
142 |
+
else:
|
143 |
+
raise ValueError(f"Unexpected status code: {response.status_code}")
|
144 |
+
|
145 |
+
else:
|
146 |
+
raise ValueError(f"Unexpected response type: {type(response)}")
|
147 |
+
|
148 |
+
except json.JSONDecodeError:
|
149 |
+
raise ValueError("Invalid JSON response")
|
150 |
+
except Exception as e:
|
151 |
+
raise ValueError(f"Error processing response: {str(e)}")
|
152 |
+
|
153 |
+
# SHA-256
|
154 |
+
def _sha256_hash(text):
|
155 |
+
sha256 = hashlib.sha256()
|
156 |
+
sha256.update(text.encode('utf-8'))
|
157 |
+
return sha256.hexdigest()
|
158 |
+
|
159 |
+
# 处理字典列表
|
160 |
+
def sha256_hash_messages(messages):
|
161 |
+
# 只提取 role 为 "user" 的消息的 content 字段
|
162 |
+
message_data = [str(msg['content']) for msg in messages if msg['role'] == "user"]
|
163 |
+
print("Filtered contents:", message_data) # 调试用
|
164 |
+
json_str = json.dumps(message_data, sort_keys=True)
|
165 |
+
print("JSON string:", json_str) # 调试用
|
166 |
+
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
|
167 |
+
|
168 |
+
|
169 |
+
def create_chat_completion_data(
|
170 |
+
content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
|
171 |
+
) -> Dict[str, Any]:
|
172 |
+
return {
|
173 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
174 |
+
"object": "chat.completion.chunk",
|
175 |
+
"created": timestamp,
|
176 |
+
"model": model,
|
177 |
+
"choices": [
|
178 |
+
{
|
179 |
+
"index": 0,
|
180 |
+
"delta": {"content": content, "role": "assistant"},
|
181 |
+
"finish_reason": finish_reason,
|
182 |
+
}
|
183 |
+
],
|
184 |
+
"usage": None,
|
185 |
+
}
|
186 |
+
|
187 |
+
async def process_streaming_response(request: ChatRequest, app_secret: str):
|
188 |
+
# 创建自定义 SSL 上下文
|
189 |
+
ssl_context = ssl.create_default_context()
|
190 |
+
ssl_context.check_hostname = True
|
191 |
+
ssl_context.verify_mode = ssl.CERT_REQUIRED
|
192 |
+
async with httpx.AsyncClient(
|
193 |
+
verify=ssl_context,
|
194 |
+
# timeout=30.0, # 增加超时时间
|
195 |
+
# http2=True # 启用 HTTP/2
|
196 |
+
) as client:
|
197 |
+
try:
|
198 |
+
request_headers = {**settings.HEADERS, 'authorization': f"Bearer {os.getenv('TOKEN', '')}"} # 从环境变量中获取新的TOKEN
|
199 |
+
|
200 |
+
# 直接使用 request.model_dump() 或 request.dict() 获取字典格式的数据
|
201 |
+
request_data = request.model_dump() # 如果使用较新版本的 Pydantic
|
202 |
+
# # 获取请求数据
|
203 |
+
# request_data = {
|
204 |
+
# "model": request.model,
|
205 |
+
# "messages": [msg.dict() for msg in request.messages],
|
206 |
+
# "temperature": request.temperature,
|
207 |
+
# "top_p": request.top_p,
|
208 |
+
# "max_tokens": request.max_tokens,
|
209 |
+
# "stream": request.stream
|
210 |
+
# }
|
211 |
+
print("Request Headers:", json.dumps(request_headers, indent=2)) # 格式化打印
|
212 |
+
print("Request Body:", json.dumps(request.json(), indent=4, ensure_ascii=False)) # 格式化打印
|
213 |
+
async with client.stream(
|
214 |
+
"POST",
|
215 |
+
f"https://api.thinkbuddy.ai/v1/chat/completions",
|
216 |
+
headers=request_headers,
|
217 |
+
json=request_data,
|
218 |
+
timeout=100,
|
219 |
+
) as response:
|
220 |
+
response.raise_for_status()
|
221 |
+
timestamp = int(datetime.now().timestamp())
|
222 |
+
async for line in response.aiter_lines():
|
223 |
+
print(f"{type(line)}: {line}")
|
224 |
+
if line and line.startswith("data: "):
|
225 |
+
try:
|
226 |
+
if line.strip() == 'data: [DONE]':
|
227 |
+
await response.aclose()
|
228 |
+
break
|
229 |
+
data_str = line[6:] # 去掉 'data: ' 前缀
|
230 |
+
|
231 |
+
# 解析JSON
|
232 |
+
json_data = json.loads(data_str)
|
233 |
+
if 'choices' in json_data and len(json_data['choices']) > 0:
|
234 |
+
delta = json_data['choices'][0].get('delta', {})
|
235 |
+
if 'content' in delta:
|
236 |
+
print(delta['content'], end='', flush=True)
|
237 |
+
yield f"data: {json.dumps(create_chat_completion_data(delta['content'], request.model, timestamp))}\n\n"
|
238 |
+
|
239 |
+
except json.JSONDecodeError as e:
|
240 |
+
print(f"JSON解析错误: {e}")
|
241 |
+
print(f"原始数据: {line}")
|
242 |
+
continue
|
243 |
+
|
244 |
+
yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
|
245 |
+
yield "data: [DONE]\n\n"
|
246 |
+
except ConnectError as e:
|
247 |
+
logger.error(f"Connection error details: {str(e)}")
|
248 |
+
raise HTTPException(
|
249 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
250 |
+
detail="Service temporarily unavailable. Please try again later."
|
251 |
+
)
|
252 |
+
except TransportError as e:
|
253 |
+
logger.error(f"Transport error details: {str(e)}")
|
254 |
+
raise HTTPException(
|
255 |
+
status_code=status.HTTP_502_BAD_GATEWAY,
|
256 |
+
detail="Network transport error occurred."
|
257 |
+
)
|
258 |
+
except httpx.HTTPStatusError as e:
|
259 |
+
# 这里需要处理401错误
|
260 |
+
logger.error(f"HTTP error occurred: {e}")
|
261 |
+
raise HTTPException(status_code=e.response.status_code, detail=str(e))
|
262 |
+
except httpx.RequestError as e:
|
263 |
+
logger.error(f"Error occurred during request: {e}")
|
264 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
265 |
+
finally:
|
266 |
+
await response.aclose()
|
267 |
+
|
main.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.app import app
|
2 |
+
import uvicorn
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
uvicorn.run(app, host="0.0.0.0", port=8001, workers=1, loop="asyncio")
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
httpx
|
3 |
+
pydantic
|
4 |
+
pydantic_settings
|
5 |
+
pyinstaller
|
6 |
+
python-dotenv
|
7 |
+
Requests
|
8 |
+
starlette
|
9 |
+
uvicorn
|
10 |
+
|
11 |
+
playsound
|