|
from fastapi import FastAPI, HTTPException, Request, status |
|
from fastapi.responses import JSONResponse, HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.templating import Jinja2Templates |
|
from app.models import ErrorResponse |
|
from app.services import GeminiClient |
|
from app.utils import ( |
|
APIKeyManager, |
|
test_api_key, |
|
format_log_message, |
|
log_manager, |
|
ResponseCacheManager, |
|
ActiveRequestsManager, |
|
clean_expired_stats, |
|
update_api_call_stats, |
|
check_version, |
|
schedule_cache_cleanup, |
|
handle_exception, |
|
log |
|
) |
|
from app.api import router, init_router, dashboard_router, init_dashboard_router |
|
from app.vertex.vertex import router as vertex_router |
|
from app.vertex.vertex import init_vertex_ai |
|
import app.config.settings as settings |
|
from app.config.safety import SAFETY_SETTINGS, SAFETY_SETTINGS_G2 |
|
import os |
|
import json |
|
import asyncio |
|
import time |
|
import logging |
|
from datetime import datetime, timedelta |
|
import sys |
|
import pathlib |
|
import threading |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
BASE_DIR = pathlib.Path(__file__).parent |
|
templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) |
|
|
|
app = FastAPI(limit="50M") |
|
|
|
|
|
|
|
|
|
key_manager = APIKeyManager() |
|
current_api_key = key_manager.get_available_key() |
|
|
|
|
|
response_cache = {} |
|
|
|
|
|
response_cache_manager = ResponseCacheManager( |
|
expiry_time=settings.CACHE_EXPIRY_TIME, |
|
max_entries=settings.MAX_CACHE_ENTRIES, |
|
remove_after_use=settings.REMOVE_CACHE_AFTER_USE, |
|
cache_dict=response_cache |
|
) |
|
|
|
|
|
active_requests_pool = {} |
|
|
|
|
|
active_requests_manager = ActiveRequestsManager(requests_pool=active_requests_pool) |
|
|
|
|
|
|
|
def switch_api_key(): |
|
global current_api_key |
|
key = key_manager.get_available_key() |
|
if key: |
|
current_api_key = key |
|
log('info', f"API key 替换为 → {current_api_key[:8]}...", extra={'key': current_api_key[:8], 'request_type': 'switch_key'}) |
|
else: |
|
log('error', "API key 替换失败,所有API key都已尝试,请重新配置或稍后重试", extra={'key': 'N/A', 'request_type': 'switch_key', 'status_code': 'N/A'}) |
|
|
|
async def check_key(key): |
|
"""检查单个API密钥是否有效""" |
|
is_valid = await test_api_key(key) |
|
status_msg = "有效" if is_valid else "无效" |
|
log('info', f"API Key {key[:10]}... {status_msg}.") |
|
return key if is_valid else None |
|
|
|
def check_key_in_thread(key): |
|
"""在线程中运行异步检查密钥函数""" |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
try: |
|
valid_key = loop.run_until_complete(check_key(key)) |
|
if valid_key: |
|
|
|
key_manager.api_keys.append(valid_key) |
|
|
|
key_manager._reset_key_stack() |
|
log('info', f"API Key {valid_key[:8]}... 已添加到可用列表") |
|
return valid_key |
|
finally: |
|
loop.close() |
|
|
|
async def check_keys(): |
|
"""启动线程池来并行检查所有密钥""" |
|
|
|
all_keys = key_manager.api_keys.copy() |
|
|
|
key_manager.api_keys = [] |
|
|
|
log('info', f"开始在单独线程中检查 {len(all_keys)} 个API密钥...") |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=min(10, len(all_keys))) as executor: |
|
|
|
future_to_key = {executor.submit(check_key_in_thread, key): key for key in all_keys} |
|
|
|
|
|
for future in future_to_key: |
|
try: |
|
future.result() |
|
except Exception as exc: |
|
log('error', f"检查密钥时发生错误: {exc}") |
|
|
|
if not key_manager.api_keys: |
|
log('error', "没有可用的 API 密钥!如果您不使用ai studio 请忽略这些错误", extra={'key': 'N/A', 'request_type': 'startup', 'status_code': 'N/A'}) |
|
|
|
return key_manager.api_keys |
|
|
|
|
|
sys.excepthook = handle_exception |
|
|
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
log('info', "Starting Gemini API proxy...") |
|
await check_version() |
|
init_vertex_ai() |
|
log('info', "初始化Vertex AI") |
|
schedule_cache_cleanup(response_cache_manager, active_requests_manager) |
|
|
|
|
|
await check_version() |
|
|
|
|
|
all_keys = key_manager.api_keys.copy() |
|
key_manager.api_keys = [] |
|
|
|
|
|
valid_key_found = False |
|
for key in all_keys: |
|
is_valid = await test_api_key(key) |
|
if is_valid: |
|
key_manager.api_keys.append(key) |
|
key_manager._reset_key_stack() |
|
log('info', f"初始检查: API Key {key[:8]}... 有效,已添加到可用列表") |
|
valid_key_found = True |
|
|
|
|
|
try: |
|
all_models = await GeminiClient.list_available_models(key) |
|
GeminiClient.AVAILABLE_MODELS = [model.replace( |
|
"models/", "") for model in all_models] |
|
log('info', "Available models loaded.") |
|
except Exception as e: |
|
log('warning', f"无法加载可用模型: {str(e)}") |
|
|
|
break |
|
|
|
if not valid_key_found: |
|
log('warning', "初始检查未找到有效密钥,将在后台继续检查") |
|
|
|
|
|
remaining_keys = [k for k in all_keys if k not in key_manager.api_keys] |
|
if remaining_keys: |
|
def check_remaining_keys(): |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
try: |
|
|
|
with ThreadPoolExecutor(max_workers=min(10, len(remaining_keys))) as executor: |
|
future_to_key = {executor.submit(check_key_in_thread, key): key for key in remaining_keys} |
|
for future in future_to_key: |
|
try: |
|
future.result() |
|
except Exception as exc: |
|
log('error', f"检查密钥时发生错误: {exc}") |
|
finally: |
|
loop.close() |
|
log('info', f"后台密钥检查完成,当前可用密钥数量: {len(key_manager.api_keys)}") |
|
|
|
|
|
threading.Thread(target=check_remaining_keys, daemon=True).start() |
|
log('info', f"后台线程已启动,正在检查剩余的 {len(remaining_keys)} 个API密钥...") |
|
|
|
|
|
key_manager.show_all_keys() |
|
log('info', f"当前可用 API 密钥数量:{len(key_manager.api_keys)}") |
|
log('info', f"最大重试次数设置为:{len(key_manager.api_keys)}") |
|
|
|
|
|
init_router( |
|
key_manager, |
|
response_cache_manager, |
|
active_requests_manager, |
|
SAFETY_SETTINGS, |
|
SAFETY_SETTINGS_G2, |
|
current_api_key, |
|
settings.FAKE_STREAMING, |
|
settings.FAKE_STREAMING_INTERVAL, |
|
settings.PASSWORD, |
|
settings.MAX_REQUESTS_PER_MINUTE, |
|
settings.MAX_REQUESTS_PER_DAY_PER_IP |
|
) |
|
|
|
|
|
init_dashboard_router( |
|
key_manager, |
|
response_cache_manager, |
|
active_requests_manager |
|
) |
|
|
|
|
|
init_router( |
|
key_manager, |
|
response_cache_manager, |
|
active_requests_manager, |
|
SAFETY_SETTINGS, |
|
SAFETY_SETTINGS_G2, |
|
current_api_key, |
|
settings.FAKE_STREAMING, |
|
settings.FAKE_STREAMING_INTERVAL, |
|
settings.PASSWORD, |
|
settings.MAX_REQUESTS_PER_MINUTE, |
|
settings.MAX_REQUESTS_PER_DAY_PER_IP |
|
) |
|
|
|
|
|
|
|
@app.exception_handler(Exception) |
|
async def global_exception_handler(request: Request, exc: Exception): |
|
from app.utils import translate_error |
|
error_message = translate_error(str(exc)) |
|
extra_log_unhandled_exception = {'status_code': 500, 'error_message': error_message} |
|
log('error', f"Unhandled exception: {error_message}", extra=extra_log_unhandled_exception) |
|
return JSONResponse(status_code=500, content=ErrorResponse(message=str(exc), type="internal_error").dict()) |
|
|
|
|
|
|
|
app.include_router(router) |
|
app.include_router(dashboard_router) |
|
|
|
|
|
app.mount("/assets", StaticFiles(directory="app/templates/assets"), name="assets") |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def root(request: Request): |
|
""" |
|
根路由 - 返回静态 HTML 文件 |
|
""" |
|
base_url = str(request.base_url).replace("http", "https") |
|
api_url = f"{base_url}v1" if base_url.endswith("/") else f"{base_url}/v1" |
|
|
|
return templates.TemplateResponse( |
|
"index.html", {"request": request, "api_url": api_url} |
|
) |
|
|