kevin commited on
Commit
feb939c
·
1 Parent(s): a66c30a

think buddy

Browse files
Files changed (13) hide show
  1. Dockerfile +17 -0
  2. README.md +1 -0
  3. core/__init__.py +0 -0
  4. core/app.py +69 -0
  5. core/auth.py +12 -0
  6. core/config.py +79 -0
  7. core/logger.py +20 -0
  8. core/models.py +16 -0
  9. core/refresh_token.py +105 -0
  10. core/router.py +241 -0
  11. core/utils.py +267 -0
  12. main.py +6 -0
  13. requirements.txt +11 -0
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10.16-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # 复制所需文件到容器中
6
+ COPY . .
7
+
8
+
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+ ENV APP_SECRET="sk-123456"
11
+ ENV REQUEST_TIMEOUT=30
12
+
13
+ # Expose port
14
+ EXPOSE 8001
15
+
16
+ # Run the application
17
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8001"]
README.md CHANGED
@@ -5,6 +5,7 @@ colorFrom: blue
5
  colorTo: red
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
5
  colorTo: red
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8001
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
core/__init__.py ADDED
File without changes
core/app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ from fastapi import FastAPI, Request
4
+ from starlette.middleware.cors import CORSMiddleware
5
+ from starlette.middleware.trustedhost import TrustedHostMiddleware
6
+ from fastapi.responses import JSONResponse
7
+ from core.config import get_settings
8
+ from core.logger import setup_logger
9
+ from core.refresh_token import TokenManager
10
+ from core.router import router
11
+
12
+ settings = get_settings()
13
+ logger = setup_logger(__name__)
14
+
15
+ # print(settings.SECRET)
16
+ def create_app() -> FastAPI:
17
+ app = FastAPI(
18
+ title=settings.PROJECT_NAME,
19
+ version="0.0.1",
20
+ description=settings.DESCRIPTION,
21
+ )
22
+ # 配置中间件
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ # # 添加可信主机中间件
32
+ app.add_middleware(
33
+ TrustedHostMiddleware,
34
+ allowed_hosts=["*"] # 在生产环境中应该限制允许的主机
35
+ )
36
+ # 添加路由
37
+ app.include_router(router, prefix="/api/v1")
38
+ app.include_router(router, prefix="/v1") # 兼容性路由
39
+
40
+ @app.exception_handler(Exception)
41
+ async def global_exception_handler(request: Request, exc: Exception):
42
+ logger.error(f"An error occurred: {str(exc)}", exc_info=True)
43
+ return JSONResponse(
44
+ status_code=500,
45
+ content={
46
+ "message": "An internal server error occurred.",
47
+ "detail": str(exc)
48
+ },
49
+ )
50
+
51
+ # # 创建 TokenManager 实例
52
+ token_manager = TokenManager()
53
+ @app.on_event("startup")
54
+ async def startup_event():
55
+ # 在应用启动时创建任务
56
+ app.state.refresh_task = asyncio.create_task(token_manager.start_auto_refresh())
57
+
58
+ @app.on_event("shutdown")
59
+ async def shutdown_event():
60
+ # 在应用关闭时取消任务
61
+ if hasattr(app.state, 'refresh_task'):
62
+ app.state.refresh_task.cancel()
63
+ try:
64
+ await app.state.refresh_task
65
+ except asyncio.CancelledError:
66
+ pass
67
+ return app
68
+
69
+ app = create_app()
core/auth.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Depends, HTTPException
2
+ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
3
+ from core.config import get_settings
4
+
5
+ settings = get_settings()
6
+ APP_SECRET = settings.APP_SECRET
7
+ security = HTTPBearer()
8
+
9
+ def verify_app_secret(credentials: HTTPAuthorizationCredentials = Depends(security)):
10
+ if credentials.credentials != APP_SECRET:
11
+ raise HTTPException(status_code=403, detail="Invalid SECRET")
12
+ return credentials.credentials
core/config.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import uuid
4
+ from typing import List, Dict
5
+
6
+ import httpx
7
+ from dotenv import load_dotenv
8
+ from pydantic_settings import BaseSettings
9
+
10
+ load_dotenv()
11
+
12
+
13
+ class Settings(BaseSettings):
14
+ APP_SECRET: str = os.getenv('APP_SECRET', '')
15
+ TOKEN: str = os.getenv('TOKEN', '')
16
+ PROJECT_NAME: str = os.getenv('PROJECT_NAME', 'FastAPI')
17
+ DESCRIPTION: str = os.getenv('DESCRIPTION', 'FastAPI template')
18
+ FIREBASE_API_KEY: str = os.getenv('FIREBASE_API_KEY', '')
19
+ REFRESH_TOKEN: str = os.getenv('REFRESH_TOKEN', '')
20
+ AUTHORIZATION_TOKEN: str = os.getenv('AUTHORIZATION_TOKEN', '')
21
+
22
+ ALLOWED_MODELS: List[Dict[str, str]] = [
23
+ {"id": "gpt-4o", "name": "GPT-4o [thinkbuddy]"},
24
+ {"id": "claude-3-5-sonnet", "name": "Claude 3.5 Sonnet v2 [thinkbuddy]"},
25
+ {"id": "gemini-2-flash", "name": "Gemini 2 Flash [thinkbuddy]"},
26
+ {"id": "nova-pro", "name": "Nova Pro [thinkbuddy]"},
27
+ {"id": "deepseek-v3", "name": "DeepSeek v3 [thinkbuddy]"},
28
+ {"id": "llama-3-3", "name": "Llama 3.3 (70B) [thinkbuddy]"},
29
+ {"id": "mistral-large-2", "name": "Mistral Large v2 [thinkbuddy]"},
30
+ {"id": "command-r-plus", "name": "Command R+ [thinkbuddy]"},
31
+ {"id": "o1-preview", "name": "o1-preview [thinkbuddy]"},
32
+ {"id": "o1-mini", "name": "o1-mini [thinkbuddy]"},
33
+ {"id": "gemini-2-thinking", "name": "Gemini 2 Thinking [thinkbuddy]"},
34
+ {"id": "claude-3-5-haiku", "name": "Claude 3.5 Haiku [thinkbuddy]"},
35
+ {"id": "gemini-1-5-flash", "name": "Gemini 1.5 Flash [thinkbuddy]"},
36
+ {"id": "gpt-4o-mini", "name": "GPT-4o mini [thinkbuddy]"},
37
+ {"id": "nova-lite", "name": "Nova Lite [thinkbuddy]"},
38
+ {"id": "gemini-1-5-pro", "name": "Gemini 1.5 Pro [thinkbuddy]"},
39
+ {"id": "claude-3-opus", "name": "Claude 3 Opus [thinkbuddy]"},
40
+ {"id": "gpt-4-turbo", "name": "GPT-4 Turbo [thinkbuddy]"},
41
+ {"id": "llama-3-1", "name": "Llama 3.1 [thinkbuddy]"}
42
+ ]
43
+ MODEL_MAPPING: Dict[str, str] = {
44
+ "gpt-4o": "gpt-4o",
45
+ "o1-preview": "o1-preview",
46
+ "claude-3-5-sonnet": "claude-3-5-sonnet",
47
+ "o1-mini": "o1-mini",
48
+ "gemini-1.5-pro": "gemini-1.5-pro",
49
+ "gemini-2.0-flash": "gemini-2.0-flash",
50
+ }
51
+ HEADERS: Dict[str, str] = {
52
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36',
53
+ 'Accept-Encoding': 'gzip, deflate, br, zstd',
54
+ 'Content-Type': 'application/json',
55
+ 'sec-ch-ua-platform': 'Windows',
56
+ 'authorization': f"Bearer {os.getenv('TOKEN', '')}",
57
+ 'sec-ch-ua': '"Google Chrome";v="131", "Chromium";v="131", "Not_A Brand";v="24"',
58
+ 'dnt': '1',
59
+ 'sec-ch-ua-mobile': '?0',
60
+ 'origin': 'https://thinkbuddy.ai',
61
+ 'sec-fetch-site': 'same-site',
62
+ 'sec-fetch-mode': 'cors',
63
+ 'sec-fetch-dest': 'empty',
64
+ 'referer': 'https://thinkbuddy.ai/',
65
+ 'accept-language': 'en-US,en;q=0.9,zh-CN;q=0.8,zh-TW;q=0.7,zh;q=0.6',
66
+ 'priority': 'u=1, i'
67
+ }
68
+
69
+ class Config:
70
+ env_file = '.env'
71
+ case_sensitive = True
72
+
73
+ _settings = None
74
+
75
+ def get_settings():
76
+ global _settings
77
+ if _settings is None:
78
+ _settings = Settings()
79
+ return _settings
core/logger.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ def setup_logger(name):
4
+ logger = logging.getLogger(name)
5
+ if not logger.handlers:
6
+ logger.setLevel(logging.INFO)
7
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
8
+
9
+ # 控制台处理器
10
+ console_handler = logging.StreamHandler()
11
+ console_handler.setFormatter(formatter)
12
+ logger.addHandler(console_handler)
13
+
14
+ # 文件处理器 - 错误级别
15
+ # error_file_handler = logging.FileHandler('error.log')
16
+ # error_file_handler.setFormatter(formatter)
17
+ # error_file_handler.setLevel(logging.ERROR)
18
+ # logger.addHandler(error_file_handler)
19
+
20
+ return logger
core/models.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from pydantic import BaseModel
3
+
4
+
5
+ class Message(BaseModel):
6
+ role: str
7
+ content: str | list
8
+
9
+
10
+ class ChatRequest(BaseModel):
11
+ model: str
12
+ messages: List[Message]
13
+ stream: Optional[bool] = False
14
+ temperature: Optional[float] = 0.7
15
+ top_p: Optional[float] = 0.9
16
+ max_tokens: Optional[int] = 8192
core/refresh_token.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 用于刷新idToken的定时任务
2
+ import asyncio
3
+ import os
4
+ import time
5
+ from typing import Optional
6
+ from datetime import datetime
7
+ from core.utils import sign_in_with_idp, handle_firebase_response, refresh_token_via_rest
8
+
9
+
10
+ class TokenManager:
11
+ def __init__(self):
12
+ self.id_token: Optional[str] = None
13
+ self.last_refresh_time: Optional[float] = None
14
+ self.refresh_interval = 30 * 60 # 30分钟,单位:秒
15
+ self.is_running = False
16
+
17
+ async def get_token(self) -> str:
18
+ """
19
+ 获取当前的 idToken,如果不存在或已过期则刷新
20
+ """
21
+ if not self.id_token or self._should_refresh():
22
+ await self.refresh_token()
23
+ return self.id_token
24
+
25
+ def _should_refresh(self) -> bool:
26
+ """
27
+ 检查是否需要刷新 token
28
+ """
29
+ if not self.last_refresh_time:
30
+ return True
31
+ return time.time() - self.last_refresh_time >= self.refresh_interval
32
+
33
+ async def refresh_token(self):
34
+ """
35
+ 刷新 idToken
36
+ """
37
+ try:
38
+ if os.getenv("REFRESH_TOKEN", "") == "" or os.getenv("REFRESH_TOKEN", "") == "None":
39
+ response = await sign_in_with_idp()
40
+ result = await handle_firebase_response(response) # idToken 实际就是bearer token
41
+ else:
42
+ result = await refresh_token_via_rest(os.getenv("REFRESH_TOKEN"))
43
+ if result is not None:
44
+ self.id_token = result
45
+ # 修改配置中的 TOKEN
46
+ print(f"Before Token is {os.getenv('TOKEN', '')}")
47
+ os.environ["TOKEN"] = self.id_token
48
+ print(f"Now Token is {os.getenv('TOKEN', '')}")
49
+ self.last_refresh_time = time.time()
50
+ print(f"Token refreshed at {datetime.now()}")
51
+ else:
52
+ print(f"Failed to refresh token: {result['error']}")
53
+ except Exception as e:
54
+ print(f"Error refreshing token: {str(e)}")
55
+
56
+ async def start_auto_refresh(self):
57
+ """
58
+ 启动自动刷新任务
59
+ """
60
+ if self.is_running:
61
+ return
62
+
63
+ self.is_running = True
64
+ while self.is_running:
65
+ try:
66
+ await self.refresh_token()
67
+ # 等待到下次刷新时间
68
+ await asyncio.sleep(self.refresh_interval)
69
+ except Exception as e:
70
+ print(f"Auto refresh error: {str(e)}")
71
+ # 发生错误时等待短暂时间后重试
72
+ await asyncio.sleep(60)
73
+
74
+ async def stop_auto_refresh(self):
75
+ """
76
+ 停止自动刷新任务
77
+ """
78
+ self.is_running = False
79
+
80
+ # 使用示例
81
+ # async def main():
82
+ # # 创建 TokenManager 实例
83
+ # token_manager = TokenManager()
84
+ #
85
+ # try:
86
+ # # 启动自动刷新任务
87
+ # refresh_task = asyncio.create_task(token_manager.start_auto_refresh())
88
+ #
89
+ # # 模拟应用运行
90
+ # while True:
91
+ # # 获取当前 token
92
+ # token = await token_manager.get_token()
93
+ # print(f"Current token: {token[:20]}...")
94
+ #
95
+ # # 等待一段时间再次获取
96
+ # await asyncio.sleep(300) # 每5分钟打印一次当前token
97
+ #
98
+ # except KeyboardInterrupt:
99
+ # # 处理 Ctrl+C
100
+ # await token_manager.stop_auto_refresh()
101
+ # await refresh_task
102
+ #
103
+ # # 运行示例
104
+ # if __name__ == "__main__":
105
+ # asyncio.run(main())
core/router.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import os
4
+ import traceback
5
+ from typing import Optional
6
+ from uuid import uuid4
7
+
8
+ import httpx
9
+ from fastapi import File, UploadFile
10
+ from fastapi import APIRouter, Response, Request, Depends, HTTPException
11
+ from fastapi.responses import StreamingResponse
12
+
13
+
14
+ from core.auth import verify_app_secret
15
+ from core.config import get_settings
16
+ from core.logger import setup_logger
17
+ from core.models import ChatRequest
18
+ from core.utils import process_streaming_response
19
+ from playsound import playsound # 用于播放音频
20
+
21
+ logger = setup_logger(__name__)
22
+ router = APIRouter()
23
+ ALLOWED_MODELS = get_settings().ALLOWED_MODELS
24
+
25
+ @router.get("/models")
26
+ async def list_models():
27
+ return {"object": "list", "data": ALLOWED_MODELS, "success": True}
28
+
29
+ @router.options("/chat/completions")
30
+ async def chat_completions_options():
31
+ return Response(
32
+ status_code=200,
33
+ headers={
34
+ "Access-Control-Allow-Origin": "*",
35
+ "Access-Control-Allow-Methods": "POST, OPTIONS",
36
+ "Access-Control-Allow-Headers": "Content-Type, Authorization",
37
+ },
38
+ )
39
+ # 识图
40
+
41
+ # 识图
42
+ # 文本转语音
43
+ @router.post("/audio/speech")
44
+ async def speech(request: Request):
45
+ url = 'https://api.thinkbuddy.ai/v1/content/speech/tts'
46
+ request_headers = {**get_settings().HEADERS,
47
+ 'authorization': f"Bearer {os.getenv('TOKEN', '')}",
48
+ 'Accept': 'application/json, text/plain, */*',
49
+ }
50
+ # data = {
51
+ # "input": "这是一张插图,显示了一杯饮料,可能是奶昔、冰沙或其他冷饮。杯子上有一个盖子和一根吸管,表明这是一种便于携带和饮用的饮品。这种设计通常用于提供咖啡、冰茶或果汁等饮品。杯子颜色简约,可能用于说明饮品的内容或品牌。",
52
+ # "voice": "nova" # alloy echo fable onyx nova shimmer
53
+ # }
54
+ body = await request.json()
55
+ try:
56
+ async with httpx.AsyncClient(http2=True) as client:
57
+ response = await client.post(url, headers=request_headers, json=body)
58
+ response.raise_for_status()
59
+
60
+ # 假设响应是音频数据,保存为文件
61
+ if response.status_code == 200:
62
+ # 保存音频文件
63
+ with open('output.mp3', 'wb') as f:
64
+ f.write(response.content)
65
+ print("音频文件已保存为 output.mp3")
66
+
67
+ # 异步播放音频
68
+ # 使用 asyncio.to_thread 来避免阻塞事件循环
69
+ # await asyncio.to_thread(playsound, 'output.mp3')
70
+ return True
71
+ else:
72
+ print(f"请求失败,状态码: {response.status_code}")
73
+ print(f"响应内容: {response.text}")
74
+ return False
75
+
76
+ except httpx.RequestError as e:
77
+ print(f"请求错误: {e}")
78
+ print("错误堆栈:")
79
+ traceback.print_exc()
80
+ return False
81
+ except httpx.HTTPStatusError as e:
82
+ print(f"HTTP 错误: {e}")
83
+ print("错误堆栈:")
84
+ traceback.print_exc()
85
+ return False
86
+ except Exception as e:
87
+ print(f"发生错误: {e}")
88
+ print("错误堆栈:")
89
+ traceback.print_exc()
90
+ return False
91
+
92
+
93
+ # 语音转文本
94
+ @router.post("/audio/transcriptions")
95
+ async def transcriptions(request: Request, file: UploadFile = File(...)):
96
+ url = 'https://api.thinkbuddy.ai/v1/content/transcribe'
97
+ params = {'enhance': 'true'}
98
+ try:
99
+ # 读取文件内容
100
+ content = await safe_read_file(file)
101
+ # 获取原始 content-type
102
+ content_type = request.headers.get('content-type')
103
+ # files = {
104
+ # 'file': (str(uuid4()),
105
+ # content,
106
+ # file.content_type or 'application/octet-stream')
107
+ # }
108
+ files = {
109
+ 'file': ('file.mp4', content, 'audio/mp4'),
110
+ 'model': (None, 'whisper-1')
111
+ }
112
+ # 记录请求信息
113
+ logger.info(f"Received upload request for file: {file.filename}")
114
+ logger.info(f"Content-Type: {request.headers.get('content-type')}")
115
+ request_headers = {**get_settings().HEADERS,
116
+ 'authorization': f"Bearer {os.getenv('TOKEN', '')}",
117
+ 'Accept': 'application/json, text/plain, */*',
118
+ 'Content-Type': content_type,
119
+ }
120
+ # 设置较长的超时时间
121
+ timeout = httpx.Timeout(
122
+ connect=30.0, # 连接超时
123
+ read=300.0, # 读取超时
124
+ write=30.0, # 写入超时
125
+ pool=30.0 # 连接池超时
126
+ )
127
+ # 使用httpx发送异步请求
128
+ async with httpx.AsyncClient(http2=True, timeout=timeout) as client:
129
+ response = await client.post(url,
130
+ params=params,
131
+ headers=request_headers,
132
+ files=files)
133
+ response.raise_for_status()
134
+ return response.json()
135
+
136
+ except httpx.TimeoutException:
137
+ raise HTTPException(status_code=504, detail="请求目标服务器超时")
138
+ except httpx.HTTPStatusError as e:
139
+ raise HTTPException(status_code=e.response.status_code, detail=str(e))
140
+ except Exception as e:
141
+ traceback.print_tb(e.__traceback__)
142
+ raise HTTPException(status_code=500, detail=str(e))
143
+ finally:
144
+ # 清理资源
145
+ await file.close()
146
+ async def safe_read_file(file: UploadFile) -> Optional[bytes]:
147
+ """安全地读取文件内容"""
148
+ try:
149
+ return await file.read()
150
+ except Exception as e:
151
+ logger.error(f"Error reading file: {str(e)}")
152
+ return None
153
+ # 文件上传
154
+ @router.post("/upload")
155
+ async def upload_file(request: Request, file: UploadFile = File(...)):
156
+ try:
157
+ # 读取文件内容
158
+ content = await safe_read_file(file)
159
+ # 获取原始 content-type
160
+ content_type = request.headers.get('content-type')
161
+ files = {
162
+ 'file': (
163
+ # str(uuid4()),
164
+ file.filename, # 使用原始文件名而不是 UUID
165
+ content,
166
+ file.content_type )
167
+ }
168
+ # 记录请求信息
169
+ logger.info(f"Received upload request for file: {file.filename}")
170
+ logger.info(f"Content-Type: {request.headers.get('content-type')}")
171
+ request_headers = {**get_settings().HEADERS,
172
+ 'authorization': f"Bearer {os.getenv('TOKEN', '')}",
173
+ 'Accept': 'application/json, text/plain, */*',
174
+ 'Content-Type': content_type,
175
+ }
176
+ # 使用httpx发送异步请求
177
+ async with httpx.AsyncClient() as client:
178
+ response = await client.post(f"https://api.thinkbuddy.ai/v1/uploads/images", headers=request_headers,files=files, timeout=100)
179
+ response.raise_for_status()
180
+ return response.json()
181
+
182
+ except httpx.TimeoutException:
183
+ raise HTTPException(status_code=504, detail="请求目标服务器超时")
184
+ except httpx.HTTPStatusError as e:
185
+ # raise HTTPException(status_code=e.response.status_code, detail=str(e))
186
+ print(f"HTTPStatusError发生错误: {e}")
187
+ print("错误堆栈:")
188
+ traceback.print_exc()
189
+ except Exception as e:
190
+ # traceback.print_tb(e.__traceback__)
191
+ # raise HTTPException(status_code=500, detail=str(e))
192
+ print(f"发生错误: {e}")
193
+ print("错误堆栈:")
194
+ traceback.print_exc()
195
+ finally:
196
+ # 清理资源
197
+ await file.close()
198
+
199
+ @router.post("/chat/completions")
200
+ async def chat_completions(
201
+ request: ChatRequest, app_secret: str = Depends(verify_app_secret)
202
+ ):
203
+ logger.info("Entering chat_completions route")
204
+ # logger.info(f"Received request: {request}")
205
+ logger.info(f"Received request json format: {json.dumps(request.dict(), indent=4)}")
206
+ logger.info(f"App secret: {app_secret}")
207
+ logger.info(f"Received chat completion request for model: {request.model}")
208
+
209
+ if request.model not in [model["id"] for model in ALLOWED_MODELS]:
210
+ raise HTTPException(
211
+ status_code=400,
212
+ detail=f"Model {request.model} is not allowed. Allowed models are: {', '.join(model['id'] for model in ALLOWED_MODELS)}",
213
+ )
214
+
215
+ if request.stream:
216
+ logger.info("Streaming response")
217
+ return StreamingResponse(
218
+ process_streaming_response(request, app_secret),
219
+ media_type="text/event-stream",
220
+ headers={
221
+ "Cache-Control": "no-cache",
222
+ "Connection": "keep-alive",
223
+ "Transfer-Encoding": "chunked"
224
+ }
225
+ )
226
+ else:
227
+ logger.info("Non-streaming response")
228
+ # return await process_non_streaming_response(request)
229
+
230
+ @router.route('/')
231
+ @router.route('/healthz')
232
+ @router.route('/ready')
233
+ @router.route('/alive')
234
+ @router.route('/status')
235
+ @router.get("/health")
236
+ async def health_check(request: Request):
237
+ return Response(content=json.dumps({"status": "ok"}), media_type="application/json")
238
+
239
+ @router.post("/env")
240
+ async def environment(app_secret: str = Depends(verify_app_secret)):
241
+ return Response(content=json.dumps({"token": os.getenv("TOKEN", ""), "refresh_token": os.getenv("REFRESH_TOKEN", ""), "key": os.getenv("FIREBASE_API_KEY", "")}), media_type="application/json")
core/utils.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import codecs
2
+ import hashlib
3
+ import json
4
+ import os
5
+ import ssl
6
+ import uuid
7
+ from datetime import datetime
8
+ # from http.client import HTTPException
9
+ from typing import Dict, Any, Optional
10
+
11
+ import httpx
12
+ from fastapi import HTTPException
13
+ from httpx import ConnectError, TransportError
14
+ from starlette import status
15
+
16
+ from core.config import get_settings
17
+ from core.logger import setup_logger
18
+ from core.models import ChatRequest
19
+
20
+ settings = get_settings()
21
+ logger = setup_logger(__name__)
22
+
23
+ def decode_unicode_escape(s):
24
+ # 检查输入是否为字典类型
25
+ if isinstance(s, dict):
26
+ return s
27
+ # 如果需要,将输入转换为字符串
28
+ if not isinstance(s, (str, bytes)):
29
+ s = str(s)
30
+ # 如果是字符串,转换为字节
31
+ if isinstance(s, str):
32
+ s = s.encode('utf-8')
33
+ return codecs.decode(s, 'unicode_escape')
34
+
35
+ FIREBASE_API_KEY = settings.FIREBASE_API_KEY
36
+ async def refresh_token_via_rest(refresh_token):
37
+ # Firebase Auth REST API endpoint
38
+ url = f"https://securetoken.googleapis.com/v1/token?key={FIREBASE_API_KEY}"
39
+
40
+ payload = {
41
+ 'grant_type': 'refresh_token',
42
+ 'refresh_token': refresh_token
43
+ }
44
+
45
+ try:
46
+ async with httpx.AsyncClient() as client:
47
+ response = await client.post(url, json=payload)
48
+ if response.status_code == 200:
49
+ data = response.json()
50
+ print(json.dumps(data, indent=2))
51
+ # return {
52
+ # 'id_token': data['id_token'],
53
+ # 'refresh_token': data.get('refresh_token'),
54
+ # 'expires_in': data['expires_in']
55
+ # }
56
+ return data['id_token']
57
+ else:
58
+ print(f"刷新失败: {response.text}")
59
+ return None
60
+ except Exception as e:
61
+ print(f"请求异常: {e}")
62
+ return None
63
+
64
+
65
+ async def sign_in_with_idp():
66
+ url = "https://identitytoolkit.googleapis.com/v1/accounts:signInWithIdp"
67
+
68
+ # 查询参数
69
+ params = {
70
+ "key": FIREBASE_API_KEY
71
+ }
72
+
73
+ # 请求头
74
+ headers = {
75
+ "X-Client-Version": "Node/JsCore/10.5.2/FirebaseCore-web",
76
+ "X-Firebase-gmpid": "1:123807869619:web:43b278a622ed6322789ec6",
77
+ "Content-Type": "application/json",
78
+ "User-Agent": "node-fetch/1.0 (+https://github.com/bitinn/node-fetch)"
79
+ }
80
+
81
+ # 请求体
82
+ data = {
83
+ "requestUri": "http://localhost",
84
+ "returnSecureToken": True,
85
+ "postBody": f"&id_token={settings.AUTHORIZATION_TOKEN}&providerId=google.com"
86
+ }
87
+ print("Request Headers:", json.dumps(headers, indent=2)) # 格式化打印
88
+ print("Request Body:", json.dumps(data, indent=2)) # 格式化打印
89
+ print("Request params:", json.dumps(params, indent=2)) # 格式化打印
90
+
91
+ async with httpx.AsyncClient() as client:
92
+ response = await client.post(
93
+ url,
94
+ params=params,
95
+ headers=headers,
96
+ json=data
97
+ )
98
+ # 检查状态码
99
+ if response.status_code == 200:
100
+ return response.json()
101
+ else:
102
+ raise Exception(f"Request failed with status code: {response.status_code}")
103
+
104
+ async def handle_firebase_response(response) -> str:
105
+ try:
106
+ # 如果响应是字典(已经解析的 JSON)
107
+ if isinstance(response, dict):
108
+ print(json.dumps(response, indent=2))
109
+ if response.get('error', {}).get('code') == 400:
110
+ print("Invalid id_token in IdP response")
111
+ # 保存refresh_token到配置中
112
+ if 'refreshToken' in response:
113
+ os.environ["REFRESH_TOKEN"] = response['refreshToken']
114
+ if 'idToken' in response:
115
+ return response['idToken']
116
+ else:
117
+ raise ValueError("dict case Response does not contain idToken")
118
+
119
+ # 如果响应是 Response 对象
120
+ elif hasattr(response, 'status_code'):
121
+ if response.status_code == 200:
122
+ data = response.json()
123
+ print(data)
124
+ # 保存refresh_token到配置中
125
+ if 'refreshToken' in data:
126
+ os.environ["REFRESH_TOKEN"] = data['refreshToken']
127
+ if 'idToken' in data:
128
+ return data['idToken']
129
+ else:
130
+ raise ValueError("response case Response does not contain idToken")
131
+
132
+ # 处理其他状态码
133
+ elif response.status_code == 400:
134
+ error_data = response.json()
135
+ raise ValueError(f"Bad Request: {error_data.get('error', {}).get('message', 'Unknown error')}")
136
+ elif response.status_code == 401:
137
+ raise ValueError("Unauthorized: Invalid credentials")
138
+ elif response.status_code == 403:
139
+ raise ValueError("Forbidden: Insufficient permissions")
140
+ elif response.status_code == 404:
141
+ raise ValueError("Not Found: Resource doesn't exist")
142
+ else:
143
+ raise ValueError(f"Unexpected status code: {response.status_code}")
144
+
145
+ else:
146
+ raise ValueError(f"Unexpected response type: {type(response)}")
147
+
148
+ except json.JSONDecodeError:
149
+ raise ValueError("Invalid JSON response")
150
+ except Exception as e:
151
+ raise ValueError(f"Error processing response: {str(e)}")
152
+
153
+ # SHA-256
154
+ def _sha256_hash(text):
155
+ sha256 = hashlib.sha256()
156
+ sha256.update(text.encode('utf-8'))
157
+ return sha256.hexdigest()
158
+
159
+ # 处理字典列表
160
+ def sha256_hash_messages(messages):
161
+ # 只提取 role 为 "user" 的消息的 content 字段
162
+ message_data = [str(msg['content']) for msg in messages if msg['role'] == "user"]
163
+ print("Filtered contents:", message_data) # 调试用
164
+ json_str = json.dumps(message_data, sort_keys=True)
165
+ print("JSON string:", json_str) # 调试用
166
+ return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
167
+
168
+
169
+ def create_chat_completion_data(
170
+ content: str, model: str, timestamp: int, finish_reason: Optional[str] = None
171
+ ) -> Dict[str, Any]:
172
+ return {
173
+ "id": f"chatcmpl-{uuid.uuid4()}",
174
+ "object": "chat.completion.chunk",
175
+ "created": timestamp,
176
+ "model": model,
177
+ "choices": [
178
+ {
179
+ "index": 0,
180
+ "delta": {"content": content, "role": "assistant"},
181
+ "finish_reason": finish_reason,
182
+ }
183
+ ],
184
+ "usage": None,
185
+ }
186
+
187
+ async def process_streaming_response(request: ChatRequest, app_secret: str):
188
+ # 创建自定义 SSL 上下文
189
+ ssl_context = ssl.create_default_context()
190
+ ssl_context.check_hostname = True
191
+ ssl_context.verify_mode = ssl.CERT_REQUIRED
192
+ async with httpx.AsyncClient(
193
+ verify=ssl_context,
194
+ # timeout=30.0, # 增加超时时间
195
+ # http2=True # 启用 HTTP/2
196
+ ) as client:
197
+ try:
198
+ request_headers = {**settings.HEADERS, 'authorization': f"Bearer {os.getenv('TOKEN', '')}"} # 从环境变量中获取新的TOKEN
199
+
200
+ # 直接使用 request.model_dump() 或 request.dict() 获取字典格式的数据
201
+ request_data = request.model_dump() # 如果使用较新版本的 Pydantic
202
+ # # 获取请求数据
203
+ # request_data = {
204
+ # "model": request.model,
205
+ # "messages": [msg.dict() for msg in request.messages],
206
+ # "temperature": request.temperature,
207
+ # "top_p": request.top_p,
208
+ # "max_tokens": request.max_tokens,
209
+ # "stream": request.stream
210
+ # }
211
+ print("Request Headers:", json.dumps(request_headers, indent=2)) # 格式化打印
212
+ print("Request Body:", json.dumps(request.json(), indent=4, ensure_ascii=False)) # 格式化打印
213
+ async with client.stream(
214
+ "POST",
215
+ f"https://api.thinkbuddy.ai/v1/chat/completions",
216
+ headers=request_headers,
217
+ json=request_data,
218
+ timeout=100,
219
+ ) as response:
220
+ response.raise_for_status()
221
+ timestamp = int(datetime.now().timestamp())
222
+ async for line in response.aiter_lines():
223
+ print(f"{type(line)}: {line}")
224
+ if line and line.startswith("data: "):
225
+ try:
226
+ if line.strip() == 'data: [DONE]':
227
+ await response.aclose()
228
+ break
229
+ data_str = line[6:] # 去掉 'data: ' 前缀
230
+
231
+ # 解析JSON
232
+ json_data = json.loads(data_str)
233
+ if 'choices' in json_data and len(json_data['choices']) > 0:
234
+ delta = json_data['choices'][0].get('delta', {})
235
+ if 'content' in delta:
236
+ print(delta['content'], end='', flush=True)
237
+ yield f"data: {json.dumps(create_chat_completion_data(delta['content'], request.model, timestamp))}\n\n"
238
+
239
+ except json.JSONDecodeError as e:
240
+ print(f"JSON解析错误: {e}")
241
+ print(f"原始数据: {line}")
242
+ continue
243
+
244
+ yield f"data: {json.dumps(create_chat_completion_data('', request.model, timestamp, 'stop'))}\n\n"
245
+ yield "data: [DONE]\n\n"
246
+ except ConnectError as e:
247
+ logger.error(f"Connection error details: {str(e)}")
248
+ raise HTTPException(
249
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
250
+ detail="Service temporarily unavailable. Please try again later."
251
+ )
252
+ except TransportError as e:
253
+ logger.error(f"Transport error details: {str(e)}")
254
+ raise HTTPException(
255
+ status_code=status.HTTP_502_BAD_GATEWAY,
256
+ detail="Network transport error occurred."
257
+ )
258
+ except httpx.HTTPStatusError as e:
259
+ # 这里需要处理401错误
260
+ logger.error(f"HTTP error occurred: {e}")
261
+ raise HTTPException(status_code=e.response.status_code, detail=str(e))
262
+ except httpx.RequestError as e:
263
+ logger.error(f"Error occurred during request: {e}")
264
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
265
+ finally:
266
+ await response.aclose()
267
+
main.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from core.app import app
2
+ import uvicorn
3
+
4
+
5
+ if __name__ == '__main__':
6
+ uvicorn.run(app, host="0.0.0.0", port=8001, workers=1, loop="asyncio")
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ httpx
3
+ pydantic
4
+ pydantic_settings
5
+ pyinstaller
6
+ python-dotenv
7
+ Requests
8
+ starlette
9
+ uvicorn
10
+
11
+ playsound