Update app.py
Browse files
app.py
CHANGED
@@ -7,32 +7,7 @@ import socketio
|
|
7 |
import requests
|
8 |
import logging
|
9 |
from threading import Event
|
10 |
-
import
|
11 |
-
from tiktoken import Encoding
|
12 |
-
|
13 |
-
def local_encoding_for_model(model_name: str):
|
14 |
-
"""
|
15 |
-
从本地加载编码文件并返回一个 Encoding 对象。
|
16 |
-
"""
|
17 |
-
local_encoding_path = '/app/cl100k_base.tiktoken'
|
18 |
-
if os.path.exists(local_encoding_path):
|
19 |
-
with open(local_encoding_path, 'rb') as f:
|
20 |
-
encoding_data = f.read() # 读取本地编码文件的字节内容
|
21 |
-
|
22 |
-
# 构造一个 Encoding 对象
|
23 |
-
return Encoding(
|
24 |
-
name="cl100k_base", # 编码的名称
|
25 |
-
pat_str="", # 正则表达式(如果有)
|
26 |
-
mergeable_ranks={}, # 合并的 rank 数据(通常是从文件或其他地方加载)
|
27 |
-
special_tokens={}, # 特殊 token 映射
|
28 |
-
explicit_n_vocab=None # 可选的词汇表大小
|
29 |
-
)
|
30 |
-
else:
|
31 |
-
raise FileNotFoundError(f"Local encoding file not found at {local_encoding_path}")
|
32 |
-
|
33 |
-
# 替换 tiktoken 的 encoding_for_model 函数
|
34 |
-
tiktoken.encoding_for_model = local_encoding_for_model
|
35 |
-
|
36 |
|
37 |
app = Flask(__name__)
|
38 |
logging.basicConfig(level=logging.INFO)
|
@@ -100,14 +75,20 @@ def normalize_content(content):
|
|
100 |
# 如果是其他类型,返回空字符串
|
101 |
return ""
|
102 |
|
103 |
-
def
|
104 |
"""
|
105 |
-
|
106 |
-
|
|
|
107 |
"""
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
@app.route('/')
|
113 |
def root():
|
@@ -145,8 +126,8 @@ def messages():
|
|
145 |
# 使用 normalize_content 递归处理 msg['content']
|
146 |
previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
|
147 |
|
148 |
-
# 动态计算输入的 token
|
149 |
-
input_tokens =
|
150 |
|
151 |
msg_id = str(uuid.uuid4())
|
152 |
response_event = Event()
|
@@ -248,8 +229,8 @@ def messages():
|
|
248 |
if sio.connected:
|
249 |
sio.disconnect()
|
250 |
|
251 |
-
# 动态计算输出的 token
|
252 |
-
output_tokens =
|
253 |
|
254 |
yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
|
255 |
yield create_event("message_delta", {
|
@@ -323,8 +304,8 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
|
|
323 |
# 等待响应完成
|
324 |
response_event.wait(timeout=30)
|
325 |
|
326 |
-
# 动态计算输出的 token
|
327 |
-
output_tokens =
|
328 |
|
329 |
# 生成完整的响应
|
330 |
full_response = {
|
|
|
7 |
import requests
|
8 |
import logging
|
9 |
from threading import Event
|
10 |
+
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
app = Flask(__name__)
|
13 |
logging.basicConfig(level=logging.INFO)
|
|
|
75 |
# 如果是其他类型,返回空字符串
|
76 |
return ""
|
77 |
|
78 |
+
def calculate_tokens(text):
|
79 |
"""
|
80 |
+
改进的 token 计算方法。
|
81 |
+
- 对于英文和有空格的文本,使用空格分词。
|
82 |
+
- 对于中文等没有空格的文本,使用字符级分词。
|
83 |
"""
|
84 |
+
# 首先判断文本是否包含大量非 ASCII 字符(如中文)
|
85 |
+
if re.search(r'[^\x00-\x7F]', text):
|
86 |
+
# 如果包含非 ASCII 字符,使用字符级分词
|
87 |
+
return len(text)
|
88 |
+
else:
|
89 |
+
# 否则使用空格分词
|
90 |
+
tokens = text.split()
|
91 |
+
return len(tokens)
|
92 |
|
93 |
@app.route('/')
|
94 |
def root():
|
|
|
126 |
# 使用 normalize_content 递归处理 msg['content']
|
127 |
previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
|
128 |
|
129 |
+
# 动态计算输入的 token 数量
|
130 |
+
input_tokens = calculate_tokens(previous_messages)
|
131 |
|
132 |
msg_id = str(uuid.uuid4())
|
133 |
response_event = Event()
|
|
|
229 |
if sio.connected:
|
230 |
sio.disconnect()
|
231 |
|
232 |
+
# 动态计算输出的 token 数量
|
233 |
+
output_tokens = calculate_tokens(''.join(response_text))
|
234 |
|
235 |
yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
|
236 |
yield create_event("message_delta", {
|
|
|
304 |
# 等待响应完成
|
305 |
response_event.wait(timeout=30)
|
306 |
|
307 |
+
# 动态计算输出的 token 数量
|
308 |
+
output_tokens = calculate_tokens(''.join(response_text))
|
309 |
|
310 |
# 生成完整的响应
|
311 |
full_response = {
|