|
import random |
|
from fastapi import HTTPException, Request |
|
import time |
|
import re |
|
from datetime import datetime, timedelta |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
import os |
|
import requests |
|
import httpx |
|
from threading import Lock |
|
import logging |
|
import sys |
|
from collections import deque |
|
|
|
DEBUG = os.environ.get("DEBUG", "false").lower() == "true" |
|
LOG_FORMAT_DEBUG = '%(asctime)s - %(levelname)s - [%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s - %(error_message)s' |
|
LOG_FORMAT_NORMAL = '[%(asctime)s] [%(levelname)s] [%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s' |
|
|
|
|
|
logger = logging.getLogger("my_logger") |
|
logger.setLevel(logging.DEBUG) |
|
|
|
|
|
console_handler = logging.StreamHandler() |
|
console_formatter = logging.Formatter('%(message)s') |
|
console_handler.setFormatter(console_formatter) |
|
logger.addHandler(console_handler) |
|
|
|
|
|
class LogManager: |
|
def __init__(self, max_logs=100): |
|
self.logs = deque(maxlen=max_logs) |
|
self.lock = Lock() |
|
|
|
def add_log(self, log_entry): |
|
with self.lock: |
|
self.logs.append(log_entry) |
|
|
|
def get_recent_logs(self, count=50): |
|
with self.lock: |
|
return list(self.logs)[-count:] |
|
|
|
|
|
log_manager = LogManager() |
|
|
|
def format_log_message(level, message, extra=None): |
|
extra = extra or {} |
|
log_values = { |
|
'asctime': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
'levelname': level, |
|
'key': extra.get('key', 'N/A'), |
|
'request_type': extra.get('request_type', 'N/A'), |
|
'model': extra.get('model', 'N/A'), |
|
'status_code': extra.get('status_code', 'N/A'), |
|
'error_message': extra.get('error_message', ''), |
|
'message': message |
|
} |
|
log_format = LOG_FORMAT_DEBUG if DEBUG else LOG_FORMAT_NORMAL |
|
formatted_log = log_format % log_values |
|
|
|
|
|
log_entry = { |
|
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
'level': level, |
|
'key': extra.get('key', 'N/A'), |
|
'request_type': extra.get('request_type', 'N/A'), |
|
'model': extra.get('model', 'N/A'), |
|
'status_code': extra.get('status_code', 'N/A'), |
|
'message': message, |
|
'error_message': extra.get('error_message', ''), |
|
'formatted': formatted_log |
|
} |
|
log_manager.add_log(log_entry) |
|
|
|
return formatted_log |
|
|
|
|
|
class APIKeyManager: |
|
def __init__(self): |
|
self.api_keys = re.findall( |
|
r"AIzaSy[a-zA-Z0-9_-]{33}", os.environ.get('GEMINI_API_KEYS', "")) |
|
self.key_stack = [] |
|
self._reset_key_stack() |
|
|
|
|
|
self.scheduler = BackgroundScheduler() |
|
self.scheduler.start() |
|
self.tried_keys_for_request = set() |
|
|
|
def _reset_key_stack(self): |
|
"""创建并随机化密钥栈""" |
|
shuffled_keys = self.api_keys[:] |
|
random.shuffle(shuffled_keys) |
|
self.key_stack = shuffled_keys |
|
|
|
|
|
def get_available_key(self): |
|
"""从栈顶获取密钥,栈空时重新生成 (修改后)""" |
|
while self.key_stack: |
|
key = self.key_stack.pop() |
|
|
|
if key not in self.tried_keys_for_request: |
|
self.tried_keys_for_request.add(key) |
|
return key |
|
|
|
if not self.api_keys: |
|
log_msg = format_log_message('ERROR', "没有配置任何 API 密钥!") |
|
logger.error(log_msg) |
|
return None |
|
|
|
self._reset_key_stack() |
|
|
|
|
|
while self.key_stack: |
|
key = self.key_stack.pop() |
|
|
|
if key not in self.tried_keys_for_request: |
|
self.tried_keys_for_request.add(key) |
|
return key |
|
|
|
return None |
|
|
|
|
|
def show_all_keys(self): |
|
log_msg = format_log_message('INFO', f"当前可用API key个数: {len(self.api_keys)} ") |
|
logger.info(log_msg) |
|
for i, api_key in enumerate(self.api_keys): |
|
log_msg = format_log_message('INFO', f"API Key{i}: {api_key[:8]}...{api_key[-3:]}") |
|
logger.info(log_msg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_tried_keys_for_request(self): |
|
"""在新的请求尝试时重置已尝试的 key 集合""" |
|
self.tried_keys_for_request = set() |
|
|
|
|
|
def handle_gemini_error(error, current_api_key, key_manager) -> str: |
|
if isinstance(error, requests.exceptions.HTTPError): |
|
status_code = error.response.status_code |
|
if status_code == 400: |
|
try: |
|
error_data = error.response.json() |
|
if 'error' in error_data: |
|
if error_data['error'].get('code') == "invalid_argument": |
|
error_message = "无效的 API 密钥" |
|
extra_log_invalid_key = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
|
log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 无效,可能已过期或被删除", extra=extra_log_invalid_key) |
|
logger.error(log_msg) |
|
|
|
|
|
return error_message |
|
error_message = error_data['error'].get( |
|
'message', 'Bad Request') |
|
extra_log_400 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
|
log_msg = format_log_message('WARNING', f"400 错误请求: {error_message}", extra=extra_log_400) |
|
logger.warning(log_msg) |
|
return f"400 错误请求: {error_message}" |
|
except ValueError: |
|
error_message = "400 错误请求:响应不是有效的JSON格式" |
|
extra_log_400_json = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
|
log_msg = format_log_message('WARNING', error_message, extra=extra_log_400_json) |
|
logger.warning(log_msg) |
|
return error_message |
|
|
|
elif status_code == 429: |
|
error_message = "API 密钥配额已用尽或其他原因" |
|
extra_log_429 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
|
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 429 官方资源耗尽或其他原因", extra=extra_log_429) |
|
logger.warning(log_msg) |
|
|
|
|
|
return error_message |
|
|
|
elif status_code == 403: |
|
error_message = "权限被拒绝" |
|
extra_log_403 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
|
log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 403 权限被拒绝", extra=extra_log_403) |
|
logger.error(log_msg) |
|
|
|
|
|
return error_message |
|
elif status_code == 500: |
|
error_message = "服务器内部错误" |
|
extra_log_500 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
|
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 500 服务器内部错误", extra=extra_log_500) |
|
logger.warning(log_msg) |
|
|
|
return "Gemini API 内部错误" |
|
|
|
elif status_code == 503: |
|
error_message = "服务不可用" |
|
extra_log_503 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
|
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 503 服务不可用", extra=extra_log_503) |
|
logger.warning(log_msg) |
|
|
|
return "Gemini API 服务不可用" |
|
else: |
|
error_message = f"未知错误: {status_code}" |
|
extra_log_other = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message} |
|
log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → {status_code} 未知错误", extra=extra_log_other) |
|
logger.warning(log_msg) |
|
|
|
return f"未知错误/模型不可用: {status_code}" |
|
|
|
elif isinstance(error, requests.exceptions.ConnectionError): |
|
error_message = "连接错误" |
|
log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message}) |
|
logger.warning(log_msg) |
|
return error_message |
|
|
|
elif isinstance(error, requests.exceptions.Timeout): |
|
error_message = "请求超时" |
|
log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message}) |
|
logger.warning(log_msg) |
|
return error_message |
|
else: |
|
error_message = f"发生未知错误: {error}" |
|
log_msg = format_log_message('ERROR', error_message, extra={'error_message': error_message}) |
|
logger.error(log_msg) |
|
return error_message |
|
|
|
|
|
async def test_api_key(api_key: str) -> bool: |
|
""" |
|
测试 API 密钥是否有效。 |
|
""" |
|
try: |
|
url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format(api_key) |
|
async with httpx.AsyncClient() as client: |
|
response = await client.get(url) |
|
response.raise_for_status() |
|
return True |
|
except Exception: |
|
return False |
|
|
|
|
|
rate_limit_data = {} |
|
rate_limit_lock = Lock() |
|
|
|
|
|
def protect_from_abuse(request: Request, max_requests_per_minute: int = 30, max_requests_per_day_per_ip: int = 600): |
|
now = int(time.time()) |
|
minute = now // 60 |
|
day = now // (60 * 60 * 24) |
|
|
|
minute_key = f"{request.url.path}:{minute}" |
|
day_key = f"{request.client.host}:{day}" |
|
|
|
with rate_limit_lock: |
|
minute_count, minute_timestamp = rate_limit_data.get( |
|
minute_key, (0, now)) |
|
if now - minute_timestamp >= 60: |
|
minute_count = 0 |
|
minute_timestamp = now |
|
minute_count += 1 |
|
rate_limit_data[minute_key] = (minute_count, minute_timestamp) |
|
|
|
day_count, day_timestamp = rate_limit_data.get(day_key, (0, now)) |
|
if now - day_timestamp >= 86400: |
|
day_count = 0 |
|
day_timestamp = now |
|
day_count += 1 |
|
rate_limit_data[day_key] = (day_count, day_timestamp) |
|
|
|
if minute_count > max_requests_per_minute: |
|
raise HTTPException(status_code=429, detail={ |
|
"message": "Too many requests per minute", "limit": max_requests_per_minute}) |
|
if day_count > max_requests_per_day_per_ip: |
|
raise HTTPException(status_code=429, detail={"message": "Too many requests per day from this IP", "limit": max_requests_per_day_per_ip}) |