Tuchuanhuhuhu commited on
Commit
14e3e6a
1 Parent(s): 55e027c

使用jieba估计实时传输模式的token计数

Browse files
Files changed (3) hide show
  1. presets.py +1 -1
  2. requirements.txt +1 -0
  3. utils.py +11 -2
presets.py CHANGED
@@ -34,7 +34,7 @@ pre code {
34
  standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
35
  error_retrieve_prompt = "连接超时,无法获取对话。请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
36
  summarize_prompt = "请总结以上对话,不超过100字。" # 总结对话时的 prompt
37
- max_token_streaming = 2000 # 流式对话时的最大 token 数
38
  timeout_streaming = 5 # 流式对话时的超时时间
39
  max_token_all = 3500 # 非流式对话时的最大 token 数
40
  timeout_all = 200 # 非流式对话时的超时时间
 
34
  standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
35
  error_retrieve_prompt = "连接超时,无法获取对话。请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
36
  summarize_prompt = "请总结以上对话,不超过100字。" # 总结对话时的 prompt
37
+ max_token_streaming = 400 # 流式对话时的最大 token 数
38
  timeout_streaming = 5 # 流式对话时的超时时间
39
  max_token_all = 3500 # 非流式对话时的最大 token 数
40
  timeout_all = 200 # 非流式对话时的超时时间
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  gradio
2
  mdtex2html
3
  pypinyin
 
 
1
  gradio
2
  mdtex2html
3
  pypinyin
4
+ jieba
utils.py CHANGED
@@ -12,6 +12,7 @@ import csv
12
  import mdtex2html
13
  from pypinyin import lazy_pinyin
14
  from presets import *
 
15
 
16
  if TYPE_CHECKING:
17
  from typing import TypedDict
@@ -45,6 +46,10 @@ def postprocess(
45
  )
46
  return y
47
 
 
 
 
 
48
  def parse_text(text):
49
  lines = text.split("\n")
50
  lines = [line for line in lines if line != ""]
@@ -89,7 +94,7 @@ def construct_assistant(text):
89
  return construct_text("assistant", text)
90
 
91
  def construct_token_message(token, stream=False):
92
- extra = "【仅包含回答的计数】 " if stream else ""
93
  return f"{extra}Token 计数: {token}"
94
 
95
  def get_response(openai_api_key, system_prompt, history, temperature, top_p, stream):
@@ -125,6 +130,10 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
125
  counter = 0
126
  status_text = "OK"
127
  history.append(construct_user(inputs))
 
 
 
 
128
  try:
129
  response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True)
130
  except requests.exceptions.ConnectTimeout:
@@ -148,7 +157,7 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
148
  # decode each line as response data is in bytes
149
  if chunklength > 6 and "delta" in chunk['choices'][0]:
150
  finish_reason = chunk['choices'][0]['finish_reason']
151
- status_text = construct_token_message(sum(previous_token_count)+token_counter, stream=True)
152
  if finish_reason == "stop":
153
  yield get_return_value()
154
  break
 
12
  import mdtex2html
13
  from pypinyin import lazy_pinyin
14
  from presets import *
15
+ import jieba
16
 
17
  if TYPE_CHECKING:
18
  from typing import TypedDict
 
46
  )
47
  return y
48
 
49
+ def count_words(input_str):
50
+ words = jieba.lcut(input_str)
51
+ return len(words)
52
+
53
  def parse_text(text):
54
  lines = text.split("\n")
55
  lines = [line for line in lines if line != ""]
 
94
  return construct_text("assistant", text)
95
 
96
  def construct_token_message(token, stream=False):
97
+ extra = "【粗略计数(因为实时传输回答)】 " if stream else ""
98
  return f"{extra}Token 计数: {token}"
99
 
100
  def get_response(openai_api_key, system_prompt, history, temperature, top_p, stream):
 
130
  counter = 0
131
  status_text = "OK"
132
  history.append(construct_user(inputs))
133
+ if len(previous_token_count) == 0:
134
+ rough_user_token_count = count_words(inputs) + count_words(system_prompt)
135
+ else:
136
+ rough_user_token_count = count_words(inputs)
137
  try:
138
  response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True)
139
  except requests.exceptions.ConnectTimeout:
 
157
  # decode each line as response data is in bytes
158
  if chunklength > 6 and "delta" in chunk['choices'][0]:
159
  finish_reason = chunk['choices'][0]['finish_reason']
160
+ status_text = construct_token_message(sum(previous_token_count)+token_counter+rough_user_token_count, stream=True)
161
  if finish_reason == "stop":
162
  yield get_return_value()
163
  break