Spaces:
Sleeping
Sleeping
keungliang
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,124 +1,18 @@
|
|
1 |
-
from flask import Flask, request, Response, json
|
2 |
-
import requests
|
3 |
-
from uuid import uuid4
|
4 |
-
import time
|
5 |
-
import os
|
6 |
-
import re
|
7 |
-
from flask_cors import CORS
|
8 |
-
|
9 |
-
app = Flask(__name__)
|
10 |
-
CORS(app) # 启用CORS支持
|
11 |
-
|
12 |
-
# 从环境变量获取 API Key
|
13 |
-
API_KEY = os.environ.get('API_KEY')
|
14 |
-
if not API_KEY:
|
15 |
-
raise ValueError("API_KEY environment variable is required")
|
16 |
-
|
17 |
-
MODEL_MAPPING = {
|
18 |
-
"deepseek": "deepseek/deepseek-chat",
|
19 |
-
"gpt-4o-mini": "openai/gpt-4o-mini",
|
20 |
-
"gemini-flash-1.5": "google/gemini-flash-1.5",
|
21 |
-
"deepseek-reasoner": "deepseek-reasoner",
|
22 |
-
"minimax-01": "minimax/minimax-01"
|
23 |
-
}
|
24 |
-
|
25 |
-
def verify_api_key():
|
26 |
-
auth_header = request.headers.get('Authorization')
|
27 |
-
if not auth_header:
|
28 |
-
return False
|
29 |
-
try:
|
30 |
-
# 支持 Bearer token 格式
|
31 |
-
if auth_header.startswith('Bearer '):
|
32 |
-
token = auth_header.split(' ')[1]
|
33 |
-
else:
|
34 |
-
token = auth_header
|
35 |
-
return token == API_KEY
|
36 |
-
except:
|
37 |
-
return False
|
38 |
-
|
39 |
-
def make_heck_request(question, session_id, messages, actual_model):
|
40 |
-
previous_question = previous_answer = None
|
41 |
-
if len(messages) >= 2:
|
42 |
-
for i in range(len(messages)-2, -1, -1):
|
43 |
-
if messages[i]["role"] == "user":
|
44 |
-
previous_question = messages[i]["content"]
|
45 |
-
if i+1 < len(messages) and messages[i+1]["role"] == "assistant":
|
46 |
-
previous_answer = messages[i+1]["content"]
|
47 |
-
break
|
48 |
-
|
49 |
-
payload = {
|
50 |
-
"model": actual_model,
|
51 |
-
"question": question,
|
52 |
-
"language": "Chinese",
|
53 |
-
"sessionId": session_id,
|
54 |
-
"previousQuestion": previous_question,
|
55 |
-
"previousAnswer": previous_answer
|
56 |
-
}
|
57 |
-
|
58 |
-
headers = {
|
59 |
-
"Content-Type": "application/json",
|
60 |
-
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
61 |
-
}
|
62 |
-
|
63 |
-
return requests.post(
|
64 |
-
"https://gateway.aiapilab.com/api/ha/v1/chat",
|
65 |
-
json=payload,
|
66 |
-
headers=headers,
|
67 |
-
stream=True
|
68 |
-
)
|
69 |
-
|
70 |
-
# 以下兩個輔助函式,分段輸出以保留換行、清單等 Markdown 格式
|
71 |
-
def parse_markdown_content(content):
|
72 |
-
"""
|
73 |
-
先按「空行」拆成段落,再在每個段落裡偵測清單項目 (如 "1. "、"2. " 等),
|
74 |
-
將其獨立拆分,以便逐段輸出。
|
75 |
-
"""
|
76 |
-
paragraph_regex = re.compile(r'(?:\r?\n){2,}') # 連續2個以上換行做為分隔
|
77 |
-
paragraphs = paragraph_regex.split(content)
|
78 |
-
|
79 |
-
for paragraph in paragraphs:
|
80 |
-
# 再以清單項 (數字. 或數字) 拆分
|
81 |
-
yield from chunk_paragraph_by_list_item(paragraph)
|
82 |
-
|
83 |
-
def chunk_paragraph_by_list_item(paragraph):
|
84 |
-
"""
|
85 |
-
將段落中的清單項 ("1. ", "2. " 等) 與其他文字區分開,逐段產出。
|
86 |
-
"""
|
87 |
-
list_item_regex = re.compile(r'(^|\n)\s*\d+\.\s+')
|
88 |
-
last_index = 0
|
89 |
-
|
90 |
-
for match in list_item_regex.finditer(paragraph):
|
91 |
-
# 先把清單標記前的文字輸出
|
92 |
-
if match.start() > last_index:
|
93 |
-
text_before = paragraph[last_index:match.start()]
|
94 |
-
if text_before.strip():
|
95 |
-
yield text_before + "\n"
|
96 |
-
# 輸出清單項本身 (例如 "1. ")
|
97 |
-
yield match.group(0)
|
98 |
-
last_index = match.end()
|
99 |
-
|
100 |
-
# 如果後面還有剩餘文字,也輸出
|
101 |
-
if last_index < len(paragraph):
|
102 |
-
yield paragraph[last_index:] + "\n"
|
103 |
-
|
104 |
-
# 段落結尾再補一個空行
|
105 |
-
yield "\n"
|
106 |
-
|
107 |
def stream_response(question, session_id, messages, request_model, actual_model):
|
108 |
resp = make_heck_request(question, session_id, messages, actual_model)
|
109 |
is_answering = False
|
|
|
110 |
|
111 |
for line in resp.iter_lines():
|
112 |
if line:
|
113 |
line = line.decode('utf-8')
|
114 |
if not line.startswith('data: '):
|
115 |
continue
|
116 |
-
|
117 |
content = line[6:].strip()
|
118 |
-
|
119 |
if content == "[ANSWER_START]":
|
120 |
is_answering = True
|
121 |
-
# 先送出角色宣告chunk
|
122 |
chunk = {
|
123 |
"id": session_id,
|
124 |
"object": "chat.completion.chunk",
|
@@ -131,9 +25,23 @@ def stream_response(question, session_id, messages, request_model, actual_model)
|
|
131 |
}
|
132 |
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
133 |
continue
|
134 |
-
|
135 |
if content == "[ANSWER_DONE]":
|
136 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
chunk = {
|
138 |
"id": session_id,
|
139 |
"object": "chat.completion.chunk",
|
@@ -147,136 +55,17 @@ def stream_response(question, session_id, messages, request_model, actual_model)
|
|
147 |
}
|
148 |
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
149 |
break
|
150 |
-
|
151 |
-
# 如果確定進入答題階段,且內容不是 [RELATE_Q...] 等系統標記,就進行分段輸出
|
152 |
if is_answering and content and not content.startswith("[RELATE_Q"):
|
153 |
-
#
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
"
|
161 |
-
"
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
"delta": {"content": sub_content},
|
166 |
-
}]
|
167 |
-
}
|
168 |
-
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
169 |
-
|
170 |
-
def normal_response(question, session_id, messages, request_model, actual_model):
|
171 |
-
resp = make_heck_request(question, session_id, messages, actual_model)
|
172 |
-
full_content = []
|
173 |
-
is_answering = False
|
174 |
-
|
175 |
-
for line in resp.iter_lines():
|
176 |
-
if line:
|
177 |
-
line = line.decode('utf-8')
|
178 |
-
if line.startswith('data: '):
|
179 |
-
content = line[6:].strip()
|
180 |
-
if content == "[ANSWER_START]":
|
181 |
-
is_answering = True
|
182 |
-
elif content == "[ANSWER_DONE]":
|
183 |
-
break
|
184 |
-
elif is_answering:
|
185 |
-
full_content.append(content)
|
186 |
-
|
187 |
-
response = {
|
188 |
-
"id": session_id,
|
189 |
-
"object": "chat.completion",
|
190 |
-
"created": int(time.time()),
|
191 |
-
"model": request_model,
|
192 |
-
"choices": [{
|
193 |
-
"index": 0,
|
194 |
-
"message": {
|
195 |
-
"role": "assistant",
|
196 |
-
"content": "".join(full_content)
|
197 |
-
},
|
198 |
-
"finish_reason": "stop"
|
199 |
-
}]
|
200 |
-
}
|
201 |
-
return response
|
202 |
-
|
203 |
-
@app.route("/hf/v1/models", methods=["GET"])
|
204 |
-
def list_models():
|
205 |
-
models = []
|
206 |
-
for model_id, _ in MODEL_MAPPING.items():
|
207 |
-
models.append({
|
208 |
-
"id": model_id,
|
209 |
-
"object": "model",
|
210 |
-
"created": int(time.time()),
|
211 |
-
"owned_by": "heck",
|
212 |
-
})
|
213 |
-
|
214 |
-
return {
|
215 |
-
"object": "list",
|
216 |
-
"data": models
|
217 |
-
}
|
218 |
-
|
219 |
-
@app.route("/hf/v1/chat/completions", methods=["POST"])
|
220 |
-
def chat_completions():
|
221 |
-
# API Key 验证
|
222 |
-
if not verify_api_key():
|
223 |
-
return {"error": "Invalid API Key"}, 401
|
224 |
-
|
225 |
-
data = request.json
|
226 |
-
|
227 |
-
if not data or "model" not in data:
|
228 |
-
return {"error": "Invalid request - missing model"}, 400
|
229 |
-
|
230 |
-
if not data.get("messages"):
|
231 |
-
return {"error": "Invalid request - missing messages"}, 400
|
232 |
-
|
233 |
-
# 验证消息格式
|
234 |
-
for msg in data["messages"]:
|
235 |
-
if not isinstance(msg, dict):
|
236 |
-
return {"error": "Invalid message format"}, 400
|
237 |
-
if "role" not in msg or "content" not in msg:
|
238 |
-
return {"error": "Invalid message format"}, 400
|
239 |
-
|
240 |
-
# 检查content的类型
|
241 |
-
if isinstance(msg["content"], list):
|
242 |
-
# 如果content是列表,确保每个元素都有text字段
|
243 |
-
for item in msg["content"]:
|
244 |
-
if not isinstance(item, dict) or "text" not in item:
|
245 |
-
return {"error": "Invalid content format"}, 400
|
246 |
-
# 提取所有text字段并合并
|
247 |
-
msg["content"] = " ".join(item["text"] for item in msg["content"])
|
248 |
-
elif not isinstance(msg["content"], str):
|
249 |
-
return {"error": "Invalid content type"}, 400
|
250 |
-
|
251 |
-
model = MODEL_MAPPING.get(data["model"])
|
252 |
-
if not model:
|
253 |
-
return {"error": "Unsupported Model"}, 400
|
254 |
-
|
255 |
-
try:
|
256 |
-
question = next((msg["content"] for msg in reversed(data["messages"])
|
257 |
-
if msg["role"] == "user"), None)
|
258 |
-
except Exception as e:
|
259 |
-
return {"error": "Failed to extract question"}, 400
|
260 |
-
|
261 |
-
if not question:
|
262 |
-
return {"error": "No user message found"}, 400
|
263 |
-
|
264 |
-
session_id = str(uuid4())
|
265 |
-
|
266 |
-
try:
|
267 |
-
if data.get("stream"):
|
268 |
-
return Response(
|
269 |
-
stream_response(question, session_id, data["messages"],
|
270 |
-
data["model"], model),
|
271 |
-
mimetype="text/event-stream"
|
272 |
-
)
|
273 |
-
else:
|
274 |
-
return normal_response(question, session_id, data["messages"],
|
275 |
-
data["model"], model)
|
276 |
-
except Exception as e:
|
277 |
-
return {"error": f"Internal server error: {str(e)}"}, 500
|
278 |
-
|
279 |
-
if __name__ == "__main__":
|
280 |
-
# 使用环境变量获取端口,默认为7860(HF Spaces 默认端口)
|
281 |
-
port = int(os.environ.get("PORT", 7860))
|
282 |
-
app.run(host='0.0.0.0', port=port)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
def stream_response(question, session_id, messages, request_model, actual_model):
|
2 |
resp = make_heck_request(question, session_id, messages, actual_model)
|
3 |
is_answering = False
|
4 |
+
buffer = ""
|
5 |
|
6 |
for line in resp.iter_lines():
|
7 |
if line:
|
8 |
line = line.decode('utf-8')
|
9 |
if not line.startswith('data: '):
|
10 |
continue
|
11 |
+
|
12 |
content = line[6:].strip()
|
13 |
+
|
14 |
if content == "[ANSWER_START]":
|
15 |
is_answering = True
|
|
|
16 |
chunk = {
|
17 |
"id": session_id,
|
18 |
"object": "chat.completion.chunk",
|
|
|
25 |
}
|
26 |
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
27 |
continue
|
28 |
+
|
29 |
if content == "[ANSWER_DONE]":
|
30 |
+
# 如果buffer中還有內容,先輸出
|
31 |
+
if buffer:
|
32 |
+
chunk = {
|
33 |
+
"id": session_id,
|
34 |
+
"object": "chat.completion.chunk",
|
35 |
+
"created": int(time.time()),
|
36 |
+
"model": request_model,
|
37 |
+
"choices": [{
|
38 |
+
"index": 0,
|
39 |
+
"delta": {"content": buffer},
|
40 |
+
}]
|
41 |
+
}
|
42 |
+
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
43 |
+
|
44 |
+
# 結束標記
|
45 |
chunk = {
|
46 |
"id": session_id,
|
47 |
"object": "chat.completion.chunk",
|
|
|
55 |
}
|
56 |
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
57 |
break
|
58 |
+
|
|
|
59 |
if is_answering and content and not content.startswith("[RELATE_Q"):
|
60 |
+
# 直接輸出內容,不做額外的格式處理
|
61 |
+
chunk = {
|
62 |
+
"id": session_id,
|
63 |
+
"object": "chat.completion.chunk",
|
64 |
+
"created": int(time.time()),
|
65 |
+
"model": request_model,
|
66 |
+
"choices": [{
|
67 |
+
"index": 0,
|
68 |
+
"delta": {"content": content},
|
69 |
+
}]
|
70 |
+
}
|
71 |
+
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|