Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
893df38
1
Parent(s):
26d3937
加入了更多诊断信息
Browse files- ChuanhuChatbot.py +4 -3
- requirements.txt +2 -0
- utils.py +34 -12
ChuanhuChatbot.py
CHANGED
@@ -44,9 +44,10 @@ gr.Chatbot.postprocess = postprocess
|
|
44 |
with gr.Blocks(css=customCSS) as demo:
|
45 |
gr.HTML(title)
|
46 |
with gr.Row():
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
50 |
chatbot = gr.Chatbot() # .style(color_map=("#1D51EE", "#585A5B"))
|
51 |
history = gr.State([])
|
52 |
token_count = gr.State([])
|
|
|
44 |
with gr.Blocks(css=customCSS) as demo:
|
45 |
gr.HTML(title)
|
46 |
with gr.Row():
|
47 |
+
with gr.Column(scale=4):
|
48 |
+
keyTxt = gr.Textbox(show_label=False, placeholder=f"在这里输入你的OpenAI API-key...",value=my_api_key, type="password", visible=not HIDE_MY_KEY).style(container=True)
|
49 |
+
with gr.Column(scale=1):
|
50 |
+
use_streaming_checkbox = gr.Checkbox(label="实时传输回答", value=True, visible=enable_streaming_option)
|
51 |
chatbot = gr.Chatbot() # .style(color_map=("#1D51EE", "#585A5B"))
|
52 |
history = gr.State([])
|
53 |
token_count = gr.State([])
|
requirements.txt
CHANGED
@@ -2,3 +2,5 @@ gradio
|
|
2 |
mdtex2html
|
3 |
pypinyin
|
4 |
jieba
|
|
|
|
|
|
2 |
mdtex2html
|
3 |
pypinyin
|
4 |
jieba
|
5 |
+
socksio
|
6 |
+
tqdm
|
utils.py
CHANGED
@@ -13,6 +13,7 @@ import mdtex2html
|
|
13 |
from pypinyin import lazy_pinyin
|
14 |
from presets import *
|
15 |
import jieba
|
|
|
16 |
|
17 |
if TYPE_CHECKING:
|
18 |
from typing import TypedDict
|
@@ -47,7 +48,9 @@ def postprocess(
|
|
47 |
return y
|
48 |
|
49 |
def count_words(input_str):
|
|
|
50 |
words = jieba.lcut(input_str)
|
|
|
51 |
return len(words)
|
52 |
|
53 |
def parse_text(text):
|
@@ -125,10 +128,12 @@ def get_response(openai_api_key, system_prompt, history, temperature, top_p, str
|
|
125 |
def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, previous_token_count, top_p, temperature):
|
126 |
def get_return_value():
|
127 |
return chatbot, history, status_text, [*previous_token_count, token_counter]
|
|
|
|
|
128 |
token_counter = 0
|
129 |
partial_words = ""
|
130 |
counter = 0
|
131 |
-
status_text = "
|
132 |
history.append(construct_user(inputs))
|
133 |
if len(previous_token_count) == 0:
|
134 |
rough_user_token_count = count_words(inputs) + count_words(system_prompt)
|
@@ -144,7 +149,7 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
|
|
144 |
chatbot.append((parse_text(inputs), ""))
|
145 |
yield get_return_value()
|
146 |
|
147 |
-
for chunk in response.iter_lines():
|
148 |
if counter == 0:
|
149 |
counter += 1
|
150 |
continue
|
@@ -159,6 +164,7 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
|
|
159 |
finish_reason = chunk['choices'][0]['finish_reason']
|
160 |
status_text = construct_token_message(sum(previous_token_count)+token_counter+rough_user_token_count, stream=True)
|
161 |
if finish_reason == "stop":
|
|
|
162 |
yield get_return_value()
|
163 |
break
|
164 |
partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
|
@@ -172,6 +178,7 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
|
|
172 |
|
173 |
|
174 |
def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, previous_token_count, top_p, temperature):
|
|
|
175 |
history.append(construct_user(inputs))
|
176 |
try:
|
177 |
response = get_response(openai_api_key, system_prompt, history, temperature, top_p, False)
|
@@ -185,22 +192,27 @@ def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, previou
|
|
185 |
total_token_count = response["usage"]["total_tokens"]
|
186 |
previous_token_count.append(total_token_count - sum(previous_token_count))
|
187 |
status_text = construct_token_message(total_token_count)
|
|
|
188 |
return chatbot, history, status_text, previous_token_count
|
189 |
|
190 |
|
191 |
def predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=False, should_check_token_count = True): # repetition_penalty, top_k
|
192 |
if stream:
|
|
|
193 |
iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature)
|
194 |
for chatbot, history, status_text, token_count in iter:
|
195 |
yield chatbot, history, status_text, token_count
|
196 |
else:
|
|
|
197 |
chatbot, history, status_text, token_count = predict_all(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature)
|
198 |
yield chatbot, history, status_text, token_count
|
|
|
199 |
if stream:
|
200 |
max_token = max_token_streaming
|
201 |
else:
|
202 |
max_token = max_token_all
|
203 |
if sum(token_count) > max_token and should_check_token_count:
|
|
|
204 |
iter = reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, hidden=True)
|
205 |
for chatbot, history, status_text, token_count in iter:
|
206 |
status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
|
@@ -208,6 +220,7 @@ def predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count
|
|
208 |
|
209 |
|
210 |
def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False):
|
|
|
211 |
if len(history) == 0:
|
212 |
yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
|
213 |
return
|
@@ -215,11 +228,13 @@ def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, t
|
|
215 |
inputs = history.pop()["content"]
|
216 |
token_count.pop()
|
217 |
iter = predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=stream)
|
|
|
218 |
for x in iter:
|
219 |
yield x
|
220 |
|
221 |
|
222 |
def reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, hidden=False):
|
|
|
223 |
iter = predict(openai_api_key, system_prompt, history, summarize_prompt, chatbot, token_count, top_p, temperature, stream=stream, should_check_token_count=False)
|
224 |
for chatbot, history, status_text, previous_token_count in iter:
|
225 |
history = history[-2:]
|
@@ -227,23 +242,29 @@ def reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_cou
|
|
227 |
if hidden:
|
228 |
chatbot.pop()
|
229 |
yield chatbot, history, construct_token_message(sum(token_count), stream=stream), token_count
|
|
|
230 |
|
231 |
|
232 |
def delete_last_conversation(chatbot, history, previous_token_count, streaming):
|
233 |
if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
|
|
|
234 |
chatbot.pop()
|
235 |
return chatbot, history
|
236 |
if len(history) > 0:
|
|
|
237 |
history.pop()
|
238 |
history.pop()
|
239 |
if len(chatbot) > 0:
|
|
|
240 |
chatbot.pop()
|
241 |
if len(previous_token_count) > 0:
|
|
|
242 |
previous_token_count.pop()
|
243 |
return chatbot, history, previous_token_count, construct_token_message(sum(previous_token_count), streaming)
|
244 |
|
245 |
|
246 |
def save_chat_history(filename, system, history, chatbot):
|
|
|
247 |
if filename == "":
|
248 |
return
|
249 |
if not filename.endswith(".json"):
|
@@ -253,13 +274,16 @@ def save_chat_history(filename, system, history, chatbot):
|
|
253 |
print(json_s)
|
254 |
with open(os.path.join(HISTORY_DIR, filename), "w") as f:
|
255 |
json.dump(json_s, f)
|
|
|
256 |
|
257 |
|
258 |
def load_chat_history(filename, system, history, chatbot):
|
|
|
259 |
try:
|
260 |
with open(os.path.join(HISTORY_DIR, filename), "r") as f:
|
261 |
json_s = json.load(f)
|
262 |
if type(json_s["history"]) == list:
|
|
|
263 |
new_history = []
|
264 |
for index, item in enumerate(json_s["history"]):
|
265 |
if index % 2 == 0:
|
@@ -267,16 +291,17 @@ def load_chat_history(filename, system, history, chatbot):
|
|
267 |
else:
|
268 |
new_history.append(construct_assistant(item))
|
269 |
json_s["history"] = new_history
|
|
|
270 |
return filename, json_s["system"], json_s["history"], json_s["chatbot"]
|
271 |
except FileNotFoundError:
|
272 |
-
print("
|
273 |
return filename, system, history, chatbot
|
274 |
|
275 |
def sorted_by_pinyin(list):
|
276 |
return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
|
277 |
|
278 |
def get_file_names(dir, plain=False, filetypes=[".json"]):
|
279 |
-
|
280 |
files = []
|
281 |
try:
|
282 |
for type in filetypes:
|
@@ -292,9 +317,11 @@ def get_file_names(dir, plain=False, filetypes=[".json"]):
|
|
292 |
return gr.Dropdown.update(choices=files)
|
293 |
|
294 |
def get_history_names(plain=False):
|
|
|
295 |
return get_file_names(HISTORY_DIR, plain)
|
296 |
|
297 |
def load_template(filename, mode=0):
|
|
|
298 |
lines = []
|
299 |
print("Loading template...")
|
300 |
if filename.endswith(".json"):
|
@@ -315,24 +342,19 @@ def load_template(filename, mode=0):
|
|
315 |
return {row[0]:row[1] for row in lines}, gr.Dropdown.update(choices=choices, value=choices[0])
|
316 |
|
317 |
def get_template_names(plain=False):
|
|
|
318 |
return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
|
319 |
|
320 |
def get_template_content(templates, selection, original_system_prompt):
|
|
|
321 |
try:
|
322 |
return templates[selection]
|
323 |
except:
|
324 |
return original_system_prompt
|
325 |
|
326 |
def reset_state():
|
|
|
327 |
return [], [], [], construct_token_message(0)
|
328 |
|
329 |
-
def compose_system(system_prompt):
|
330 |
-
return {"role": "system", "content": system_prompt}
|
331 |
-
|
332 |
-
|
333 |
-
def compose_user(user_input):
|
334 |
-
return {"role": "user", "content": user_input}
|
335 |
-
|
336 |
-
|
337 |
def reset_textbox():
|
338 |
return gr.update(value='')
|
|
|
13 |
from pypinyin import lazy_pinyin
|
14 |
from presets import *
|
15 |
import jieba
|
16 |
+
from tqdm import tqdm
|
17 |
|
18 |
if TYPE_CHECKING:
|
19 |
from typing import TypedDict
|
|
|
48 |
return y
|
49 |
|
50 |
def count_words(input_str):
|
51 |
+
print("计算输入字数中……")
|
52 |
words = jieba.lcut(input_str)
|
53 |
+
print("计算完成!")
|
54 |
return len(words)
|
55 |
|
56 |
def parse_text(text):
|
|
|
128 |
def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, previous_token_count, top_p, temperature):
|
129 |
def get_return_value():
|
130 |
return chatbot, history, status_text, [*previous_token_count, token_counter]
|
131 |
+
|
132 |
+
print("实时回答模式")
|
133 |
token_counter = 0
|
134 |
partial_words = ""
|
135 |
counter = 0
|
136 |
+
status_text = "开始实时传输回答……"
|
137 |
history.append(construct_user(inputs))
|
138 |
if len(previous_token_count) == 0:
|
139 |
rough_user_token_count = count_words(inputs) + count_words(system_prompt)
|
|
|
149 |
chatbot.append((parse_text(inputs), ""))
|
150 |
yield get_return_value()
|
151 |
|
152 |
+
for chunk in tqdm(response.iter_lines()):
|
153 |
if counter == 0:
|
154 |
counter += 1
|
155 |
continue
|
|
|
164 |
finish_reason = chunk['choices'][0]['finish_reason']
|
165 |
status_text = construct_token_message(sum(previous_token_count)+token_counter+rough_user_token_count, stream=True)
|
166 |
if finish_reason == "stop":
|
167 |
+
print("生成完毕")
|
168 |
yield get_return_value()
|
169 |
break
|
170 |
partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
|
|
|
178 |
|
179 |
|
180 |
def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, previous_token_count, top_p, temperature):
|
181 |
+
print("一次性回答模式")
|
182 |
history.append(construct_user(inputs))
|
183 |
try:
|
184 |
response = get_response(openai_api_key, system_prompt, history, temperature, top_p, False)
|
|
|
192 |
total_token_count = response["usage"]["total_tokens"]
|
193 |
previous_token_count.append(total_token_count - sum(previous_token_count))
|
194 |
status_text = construct_token_message(total_token_count)
|
195 |
+
print("生成一次性回答完毕")
|
196 |
return chatbot, history, status_text, previous_token_count
|
197 |
|
198 |
|
199 |
def predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=False, should_check_token_count = True): # repetition_penalty, top_k
|
200 |
if stream:
|
201 |
+
print("使用流式传输")
|
202 |
iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature)
|
203 |
for chatbot, history, status_text, token_count in iter:
|
204 |
yield chatbot, history, status_text, token_count
|
205 |
else:
|
206 |
+
print("不使用流式传输")
|
207 |
chatbot, history, status_text, token_count = predict_all(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature)
|
208 |
yield chatbot, history, status_text, token_count
|
209 |
+
print(f"传输完毕。当前token计数为{token_count}")
|
210 |
if stream:
|
211 |
max_token = max_token_streaming
|
212 |
else:
|
213 |
max_token = max_token_all
|
214 |
if sum(token_count) > max_token and should_check_token_count:
|
215 |
+
print(f"精简token中{token_count}/{max_token}")
|
216 |
iter = reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, hidden=True)
|
217 |
for chatbot, history, status_text, token_count in iter:
|
218 |
status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
|
|
|
220 |
|
221 |
|
222 |
def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False):
|
223 |
+
print("重试中……")
|
224 |
if len(history) == 0:
|
225 |
yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
|
226 |
return
|
|
|
228 |
inputs = history.pop()["content"]
|
229 |
token_count.pop()
|
230 |
iter = predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=stream)
|
231 |
+
print("重试完毕")
|
232 |
for x in iter:
|
233 |
yield x
|
234 |
|
235 |
|
236 |
def reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, hidden=False):
|
237 |
+
print("开始减少token数量……")
|
238 |
iter = predict(openai_api_key, system_prompt, history, summarize_prompt, chatbot, token_count, top_p, temperature, stream=stream, should_check_token_count=False)
|
239 |
for chatbot, history, status_text, previous_token_count in iter:
|
240 |
history = history[-2:]
|
|
|
242 |
if hidden:
|
243 |
chatbot.pop()
|
244 |
yield chatbot, history, construct_token_message(sum(token_count), stream=stream), token_count
|
245 |
+
print("减少token数量完毕")
|
246 |
|
247 |
|
248 |
def delete_last_conversation(chatbot, history, previous_token_count, streaming):
|
249 |
if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
|
250 |
+
print("由于包含报错信息,只删除chatbot记录")
|
251 |
chatbot.pop()
|
252 |
return chatbot, history
|
253 |
if len(history) > 0:
|
254 |
+
print("删除了一组对话历史")
|
255 |
history.pop()
|
256 |
history.pop()
|
257 |
if len(chatbot) > 0:
|
258 |
+
print("删除了一组chatbot对话")
|
259 |
chatbot.pop()
|
260 |
if len(previous_token_count) > 0:
|
261 |
+
print("删除了一组对话的token计数记录")
|
262 |
previous_token_count.pop()
|
263 |
return chatbot, history, previous_token_count, construct_token_message(sum(previous_token_count), streaming)
|
264 |
|
265 |
|
266 |
def save_chat_history(filename, system, history, chatbot):
|
267 |
+
print("保存对话历史中……")
|
268 |
if filename == "":
|
269 |
return
|
270 |
if not filename.endswith(".json"):
|
|
|
274 |
print(json_s)
|
275 |
with open(os.path.join(HISTORY_DIR, filename), "w") as f:
|
276 |
json.dump(json_s, f)
|
277 |
+
print("保存对话历史完毕")
|
278 |
|
279 |
|
280 |
def load_chat_history(filename, system, history, chatbot):
|
281 |
+
print("加载对话历史中……")
|
282 |
try:
|
283 |
with open(os.path.join(HISTORY_DIR, filename), "r") as f:
|
284 |
json_s = json.load(f)
|
285 |
if type(json_s["history"]) == list:
|
286 |
+
print("历史记录格式为旧版,正在转换……")
|
287 |
new_history = []
|
288 |
for index, item in enumerate(json_s["history"]):
|
289 |
if index % 2 == 0:
|
|
|
291 |
else:
|
292 |
new_history.append(construct_assistant(item))
|
293 |
json_s["history"] = new_history
|
294 |
+
print("加载对话历史完毕")
|
295 |
return filename, json_s["system"], json_s["history"], json_s["chatbot"]
|
296 |
except FileNotFoundError:
|
297 |
+
print("没有找到对话历史文件,不执行任何操作")
|
298 |
return filename, system, history, chatbot
|
299 |
|
300 |
def sorted_by_pinyin(list):
|
301 |
return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
|
302 |
|
303 |
def get_file_names(dir, plain=False, filetypes=[".json"]):
|
304 |
+
print(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
|
305 |
files = []
|
306 |
try:
|
307 |
for type in filetypes:
|
|
|
317 |
return gr.Dropdown.update(choices=files)
|
318 |
|
319 |
def get_history_names(plain=False):
|
320 |
+
print("获取历史记录文件名列表")
|
321 |
return get_file_names(HISTORY_DIR, plain)
|
322 |
|
323 |
def load_template(filename, mode=0):
|
324 |
+
print(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
|
325 |
lines = []
|
326 |
print("Loading template...")
|
327 |
if filename.endswith(".json"):
|
|
|
342 |
return {row[0]:row[1] for row in lines}, gr.Dropdown.update(choices=choices, value=choices[0])
|
343 |
|
344 |
def get_template_names(plain=False):
|
345 |
+
print("获取模板文件名列表")
|
346 |
return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
|
347 |
|
348 |
def get_template_content(templates, selection, original_system_prompt):
|
349 |
+
print(f"获取模板内容,模板字典为{templates},选择为{selection},原始系统提示为{original_system_prompt}")
|
350 |
try:
|
351 |
return templates[selection]
|
352 |
except:
|
353 |
return original_system_prompt
|
354 |
|
355 |
def reset_state():
|
356 |
+
print("重置状态")
|
357 |
return [], [], [], construct_token_message(0)
|
358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
def reset_textbox():
|
360 |
return gr.update(value='')
|