notdiamond2api2 / app.py
dan92's picture
Upload app.py
d0d6353 verified
raw
history blame
22 kB
import json
import logging
import os
import random
import time
import uuid
import re
import socket
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache, wraps
from typing import Dict, Any, Callable, List, Tuple
import requests
import tiktoken
from flask import Flask, Response, jsonify, request, stream_with_context
from flask_cors import CORS
from requests.adapters import HTTPAdapter
from urllib3.util.connection import create_connection
import urllib3
from cachetools import TTLCache
import threading
from time import sleep
from datetime import datetime, timedelta
import concurrent.futures
from concurrent.futures import TimeoutError
# 新增导入
import register_bot
# Constants
CHAT_COMPLETION_CHUNK = 'chat.completion.chunk'
CHAT_COMPLETION = 'chat.completion'
CONTENT_TYPE_EVENT_STREAM = 'text/event-stream'
_BASE_URL = "https://chat.notdiamond.ai"
_API_BASE_URL = "https://spuckhogycrxcbomznwo.supabase.co"
_USER_AGENT = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36'
# 从环境变量获取API密钥和特定URL
API_KEY = os.getenv('API_KEY')
_PASTE_API_URL = os.getenv('PASTE_API_URL')
_PASTE_API_PASSWORD = os.getenv('PASTE_API_PASSWORD')
if not API_KEY:
raise ValueError("API_KEY environment variable must be set")
if not _PASTE_API_URL:
raise ValueError("PASTE_API_URL environment variable must be set")
# 创建 Flask 应用
app = Flask(__name__)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
CORS(app, resources={r"/*": {"origins": "*"}})
executor = ThreadPoolExecutor(max_workers=10)
proxy_url = os.getenv('PROXY_URL')
NOTDIAMOND_IP = os.getenv('NOTDIAMOND_IP')
NOTDIAMOND_DOMAIN = os.getenv('NOTDIAMOND_DOMAIN')
if not NOTDIAMOND_IP:
logger.error("NOTDIAMOND_IP environment variable is not set!")
raise ValueError("NOTDIAMOND_IP must be set")
# 其他代码保持不变...
@app.route('/', methods=['GET'])
def root():
return jsonify({
"service": "AI Chat Completion Proxy",
"usage": {
"endpoint": "/ai/v1/chat/completions",
"method": "POST",
"headers": {
"Authorization": "Bearer YOUR_API_KEY"
},
"body": {
"model": "One of: " + ", ".join(MODEL_INFO.keys()),
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, who are you?"}
],
"stream": False,
"temperature": 0.7
}
},
"availableModels": list(MODEL_INFO.keys()),
"note": "API key authentication is required for other endpoints."
})
# 为了兼容 Flask CLI 和 Gunicorn,修改启动逻辑
if __name__ != "__main__":
health_check_thread = threading.Thread(target=health_check, daemon=True)
health_check_thread.start()
if __name__ == "__main__":
health_check_thread = threading.Thread(target=health_check, daemon=True)
health_check_thread.start()
port = int(os.environ.get("PORT", 3000))
app.run(debug=False, host='0.0.0.0', port=port, threaded=True)
# API密钥验证装饰器
def require_api_key(f):
@wraps(f)
def decorated_function(*args, **kwargs):
auth_header = request.headers.get('Authorization')
if not auth_header:
return jsonify({'error': 'No API key provided'}), 401
try:
# 从 Bearer token 中提取API密钥
provided_key = auth_header.split('Bearer ')[-1].strip()
if provided_key != API_KEY:
return jsonify({'error': 'Invalid API key'}), 401
except Exception:
return jsonify({'error': 'Invalid Authorization header format'}), 401
return f(*args, **kwargs)
return decorated_function
refresh_token_cache = TTLCache(maxsize=1000, ttl=3600)
headers_cache = TTLCache(maxsize=1, ttl=3600) # 1小时过期
token_refresh_lock = threading.Lock()
# 自定义连接函数
def patched_create_connection(address, *args, **kwargs):
host, port = address
if host == NOTDIAMOND_DOMAIN:
logger.info(f"Connecting to {NOTDIAMOND_DOMAIN} using IP: {NOTDIAMOND_IP}")
return create_connection((NOTDIAMOND_IP, port), *args, **kwargs)
return create_connection(address, *args, **kwargs)
# 替换 urllib3 的默认连接函数
urllib3.util.connection.create_connection = patched_create_connection
# 自定义 HTTPAdapter
class CustomHTTPAdapter(HTTPAdapter):
def init_poolmanager(self, *args, **kwargs):
kwargs['socket_options'] = kwargs.get('socket_options', [])
kwargs['socket_options'] += [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]
return super(CustomHTTPAdapter, self).init_poolmanager(*args, **kwargs)
# 创建自定义的 Session
def create_custom_session():
session = requests.Session()
adapter = CustomHTTPAdapter()
session.mount('https://', adapter)
session.mount('http://', adapter)
return session
# 添加速率限制相关的常量
AUTH_RETRY_DELAY = 60 # 认证重试延迟(秒)
AUTH_BACKOFF_FACTOR = 2 # 退避因子
AUTH_MAX_RETRIES = 3 # 最大重试次数
AUTH_CHECK_INTERVAL = 300 # 健康检查间隔(秒)
AUTH_RATE_LIMIT_WINDOW = 3600 # 速率限制窗口(秒)
AUTH_MAX_REQUESTS = 100 # 每个窗口最大请求数
class AuthManager:
def __init__(self, email: str, password: str):
self._email: str = email
self._password: str = password
self._max_retries: int = 3
self._retry_delay: int = 1
self._api_key: str = ""
self._user_info: Dict[str, Any] = {}
self._refresh_token: str = ""
self._access_token: str = ""
self._token_expiry: float = 0
self._session: requests.Session = create_custom_session()
self._logger: logging.Logger = logging.getLogger(__name__)
self.model_status = {model: True for model in MODEL_INFO.keys()}
self.last_successful_index = 0
self.last_success_date = datetime.now().date()
def get_next_auth_manager(self, model):
"""改进的账号选择逻辑,优先使用上次成功的账号"""
current_date = datetime.now().date()
# 如果是新的一天,重置状态并从第一个账号开始
if current_date > self.last_success_date:
self.current_index = 0
self.last_successful_index = 0
self.last_success_date = current_date
self.reset_all_model_status()
return self.auth_managers[0] if self.auth_managers else None
# 优先使用上次成功的账号
auth_manager = self.auth_managers[self.last_successful_index]
if auth_manager.is_model_available(model) and auth_manager._should_attempt_auth():
return auth_manager
# 如果上次成功的账号不可用,才开始轮询其他账号
start_index = (self.last_successful_index + 1) % len(self.auth_managers)
current = start_index
while current != self.last_successful_index:
auth_manager = self.auth_managers[current]
if auth_manager.is_model_available(model) and auth_manager._should_attempt_auth():
self.last_successful_index = current
return auth_manager
current = (current + 1) % len(self.auth_managers)
return None
def update_last_successful(self, index):
"""更新最后一次成功使用的账号索引"""
self.last_successful_index = index
self.last_success_date = datetime.now().date()
# ... (其他 AuthManager 方法保持不变)
MODEL_INFO = {
"gpt-4o-mini": {"provider": "openai", "mapping": "gpt-4o-mini"},
"gpt-4o": {"provider": "openai", "mapping": "gpt-4o"},
"gpt-4-turbo": {"provider": "openai", "mapping": "gpt-4-turbo-2024-04-09"},
"chatgpt-4o-latest": {"provider": "openai", "mapping": "chatgpt-4o-latest"},
"gemini-1.5-pro-latest": {"provider": "google", "mapping": "models/gemini-1.5-pro-latest"},
"gemini-1.5-flash-latest": {"provider": "google", "mapping": "models/gemini-1.5-flash-latest"},
"llama-3.1-70b-instruct": {"provider": "togetherai", "mapping": "meta.llama3-1-70b-instruct-v1:0"},
"llama-3.1-405b-instruct": {"provider": "togetherai", "mapping": "meta.llama3-1-405b-instruct-v1:0"},
"claude-3-5-sonnet-20241022": {"provider": "anthropic", "mapping": "anthropic.claude-3-5-sonnet-20241022-v2:0"},
"claude-3-5-haiku-20241022": {"provider": "anthropic", "mapping": "anthropic.claude-3-5-haiku-20241022-v1:0"},
"perplexity": {"provider": "perplexity", "mapping": "llama-3.1-sonar-large-128k-online"},
"mistral-large-2407": {"provider": "mistral", "mapping": "mistral.mistral-large-2407-v1:0"}
}
def stream_notdiamond_response(response, model):
"""改进的流式响应处理,添加超时处理和错误恢复"""
buffer = ""
full_content = ""
last_activity = time.time()
timeout = 30 # 设置单个块的超时时间
try:
for chunk in response.iter_content(chunk_size=1024):
current_time = time.time()
# 检查是否超时
if current_time - last_activity > timeout:
logger.warning("Stream response timeout, sending partial content")
if full_content:
final_chunk = create_openai_chunk('', model, 'timeout')
if 'choices' in final_chunk and final_chunk['choices']:
final_chunk['choices'][0]['context'] = full_content
yield final_chunk
return
if chunk:
try:
new_content = chunk.decode('utf-8')
buffer += new_content
full_content += new_content
chunk_data = create_openai_chunk(new_content, model)
if 'choices' in chunk_data and chunk_data['choices']:
chunk_data['choices'][0]['delta']['content'] = new_content
chunk_data['choices'][0]['context'] = full_content
yield chunk_data
last_activity = current_time
except Exception as e:
logger.error(f"Error processing chunk: {e}")
continue
final_chunk = create_openai_chunk('', model, 'stop')
if 'choices' in final_chunk and final_chunk['choices']:
final_chunk['choices'][0]['context'] = full_content
yield final_chunk
except Exception as e:
logger.error(f"Stream response error: {e}")
error_chunk = create_openai_chunk('', model, 'error')
if 'choices' in error_chunk and error_chunk['choices']:
error_chunk['choices'][0]['context'] = full_content
yield error_chunk
def make_request(payload, auth_manager, model_id):
"""改进的请求处理,添加超时控制"""
global multi_auth_manager
max_retries = 3
retry_delay = 1
request_timeout = 30 # 设置请求超时时间
logger.info(f"尝试发送请求,模型:{model_id}")
# ... (其他代码保持不变)
while len(tried_accounts) < len(multi_auth_manager.auth_managers):
auth_manager = multi_auth_manager.get_next_auth_manager(model_id)
if not auth_manager:
break
if auth_manager._email in tried_accounts:
continue
tried_accounts.add(auth_manager._email)
logger.info(f"尝试使用账号 {auth_manager._email}")
for attempt in range(max_retries):
try:
url = get_notdiamond_url()
headers = get_notdiamond_headers(auth_manager)
response = executor.submit(
requests.post,
url,
headers=headers,
json=payload,
stream=True,
timeout=request_timeout
).result(timeout=request_timeout)
if response.status_code == 200 and response.headers.get('Content-Type') == 'text/event-stream':
logger.info(f"请求成功,使用账号 {auth_manager._email}")
current_index = multi_auth_manager.auth_managers.index(auth_manager)
multi_auth_manager.update_last_successful(current_index)
return response
except (requests.Timeout, concurrent.futures.TimeoutError) as e:
logger.error(f"Request timeout for account {auth_manager._email}: {e}")
break
except Exception as e:
logger.error(f"Request attempt {attempt + 1} failed for account {auth_manager._email}: {e}")
if attempt < max_retries - 1:
time.sleep(retry_delay)
continue
def health_check():
"""改进的健康检查函数,每60秒只检查一个账号"""
check_index = 0
last_check_date = datetime.now().date()
while True:
try:
if multi_auth_manager:
current_date = datetime.now().date()
# 如果是新的一天,重置检查索引
if current_date > last_check_date:
check_index = 0
last_check_date = current_date
logger.info("New day started, resetting health check index")
continue
# 只检查一个账号
if check_index < len(multi_auth_manager.auth_managers):
auth_manager = multi_auth_manager.auth_managers[check_index]
email = auth_manager._email
if auth_manager._should_attempt_auth():
if not auth_manager.ensure_valid_token():
logger.warning(f"Auth token validation failed during health check for {email}")
auth_manager.clear_auth()
else:
logger.info(f"Health check passed for {email}")
else:
logger.info(f"Skipping health check for {email} due to rate limiting")
# 更新检查索引
check_index = (check_index + 1) % len(multi_auth_manager.auth_managers)
# 在每天午夜重置所有账号的模型使用状态
current_time_local = time.localtime()
if current_time_local.tm_hour == 0 and current_time_local.tm_min == 0:
multi_auth_manager.reset_all_model_status()
logger.info("Reset model status for all accounts")
except Exception as e:
logger.error(f"Health check error: {e}")
sleep(60) # 每60秒检查一个账号
def generate_system_fingerprint():
"""生成并返回唯一的系统指纹。"""
return f"fp_{uuid.uuid4().hex[:10]}"
def create_openai_chunk(content, model, finish_reason=None, usage=None):
"""创建OpenAI格式的响应块。"""
chunk = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": CHAT_COMPLETION_CHUNK,
"created": int(time.time()),
"model": model,
"system_fingerprint": generate_system_fingerprint(),
"choices": [
{
"index": 0,
"delta": {"content": content} if content else {},
"logprobs": None,
"finish_reason": finish_reason
}
]
}
if usage is not None:
chunk["usage"] = usage
return chunk
def count_tokens(text, model="gpt-3.5-turbo-0301"):
"""计算给定文本的令牌数量。"""
try:
return len(tiktoken.encoding_for_model(model).encode(text))
except KeyError:
return len(tiktoken.get_encoding("cl100k_base").encode(text))
def count_message_tokens(messages, model="gpt-3.5-turbo-0301"):
"""计算消息列表中的总令牌数量。"""
return sum(count_tokens(str(message), model) for message in messages)
NOTDIAMOND_URLS = os.getenv('NOTDIAMOND_URLS', 'https://not-diamond-workers.t7-cc4.workers.dev/stream-message').split(',')
def get_notdiamond_url():
"""随机选择并返回一个 notdiamond URL。"""
return random.choice(NOTDIAMOND_URLS)
def get_notdiamond_headers(auth_manager):
"""返回用于 notdiamond API 请求的头信息。"""
cache_key = f'notdiamond_headers_{auth_manager.get_jwt_value()}'
try:
return headers_cache[cache_key]
except KeyError:
headers = {
'accept': 'text/event-stream',
'accept-language': 'zh-CN,zh;q=0.9',
'content-type': 'application/json',
'user-agent': _USER_AGENT,
'authorization': f'Bearer {auth_manager.get_jwt_value()}'
}
headers_cache[cache_key] = headers
return headers
def generate_stream_response(response, model, prompt_tokens):
"""生成流式 HTTP 响应。"""
total_completion_tokens = 0
for chunk in stream_notdiamond_response(response, model):
content = chunk['choices'][0]['delta'].get('content', '')
total_completion_tokens += count_tokens(content, model)
chunk['usage'] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": total_completion_tokens,
"total_tokens": prompt_tokens + total_completion_tokens
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
def handle_non_stream_response(response, model, prompt_tokens):
"""处理非流式响应。"""
full_content = ""
try:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
content = chunk.decode('utf-8')
full_content += content
completion_tokens = count_tokens(full_content, model)
total_tokens = prompt_tokens + completion_tokens
response_data = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": CHAT_COMPLETION,
"created": int(time.time()),
"model": model,
"system_fingerprint": generate_system_fingerprint(),
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": full_content
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens
}
}
return jsonify(response_data)
except Exception as e:
logger.error(f"Error processing non-stream response: {e}")
raise
@app.route('/ai/v1/chat/completions', methods=['POST'])
@require_api_key
def handle_request():
"""处理聊天完成请求的主路由。"""
global multi_auth_manager
if not multi_auth_manager:
return jsonify({'error': 'Unauthorized'}), 401
try:
request_data = request.get_json()
model_id = request_data.get('model', '')
auth_manager = multi_auth_manager.ensure_valid_token(model_id)
if not auth_manager:
return jsonify({'error': 'No available accounts for this model'}), 403
stream = request_data.get('stream', False)
prompt_tokens = count_message_tokens(
request_data.get('messages', []),
model_id
)
payload = {
'model': MODEL_INFO[model_id]['mapping'],
'messages': request_data.get('messages', []),
'temperature': request_data.get('temperature', 1),
'max_tokens': request_data.get('max_tokens'),
'presence_penalty': request_data.get('presence_penalty'),
'frequency_penalty': request_data.get('frequency_penalty'),
'top_p': request_data.get('top_p', 1),
}
response = make_request(payload, auth_manager, model_id)
if stream:
return Response(
stream_with_context(generate_stream_response(response, model_id, prompt_tokens)),
content_type=CONTENT_TYPE_EVENT_STREAM
)
else:
return handle_non_stream_response(response, model_id, prompt_tokens)
except requests.RequestException as e:
logger.error(f"Request error: {e}")
return jsonify({
'error': {
'message': 'Error communicating with the API',
'type': 'api_error',
'details': str(e)
}
}), 503
except Exception as e:
logger.error(f"Unexpected error: {e}")
return jsonify({
'error': {
'message': 'Internal Server Error',
'type': 'server_error',
'details': str(e)
}
}), 500
@app.route('/ai/v1/models', methods=['GET'])
@require_api_key
def list_models():
"""返回可用模型列表。"""
models = [
{
"id": model_id,
"object": "model",
"created": int(time.time()),
"owned_by": "notdiamond",
"permission": [],
"root": model_id,
"parent": None,
} for model_id in MODEL_INFO.keys()
]
return jsonify({
"object": "list",
"data": models
})