Spaces:
Running
Running
from fastapi import FastAPI, Request | |
from fastapi.responses import JSONResponse, HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from fastapi.middleware.cors import CORSMiddleware | |
from app.models.schemas import ErrorResponse | |
from app.services import GeminiClient | |
from app.utils import ( | |
APIKeyManager, | |
test_api_key, | |
ResponseCacheManager, | |
ActiveRequestsManager, | |
check_version, | |
schedule_cache_cleanup, | |
handle_exception, | |
log | |
) | |
from app.config.persistence import save_settings, load_settings | |
from app.api import router, init_router, dashboard_router, init_dashboard_router | |
from app.vertex.vertex_ai_init import init_vertex_ai | |
from app.vertex.credentials_manager import CredentialManager | |
import app.config.settings as settings | |
from app.config.safety import SAFETY_SETTINGS, SAFETY_SETTINGS_G2 | |
import asyncio | |
import sys | |
import pathlib | |
import os | |
# 设置模板目录 | |
BASE_DIR = pathlib.Path(__file__).parent | |
templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) | |
app = FastAPI(limit="50M") | |
# --------------- CORS 中间件 --------------- | |
# 如果 ALLOWED_ORIGINS 为空列表,则不允许任何跨域请求 | |
if settings.ALLOWED_ORIGINS: | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=settings.ALLOWED_ORIGINS, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# --------------- 全局实例 --------------- | |
load_settings() | |
# 初始化API密钥管理器 | |
key_manager = APIKeyManager() | |
# 创建全局缓存字典,将作为缓存管理器的内部存储 | |
response_cache = {} | |
# 初始化缓存管理器,使用全局字典作为存储 | |
response_cache_manager = ResponseCacheManager( | |
expiry_time=settings.CACHE_EXPIRY_TIME, | |
max_entries=settings.MAX_CACHE_ENTRIES, | |
cache_dict=response_cache | |
) | |
# 活跃请求池 - 将作为活跃请求管理器的内部存储 | |
active_requests_pool = {} | |
# 初始化活跃请求管理器 | |
active_requests_manager = ActiveRequestsManager(requests_pool=active_requests_pool) | |
SKIP_CHECK_API_KEY = os.environ.get("SKIP_CHECK_API_KEY", "").lower() == "true" | |
# --------------- 工具函数 --------------- | |
# @app.middleware("http") | |
# async def log_requests(request: Request, call_next): | |
# """ | |
# DEBUG用,接收并打印请求内容 | |
# """ | |
# log('info', f"接收到请求: {request.method} {request.url}") | |
# try: | |
# body = await request.json() | |
# log('info', f"请求体: {body}") | |
# except Exception: | |
# log('info', "请求体不是 JSON 格式或者为空") | |
# response = await call_next(request) | |
# return response | |
async def check_remaining_keys_async(keys_to_check: list, initial_invalid_keys: list): | |
""" | |
在后台异步检查剩余的 API 密钥。 | |
""" | |
local_invalid_keys = [] | |
found_valid_keys =False | |
log('info', f" 开始在后台检查剩余 API Key 是否有效") | |
for key in keys_to_check: | |
is_valid = await test_api_key(key) | |
if is_valid: | |
if key not in key_manager.api_keys: # 避免重复添加 | |
key_manager.api_keys.append(key) | |
found_valid_keys = True | |
# log('info', f"API Key {key[:8]}... 有效") | |
else: | |
local_invalid_keys.append(key) | |
log('warning', f" API Key {key[:8]}... 无效") | |
await asyncio.sleep(0.05) # 短暂休眠,避免请求过于密集 | |
if found_valid_keys: | |
key_manager._reset_key_stack() # 如果找到新的有效key,重置栈 | |
# 合并所有无效密钥 (初始无效 + 后台检查出的无效) | |
combined_invalid_keys = list(set(initial_invalid_keys + local_invalid_keys)) | |
# 获取当前设置中的无效密钥 | |
current_invalid_keys_str = settings.INVALID_API_KEYS or "" | |
current_invalid_keys_set = set(k.strip() for k in current_invalid_keys_str.split(',') if k.strip()) | |
# 更新无效密钥集合 | |
new_invalid_keys_set = current_invalid_keys_set.union(set(combined_invalid_keys)) | |
# 只有当无效密钥列表发生变化时才保存 | |
if new_invalid_keys_set != current_invalid_keys_set: | |
settings.INVALID_API_KEYS = ','.join(sorted(list(new_invalid_keys_set))) | |
save_settings() | |
log('info', f"密钥检查任务完成。当前总可用密钥数量: {len(key_manager.api_keys)}") | |
# 设置全局异常处理 | |
sys.excepthook = handle_exception | |
# --------------- 事件处理 --------------- | |
async def startup_event(): | |
# 首先加载持久化设置,确保所有配置都是最新的 | |
load_settings() | |
# 重新加载vertex配置,确保获取到最新的持久化设置 | |
import app.vertex.config as vertex_config | |
vertex_config.reload_config() | |
# 初始化CredentialManager | |
credential_manager_instance = CredentialManager() | |
# 添加到应用程序状态 | |
app.state.credential_manager = credential_manager_instance | |
# 初始化Vertex AI服务 | |
await init_vertex_ai(credential_manager=credential_manager_instance) | |
schedule_cache_cleanup(response_cache_manager, active_requests_manager) | |
# 检查版本 | |
await check_version() | |
# 密钥检查 | |
initial_keys = key_manager.api_keys.copy() | |
key_manager.api_keys = [] # 清空,等待检查结果 | |
first_valid_key = None | |
initial_invalid_keys = [] | |
keys_to_check_later = [] | |
# 阻塞式查找第一个有效密钥 | |
for index, key in enumerate(initial_keys): | |
is_valid = await test_api_key(key) | |
if is_valid: | |
log('info', f"找到第一个有效密钥: {key[:8]}...") | |
first_valid_key = key | |
key_manager.api_keys.append(key) # 添加到管理器 | |
key_manager._reset_key_stack() | |
# 将剩余的key放入后台检查列表 | |
keys_to_check_later = initial_keys[index + 1:] | |
break # 找到即停止 | |
else: | |
log('warning', f"密钥 {key[:8]}... 无效") | |
initial_invalid_keys.append(key) | |
if not first_valid_key: | |
log('error', "启动时未能找到任何有效 API 密钥!") | |
keys_to_check_later = [] # 没有有效key,无需后台检查 | |
else: | |
# 使用第一个有效密钥加载模型 | |
try: | |
all_models = await GeminiClient.list_available_models(first_valid_key) | |
GeminiClient.AVAILABLE_MODELS = [model.replace("models/", "") for model in all_models] | |
log('info', f"使用密钥 {first_valid_key[:8]}... 加载可用模型成功") | |
except Exception as e: | |
log('warning', f"使用密钥 {first_valid_key[:8]}... 加载可用模型失败",extra={'error_message': str(e)}) | |
if not SKIP_CHECK_API_KEY: | |
# 创建后台任务检查剩余密钥 | |
if keys_to_check_later: | |
asyncio.create_task(check_remaining_keys_async(keys_to_check_later, initial_invalid_keys)) | |
else: | |
# 如果没有需要后台检查的key,也要处理初始无效key | |
current_invalid_keys_str = settings.INVALID_API_KEYS or "" | |
current_invalid_keys_set = set(k.strip() for k in current_invalid_keys_str.split(',') if k.strip()) | |
new_invalid_keys_set = current_invalid_keys_set.union(set(initial_invalid_keys)) | |
if new_invalid_keys_set != current_invalid_keys_set: | |
settings.INVALID_API_KEYS = ','.join(sorted(list(new_invalid_keys_set))) | |
save_settings() | |
log('info', f"更新初始无效密钥列表完成,总无效密钥数: {len(new_invalid_keys_set)}") | |
else: # 跳过检查 | |
log('info',"跳过 API 密钥检查") | |
key_manager.api_keys.extend(keys_to_check_later) | |
key_manager._reset_key_stack() | |
# 初始化路由器 | |
init_router( | |
key_manager, | |
response_cache_manager, | |
active_requests_manager, | |
SAFETY_SETTINGS, | |
SAFETY_SETTINGS_G2, | |
first_valid_key, | |
settings.FAKE_STREAMING, | |
settings.FAKE_STREAMING_INTERVAL, | |
settings.PASSWORD, | |
settings.MAX_REQUESTS_PER_MINUTE, | |
settings.MAX_REQUESTS_PER_DAY_PER_IP | |
) | |
# 初始化仪表盘路由器 | |
init_dashboard_router( | |
key_manager, | |
response_cache_manager, | |
active_requests_manager, | |
credential_manager_instance | |
) | |
# --------------- 异常处理 --------------- | |
async def global_exception_handler(request: Request, exc: Exception): | |
from app.utils import translate_error | |
error_message = translate_error(str(exc)) | |
extra_log_unhandled_exception = {'status_code': 500, 'error_message': error_message} | |
log('error', f"Unhandled exception: {error_message}", extra=extra_log_unhandled_exception) | |
return JSONResponse(status_code=500, content=ErrorResponse(message=str(exc), type="internal_error").dict()) | |
# --------------- 路由 --------------- | |
app.include_router(router) | |
app.include_router(dashboard_router) | |
# 挂载静态文件目录 | |
app.mount("/assets", StaticFiles(directory="app/templates/assets"), name="assets") | |
# 设置根路由路径 | |
dashboard_path = f"/{settings.DASHBOARD_URL}" if settings.DASHBOARD_URL else "/" | |
async def root(request: Request): | |
""" | |
根路由 - 返回静态 HTML 文件 | |
""" | |
base_url = str(request.base_url).replace("http", "https") | |
api_url = f"{base_url}v1" if base_url.endswith("/") else f"{base_url}/v1" | |
# 直接返回 index.html 文件 | |
return templates.TemplateResponse( | |
"index.html", {"request": request, "api_url": api_url} | |
) | |