keungliang commited on
Commit
74a279e
·
verified ·
1 Parent(s): 52d60d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -243
app.py CHANGED
@@ -1,124 +1,18 @@
1
- from flask import Flask, request, Response, json
2
- import requests
3
- from uuid import uuid4
4
- import time
5
- import os
6
- import re
7
- from flask_cors import CORS
8
-
9
- app = Flask(__name__)
10
- CORS(app) # 启用CORS支持
11
-
12
- # 从环境变量获取 API Key
13
- API_KEY = os.environ.get('API_KEY')
14
- if not API_KEY:
15
- raise ValueError("API_KEY environment variable is required")
16
-
17
- MODEL_MAPPING = {
18
- "deepseek": "deepseek/deepseek-chat",
19
- "gpt-4o-mini": "openai/gpt-4o-mini",
20
- "gemini-flash-1.5": "google/gemini-flash-1.5",
21
- "deepseek-reasoner": "deepseek-reasoner",
22
- "minimax-01": "minimax/minimax-01"
23
- }
24
-
25
- def verify_api_key():
26
- auth_header = request.headers.get('Authorization')
27
- if not auth_header:
28
- return False
29
- try:
30
- # 支持 Bearer token 格式
31
- if auth_header.startswith('Bearer '):
32
- token = auth_header.split(' ')[1]
33
- else:
34
- token = auth_header
35
- return token == API_KEY
36
- except:
37
- return False
38
-
39
- def make_heck_request(question, session_id, messages, actual_model):
40
- previous_question = previous_answer = None
41
- if len(messages) >= 2:
42
- for i in range(len(messages)-2, -1, -1):
43
- if messages[i]["role"] == "user":
44
- previous_question = messages[i]["content"]
45
- if i+1 < len(messages) and messages[i+1]["role"] == "assistant":
46
- previous_answer = messages[i+1]["content"]
47
- break
48
-
49
- payload = {
50
- "model": actual_model,
51
- "question": question,
52
- "language": "Chinese",
53
- "sessionId": session_id,
54
- "previousQuestion": previous_question,
55
- "previousAnswer": previous_answer
56
- }
57
-
58
- headers = {
59
- "Content-Type": "application/json",
60
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
61
- }
62
-
63
- return requests.post(
64
- "https://gateway.aiapilab.com/api/ha/v1/chat",
65
- json=payload,
66
- headers=headers,
67
- stream=True
68
- )
69
-
70
- # 以下兩個輔助函式,分段輸出以保留換行、清單等 Markdown 格式
71
- def parse_markdown_content(content):
72
- """
73
- 先按「空行」拆成段落,再在每個段落裡偵測清單項目 (如 "1. "、"2. " 等),
74
- 將其獨立拆分,以便逐段輸出。
75
- """
76
- paragraph_regex = re.compile(r'(?:\r?\n){2,}') # 連續2個以上換行做為分隔
77
- paragraphs = paragraph_regex.split(content)
78
-
79
- for paragraph in paragraphs:
80
- # 再以清單項 (數字. 或數字) 拆分
81
- yield from chunk_paragraph_by_list_item(paragraph)
82
-
83
- def chunk_paragraph_by_list_item(paragraph):
84
- """
85
- 將段落中的清單項 ("1. ", "2. " 等) 與其他文字區分開,逐段產出。
86
- """
87
- list_item_regex = re.compile(r'(^|\n)\s*\d+\.\s+')
88
- last_index = 0
89
-
90
- for match in list_item_regex.finditer(paragraph):
91
- # 先把清單標記前的文字輸出
92
- if match.start() > last_index:
93
- text_before = paragraph[last_index:match.start()]
94
- if text_before.strip():
95
- yield text_before + "\n"
96
- # 輸出清單項本身 (例如 "1. ")
97
- yield match.group(0)
98
- last_index = match.end()
99
-
100
- # 如果後面還有剩餘文字,也輸出
101
- if last_index < len(paragraph):
102
- yield paragraph[last_index:] + "\n"
103
-
104
- # 段落結尾再補一個空行
105
- yield "\n"
106
-
107
  def stream_response(question, session_id, messages, request_model, actual_model):
108
  resp = make_heck_request(question, session_id, messages, actual_model)
109
  is_answering = False
 
110
 
111
  for line in resp.iter_lines():
112
  if line:
113
  line = line.decode('utf-8')
114
  if not line.startswith('data: '):
115
  continue
116
-
117
  content = line[6:].strip()
118
-
119
  if content == "[ANSWER_START]":
120
  is_answering = True
121
- # 先送出角色宣告chunk
122
  chunk = {
123
  "id": session_id,
124
  "object": "chat.completion.chunk",
@@ -131,9 +25,23 @@ def stream_response(question, session_id, messages, request_model, actual_model)
131
  }
132
  yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
133
  continue
134
-
135
  if content == "[ANSWER_DONE]":
136
- # 最後的結束chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  chunk = {
138
  "id": session_id,
139
  "object": "chat.completion.chunk",
@@ -147,136 +55,17 @@ def stream_response(question, session_id, messages, request_model, actual_model)
147
  }
148
  yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
149
  break
150
-
151
- # 如果確定進入答題階段,且內容不是 [RELATE_Q...] 等系統標記,就進行分段輸出
152
  if is_answering and content and not content.startswith("[RELATE_Q"):
