Spaces:
Sleeping
Sleeping
File size: 13,224 Bytes
bb8150b ebd8cb8 bb8150b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 |
from flask import Flask, request, Response, jsonify
import requests
import uuid
import time
import json
import threading
import logging
import os
# ====== 读取 Huggingface Secret 配置的私有key =======
PRIVATE_KEY = os.environ.get("PRIVATE_KEY", "114514")
SAFE_HEADER = "X-API-KEY"
# 全局接口访问权限检查
def check_private_key():
if request.path in ["/", "/favicon.ico"]:
return
key = request.headers.get(SAFE_HEADER)
if not key or key != PRIVATE_KEY:
return jsonify({"error": "Unauthorized, must provide correct X-API-KEY"}), 401
# 应用所有API鉴权
app = Flask(__name__)
#app.before_request(check_private_key)
# ========== KEY池(每行一个)==========
ONDEMAND_APIKEYS = [
"7oGmV4VoDgkRFUoJzlgEULWLEB0OyF7H",
]
BAD_KEY_RETRY_INTERVAL = 600 # 秒
SESSION_TIMEOUT = 600 # 对话超时时间(10分钟)
# ========== OnDemand模型映射 ==========
MODEL_MAP = {
"gpto3-mini": "predefined-openai-gpto3-mini",
"gpt-4o": "predefined-openai-gpt4o",
"gpt-4.1": "predefined-openai-gpt4.1",
"gpt-4.1-mini": "predefined-openai-gpt4.1-mini",
"gpt-4.1-nano": "predefined-openai-gpt4.1-nano",
"gpt-4o-mini": "predefined-openai-gpt4o-mini",
"deepseek-v3": "predefined-deepseek-v3",
"deepseek-r1": "predefined-deepseek-r1",
"claude-3.7-sonnet": "predefined-claude-3.7-sonnet",
"gemini-2.0-flash": "predefined-gemini-2.0-flash",
}
DEFAULT_ONDEMAND_MODEL = "predefined-openai-gpt4o"
# ==========================================
class KeyManager:
def __init__(self, key_list):
self.key_list = list(key_list)
self.lock = threading.Lock()
self.key_status = {k: {"bad": False, "bad_ts": None} for k in self.key_list}
self.idx = 0
# 新增:当前正在使用的key和session
self.current_key = None
self.current_session = None
self.last_used_time = None
def display_key(self, key):
return f"{key[:6]}...{key[-4:]}"
def get(self):
with self.lock:
now = time.time()
# 检查对话是否超时
if self.current_key and self.last_used_time and (now - self.last_used_time > SESSION_TIMEOUT):
print(f"【对话超时】上次使用时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(self.last_used_time))}")
print(f"【对话超时】当前时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(now))}")
print(f"【对话超时】超时{SESSION_TIMEOUT//60}分钟,切换新会话")
self.current_key = None
self.current_session = None
# 如果已有正在使用的key,继续使用
if self.current_key:
if not self.key_status[self.current_key]["bad"]:
print(f"【对话请求】【继续使用API KEY: {self.display_key(self.current_key)}】【状态:正常】")
self.last_used_time = now
return self.current_key
else:
# 当前key已标记为异常,需要切换
self.current_key = None
self.current_session = None
# 如果没有当前key或当前key无效,选择新的key
total = len(self.key_list)
for _ in range(total):
key = self.key_list[self.idx]
self.idx = (self.idx + 1) % total
s = self.key_status[key]
if not s["bad"]:
print(f"【对话请求】【使用新API KEY: {self.display_key(key)}】【状态:正常】")
self.current_key = key
self.current_session = None # 强制创建新会话
self.last_used_time = now
return key
if s["bad"] and s["bad_ts"]:
ago = now - s["bad_ts"]
if ago >= BAD_KEY_RETRY_INTERVAL:
print(f"【KEY自动尝试恢复】API KEY: {self.display_key(key)} 满足重试周期,标记为正常")
self.key_status[key]["bad"] = False
self.key_status[key]["bad_ts"] = None
self.current_key = key
self.current_session = None # 强制创建新会话
self.last_used_time = now
return key
print("【警告】全部KEY已被禁用,强制选用第一个KEY继续尝试:", self.display_key(self.key_list[0]))
for k in self.key_list:
self.key_status[k]["bad"] = False
self.key_status[k]["bad_ts"] = None
self.idx = 0
self.current_key = self.key_list[0]
self.current_session = None # 强制创建新会话
self.last_used_time = now
print(f"【对话请求】【使用API KEY: {self.display_key(self.current_key)}】【状态:强制尝试(全部异常)】")
return self.current_key
def mark_bad(self, key):
with self.lock:
if key in self.key_status and not self.key_status[key]["bad"]:
print(f"【禁用KEY】API KEY: {self.display_key(key)},接口返回无效(将在{BAD_KEY_RETRY_INTERVAL//60}分钟后自动重试)")
self.key_status[key]["bad"] = True
self.key_status[key]["bad_ts"] = time.time()
if self.current_key == key:
self.current_key = None
self.current_session = None
def get_session(self, apikey):
with self.lock:
if not self.current_session:
try:
self.current_session = create_session(apikey)
print(f"【创建新会话】SESSION ID: {self.current_session}")
except Exception as e:
print(f"【创建会话失败】错误: {str(e)}")
raise
self.last_used_time = time.time()
return self.current_session
keymgr = KeyManager(ONDEMAND_APIKEYS)
ONDEMAND_API_BASE = "https://api.on-demand.io/chat/v1"
def get_endpoint_id(openai_model):
m = str(openai_model or "").lower().replace(" ", "")
return MODEL_MAP.get(m, DEFAULT_ONDEMAND_MODEL)
def create_session(apikey, external_user_id=None, plugin_ids=None):
url = f"{ONDEMAND_API_BASE}/sessions"
payload = {"externalUserId": external_user_id or str(uuid.uuid4())}
if plugin_ids is not None:
payload["pluginIds"] = plugin_ids
headers = {"apikey": apikey, "Content-Type": "application/json"}
resp = requests.post(url, json=payload, headers=headers, timeout=20)
resp.raise_for_status()
return resp.json()["data"]["id"]
def format_openai_sse_delta(chunk_str):
return f"data: {json.dumps(chunk_str, ensure_ascii=False)}\n\n"
@app.route("/v1/chat/completions", methods=["POST"])
def chat_completions():
data = request.json
if not data or "messages" not in data:
return jsonify({"error": "请求缺少messages字段"}), 400
messages = data["messages"]
openai_model = data.get("model", "gpt-4o")
endpoint_id = get_endpoint_id(openai_model)
is_stream = bool(data.get("stream", False))
user_msg = None
for msg in reversed(messages):
if msg.get("role") == "user":
user_msg = msg.get("content")
break
if user_msg is None:
return jsonify({"error": "未找到用户消息"}), 400
def with_valid_key(func):
bad_cnt = 0
max_retry = len(keymgr.key_list)*2
while bad_cnt < max_retry:
key = keymgr.get()
try:
return func(key)
except Exception as e:
if hasattr(e, 'response'):
r = e.response
if r.status_code in (401, 403, 429, 500):
keymgr.mark_bad(key)
bad_cnt += 1
continue
raise
return jsonify({"error": "没有可用API KEY,请补充新KEY或联系技术支持"}), 500
if is_stream:
def generate():
def do_once(apikey):
# 使用KeyManager获取或创建session
sid = keymgr.get_session(apikey)
url = f"{ONDEMAND_API_BASE}/sessions/{sid}/query"
payload = {
"query": user_msg,
"endpointId": endpoint_id,
"pluginIds": [],
"responseMode": "stream"
}
headers = {"apikey": apikey, "Content-Type": "application/json", "Accept": "text/event-stream"}
with requests.post(url, json=payload, headers=headers, stream=True, timeout=120) as resp:
if resp.status_code != 200:
raise requests.HTTPError(response=resp)
answer_acc = ""
first_chunk = True
for line in resp.iter_lines():
if not line:
continue
line = line.decode("utf-8")
if line.startswith("data:"):
datapart = line[5:].strip()
if datapart == "[DONE]":
yield "data: [DONE]\n\n"
break
elif datapart.startswith("[ERROR]:"):
err_json = datapart[len("[ERROR]:"):].strip()
yield format_openai_sse_delta({"error": err_json})
break
else:
try:
js = json.loads(datapart)
except Exception:
continue
if js.get("eventType") == "fulfillment":
delta = js.get("answer", "")
answer_acc += delta
chunk = {
"id": "chatcmpl-" + str(uuid.uuid4())[:8],
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": openai_model,
"choices": [{
"delta": {
"role": "assistant",
"content": delta
} if first_chunk else {
"content": delta
},
"index": 0,
"finish_reason": None
}]
}
yield format_openai_sse_delta(chunk)
first_chunk = False
yield "data: [DONE]\n\n"
yield from with_valid_key(do_once)
return Response(generate(), content_type='text/event-stream')
def nonstream(apikey):
# 使用KeyManager获取或创建session
sid = keymgr.get_session(apikey)
url = f"{ONDEMAND_API_BASE}/sessions/{sid}/query"
payload = {
"query": user_msg,
"endpointId": endpoint_id,
"pluginIds": [],
"responseMode": "sync"
}
headers = {"apikey": apikey, "Content-Type": "application/json"}
resp = requests.post(url, json=payload, headers=headers, timeout=120)
if resp.status_code != 200:
raise requests.HTTPError(response=resp)
ai_response = resp.json()["data"]["answer"]
resp_obj = {
"id": "chatcmpl-" + str(uuid.uuid4())[:8],
"object": "chat.completion",
"created": int(time.time()),
"model": openai_model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": ai_response},
"finish_reason": "stop"
}
],
"usage": {}
}
return jsonify(resp_obj)
return with_valid_key(nonstream)
@app.route("/v1/models", methods=["GET"])
def models():
model_objs = []
for mdl in MODEL_MAP.keys():
model_objs.append({
"id": mdl,
"object": "model",
"owned_by": "ondemand-proxy"
})
uniq = {m["id"]: m for m in model_objs}.values()
return jsonify({
"object": "list",
"data": list(uniq)
})
if __name__ == "__main__":
log_fmt = '[%(asctime)s] %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=log_fmt)
print("======== OnDemand KEY池数量:", len(ONDEMAND_APIKEYS), "========")
app.run(host="0.0.0.0", port=7860, debug=False)
|