| | import random
|
| | from fastapi import HTTPException, Request
|
| | import time
|
| | import re
|
| | from datetime import datetime, timedelta
|
| | from apscheduler.schedulers.background import BackgroundScheduler
|
| | import os
|
| | import requests
|
| | import httpx
|
| | from threading import Lock
|
| | import logging
|
| | import sys
|
| |
|
| | DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
|
| | LOG_FORMAT_DEBUG = '%(asctime)s - %(levelname)s - [%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s - %(error_message)s'
|
| | LOG_FORMAT_NORMAL = '[%(key)s]-%(request_type)s-[%(model)s]-%(status_code)s: %(message)s'
|
| |
|
| |
|
| | logger = logging.getLogger("my_logger")
|
| | logger.setLevel(logging.DEBUG)
|
| |
|
| | handler = logging.StreamHandler()
|
| |
|
| |
|
| | logger.addHandler(handler)
|
| |
|
| | def format_log_message(level, message, extra=None):
|
| | extra = extra or {}
|
| | log_values = {
|
| | 'asctime': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| | 'levelname': level,
|
| | 'key': extra.get('key', 'N/A'),
|
| | 'request_type': extra.get('request_type', 'N/A'),
|
| | 'model': extra.get('model', 'N/A'),
|
| | 'status_code': extra.get('status_code', 'N/A'),
|
| | 'error_message': extra.get('error_message', ''),
|
| | 'message': message
|
| | }
|
| | log_format = LOG_FORMAT_DEBUG if DEBUG else LOG_FORMAT_NORMAL
|
| | return log_format % log_values
|
| |
|
| |
|
| | class APIKeyManager:
|
| | def __init__(self):
|
| | self.api_keys = re.findall(
|
| | r"AIzaSy[a-zA-Z0-9_-]{33}", os.environ.get('GEMINI_API_KEYS', ""))
|
| | self.key_stack = []
|
| | self._reset_key_stack()
|
| |
|
| |
|
| | self.scheduler = BackgroundScheduler()
|
| | self.scheduler.start()
|
| | self.tried_keys_for_request = set()
|
| |
|
| | def _reset_key_stack(self):
|
| | """创建并随机化密钥栈"""
|
| | shuffled_keys = self.api_keys[:]
|
| | random.shuffle(shuffled_keys)
|
| | self.key_stack = shuffled_keys
|
| |
|
| |
|
| | def get_available_key(self):
|
| | """从栈顶获取密钥,栈空时重新生成 (修改后)"""
|
| | while self.key_stack:
|
| | key = self.key_stack.pop()
|
| |
|
| | if key not in self.tried_keys_for_request:
|
| | self.tried_keys_for_request.add(key)
|
| | return key
|
| |
|
| | if not self.api_keys:
|
| | log_msg = format_log_message('ERROR', "没有配置任何 API 密钥!")
|
| | logger.error(log_msg)
|
| | return None
|
| |
|
| | self._reset_key_stack()
|
| |
|
| |
|
| | while self.key_stack:
|
| | key = self.key_stack.pop()
|
| |
|
| | if key not in self.tried_keys_for_request:
|
| | self.tried_keys_for_request.add(key)
|
| | return key
|
| |
|
| | return None
|
| |
|
| |
|
| | def show_all_keys(self):
|
| | log_msg = format_log_message('INFO', f"当前可用API key个数: {len(self.api_keys)} ")
|
| | logger.info(log_msg)
|
| | for i, api_key in enumerate(self.api_keys):
|
| | log_msg = format_log_message('INFO', f"API Key{i}: {api_key[:8]}...{api_key[-3:]}")
|
| | logger.info(log_msg)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def reset_tried_keys_for_request(self):
|
| | """在新的请求尝试时重置已尝试的 key 集合"""
|
| | self.tried_keys_for_request = set()
|
| |
|
| |
|
| | def handle_gemini_error(error, current_api_key, key_manager) -> str:
|
| | if isinstance(error, requests.exceptions.HTTPError):
|
| | status_code = error.response.status_code
|
| | if status_code == 400:
|
| | try:
|
| | error_data = error.response.json()
|
| | if 'error' in error_data:
|
| | if error_data['error'].get('code') == "invalid_argument":
|
| | error_message = "无效的 API 密钥"
|
| | extra_log_invalid_key = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
| | log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 无效,可能已过期或被删除", extra=extra_log_invalid_key)
|
| | logger.error(log_msg)
|
| |
|
| |
|
| | return error_message
|
| | error_message = error_data['error'].get(
|
| | 'message', 'Bad Request')
|
| | extra_log_400 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
| | log_msg = format_log_message('WARNING', f"400 错误请求: {error_message}", extra=extra_log_400)
|
| | logger.warning(log_msg)
|
| | return f"400 错误请求: {error_message}"
|
| | except ValueError:
|
| | error_message = "400 错误请求:响应不是有效的JSON格式"
|
| | extra_log_400_json = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
| | log_msg = format_log_message('WARNING', error_message, extra=extra_log_400_json)
|
| | logger.warning(log_msg)
|
| | return error_message
|
| |
|
| | elif status_code == 429:
|
| | error_message = "API 密钥配额已用尽或其他原因"
|
| | extra_log_429 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
| | log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 429 官方资源耗尽或其他原因", extra=extra_log_429)
|
| | logger.warning(log_msg)
|
| |
|
| |
|
| | return error_message
|
| |
|
| | elif status_code == 403:
|
| | error_message = "权限被拒绝"
|
| | extra_log_403 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
| | log_msg = format_log_message('ERROR', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 403 权限被拒绝", extra=extra_log_403)
|
| | logger.error(log_msg)
|
| |
|
| |
|
| | return error_message
|
| | elif status_code == 500:
|
| | error_message = "服务器内部错误"
|
| | extra_log_500 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
| | log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 500 服务器内部错误", extra=extra_log_500)
|
| | logger.warning(log_msg)
|
| |
|
| | return "Gemini API 内部错误"
|
| |
|
| | elif status_code == 503:
|
| | error_message = "服务不可用"
|
| | extra_log_503 = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
| | log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → 503 服务不可用", extra=extra_log_503)
|
| | logger.warning(log_msg)
|
| |
|
| | return "Gemini API 服务不可用"
|
| | else:
|
| | error_message = f"未知错误: {status_code}"
|
| | extra_log_other = {'key': current_api_key[:8], 'status_code': status_code, 'error_message': error_message}
|
| | log_msg = format_log_message('WARNING', f"{current_api_key[:8]} ... {current_api_key[-3:]} → {status_code} 未知错误", extra=extra_log_other)
|
| | logger.warning(log_msg)
|
| |
|
| | return f"未知错误/模型不可用: {status_code}"
|
| |
|
| | elif isinstance(error, requests.exceptions.ConnectionError):
|
| | error_message = "连接错误"
|
| | log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message})
|
| | logger.warning(log_msg)
|
| | return error_message
|
| |
|
| | elif isinstance(error, requests.exceptions.Timeout):
|
| | error_message = "请求超时"
|
| | log_msg = format_log_message('WARNING', error_message, extra={'error_message': error_message})
|
| | logger.warning(log_msg)
|
| | return error_message
|
| | else:
|
| | error_message = f"发生未知错误: {error}"
|
| | log_msg = format_log_message('ERROR', error_message, extra={'error_message': error_message})
|
| | logger.error(log_msg)
|
| | return error_message
|
| |
|
| |
|
| | async def test_api_key(api_key: str) -> bool:
|
| | """
|
| | 测试 API 密钥是否有效。
|
| | """
|
| | try:
|
| | url = "https://generativelanguage.googleapis.com/v1beta/models?key={}".format(api_key)
|
| | async with httpx.AsyncClient() as client:
|
| | response = await client.get(url)
|
| | response.raise_for_status()
|
| | return True
|
| | except Exception:
|
| | return False
|
| |
|
| |
|
| | rate_limit_data = {}
|
| | rate_limit_lock = Lock()
|
| |
|
| |
|
| | def protect_from_abuse(request: Request, max_requests_per_minute: int = 30, max_requests_per_day_per_ip: int = 600):
|
| | now = int(time.time())
|
| | minute = now // 60
|
| | day = now // (60 * 60 * 24)
|
| |
|
| | minute_key = f"{request.url.path}:{minute}"
|
| | day_key = f"{request.client.host}:{day}"
|
| |
|
| | with rate_limit_lock:
|
| | minute_count, minute_timestamp = rate_limit_data.get(
|
| | minute_key, (0, now))
|
| | if now - minute_timestamp >= 60:
|
| | minute_count = 0
|
| | minute_timestamp = now
|
| | minute_count += 1
|
| | rate_limit_data[minute_key] = (minute_count, minute_timestamp)
|
| |
|
| | day_count, day_timestamp = rate_limit_data.get(day_key, (0, now))
|
| | if now - day_timestamp >= 86400:
|
| | day_count = 0
|
| | day_timestamp = now
|
| | day_count += 1
|
| | rate_limit_data[day_key] = (day_count, day_timestamp)
|
| |
|
| | if minute_count > max_requests_per_minute:
|
| | raise HTTPException(status_code=429, detail={
|
| | "message": "Too many requests per minute", "limit": max_requests_per_minute})
|
| | if day_count > max_requests_per_day_per_ip:
|
| | raise HTTPException(status_code=429, detail={"message": "Too many requests per day from this IP", "limit": max_requests_per_day_per_ip}) |