153
- # 這裡參考了 xgrok 中的「段落 + 清單項」分段邏輯
154
- for sub_content in parse_markdown_content(content):
155
- # 去除如果整段都是空白,就不輸出了
156
- if not sub_content.strip():
157
- continue
158
- chunk = {
159
- "id": session_id,
160
- "object": "chat.completion.chunk",
161
- "created": int(time.time()),
162
- "model": request_model,
163
- "choices": [{
164
- "index": 0,
165
- "delta": {"content": sub_content},
166
- }]
167
- }
168
- yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
169
-
170
- def normal_response(question, session_id, messages, request_model, actual_model):
171
- resp = make_heck_request(question, session_id, messages, actual_model)
172
- full_content = []
173
- is_answering = False
174
-
175
- for line in resp.iter_lines():
176
- if line:
177
- line = line.decode('utf-8')
178
- if line.startswith('data: '):
179
- content = line[6:].strip()
180
- if content == "[ANSWER_START]":
181
- is_answering = True
182
- elif content == "[ANSWER_DONE]":
183
- break
184
- elif is_answering:
185
- full_content.append(content)
186
-
187
- response = {
188
- "id": session_id,
189
- "object": "chat.completion",
190
- "created": int(time.time()),
191
- "model": request_model,
192
- "choices": [{
193
- "index": 0,
194
- "message": {
195
- "role": "assistant",
196
- "content": "".join(full_content)
197
- },
198
- "finish_reason": "stop"
199
- }]
200
- }
201
- return response
202
-
203
- @app.route("/hf/v1/models", methods=["GET"])
204
- def list_models():
205
- models = []
206
- for model_id, _ in MODEL_MAPPING.items():
207
- models.append({
208
- "id": model_id,
209
- "object": "model",
210
- "created": int(time.time()),
211
- "owned_by": "heck",
212
- })
213
-
214
- return {
215
- "object": "list",
216
- "data": models
217
- }
218
-
219
- @app.route("/hf/v1/chat/completions", methods=["POST"])
220
- def chat_completions():
221
- # API Key 验证
222
- if not verify_api_key():
223
- return {"error": "Invalid API Key"}, 401
224
-
225
- data = request.json
226
-
227
- if not data or "model" not in data:
228
- return {"error": "Invalid request - missing model"}, 400
229
-
230
- if not data.get("messages"):
231
- return {"error": "Invalid request - missing messages"}, 400
232
-
233
- # 验证消息格式
234
- for msg in data["messages"]:
235
- if not isinstance(msg, dict):
236
- return {"error": "Invalid message format"}, 400
237
- if "role" not in msg or "content" not in msg:
238
- return {"error": "Invalid message format"}, 400
239
-
240
- # 检查content的类型
241
- if isinstance(msg["content"], list):
242
- # 如果content是列表,确保每个元素都有text字段
243
- for item in msg["content"]:
244
- if not isinstance(item, dict) or "text" not in item:
245
- return {"error": "Invalid content format"}, 400
246
- # 提取所有text字段并合并
247
- msg["content"] = " ".join(item["text"] for item in msg["content"])
248
- elif not isinstance(msg["content"], str):
249
- return {"error": "Invalid content type"}, 400
250
-
251
- model = MODEL_MAPPING.get(data["model"])
252
- if not model:
253
- return {"error": "Unsupported Model"}, 400
254
-
255
- try:
256
- question = next((msg["content"] for msg in reversed(data["messages"])
257
- if msg["role"] == "user"), None)
258
- except Exception as e:
259
- return {"error": "Failed to extract question"}, 400
260
-
261
- if not question:
262
- return {"error": "No user message found"}, 400
263
-
264
- session_id = str(uuid4())
265
-
266
- try:
267
- if data.get("stream"):
268
- return Response(
269
- stream_response(question, session_id, data["messages"],
270
- data["model"], model),
271
- mimetype="text/event-stream"
272
- )
273
- else:
274
- return normal_response(question, session_id, data["messages"],
275
- data["model"], model)
276
- except Exception as e:
277
- return {"error": f"Internal server error: {str(e)}"}, 500
278
-
279
- if __name__ == "__main__":
280
- # 使用环境变量获取端口,默认为7860(HF Spaces 默认端口)
281
- port = int(os.environ.get("PORT", 7860))
282
- app.run(host='0.0.0.0', port=port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def stream_response(question, session_id, messages, request_model, actual_model):
2
  resp = make_heck_request(question, session_id, messages, actual_model)
3
  is_answering = False
4
+ buffer = ""
5
 
6
  for line in resp.iter_lines():
7
  if line:
8
  line = line.decode('utf-8')
9
  if not line.startswith('data: '):
10
  continue
11
+
12
  content = line[6:].strip()
13
+
14
  if content == "[ANSWER_START]":
15
  is_answering = True
 
16
  chunk = {
17
  "id": session_id,
18
  "object": "chat.completion.chunk",
 
25
  }
26
  yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
27
  continue
28
+
29
  if content == "[ANSWER_DONE]":
30
+ # 如果buffer中還有內容,先輸出
31
+ if buffer:
32
+ chunk = {
33
+ "id": session_id,
34
+ "object": "chat.completion.chunk",
35
+ "created": int(time.time()),
36
+ "model": request_model,
37
+ "choices": [{
38
+ "index": 0,
39
+ "delta": {"content": buffer},
40
+ }]
41
+ }
42
+ yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
43
+
44
+ # 結束標記
45
  chunk = {
46
  "id": session_id,
47
  "object": "chat.completion.chunk",
 
55
  }
56
  yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
57
  break
58
+
 
59
  if is_answering and content and not content.startswith("[RELATE_Q"):
60
+ # 直接輸出內容,不做額外的格式處理
61
+ chunk = {
62
+ "id": session_id,
63
+ "object": "chat.completion.chunk",
64
+ "created": int(time.time()),
65
+ "model": request_model,
66
+ "choices": [{
67
+ "index": 0,
68
+ "delta": {"content": content},
69
+ }]
70
+ }
71
+ yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"