优化chatgpt对话的截断策略
Browse files- crazy_functions/谷歌检索小助手.py +2 -1
- request_llm/bridge_chatgpt.py +10 -7
- toolbox.py +46 -0
crazy_functions/谷歌检索小助手.py
CHANGED
@@ -98,7 +98,8 @@ def 谷歌检索小助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
|
98 |
history.extend([ "第一批", gpt_say ])
|
99 |
meta_paper_info_list = meta_paper_info_list[10:]
|
100 |
|
101 |
-
chatbot.append(["状态?",
|
|
|
102 |
msg = '正常'
|
103 |
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
104 |
res = write_results_to_file(history)
|
|
|
98 |
history.extend([ "第一批", gpt_say ])
|
99 |
meta_paper_info_list = meta_paper_info_list[10:]
|
100 |
|
101 |
+
chatbot.append(["状态?",
|
102 |
+
"已经全部完成,您可以试试让AI写一个Related Works,例如您可以继续输入Write a \"Related Works\" section about \"你搜索的研究领域\" for me."])
|
103 |
msg = '正常'
|
104 |
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
105 |
res = write_results_to_file(history)
|
request_llm/bridge_chatgpt.py
CHANGED
@@ -21,7 +21,7 @@ import importlib
|
|
21 |
|
22 |
# config_private.py放自己的秘密如API和代理网址
|
23 |
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
24 |
-
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys
|
25 |
proxies, API_KEY, TIMEOUT_SECONDS, MAX_RETRY = \
|
26 |
get_conf('proxies', 'API_KEY', 'TIMEOUT_SECONDS', 'MAX_RETRY')
|
27 |
|
@@ -145,7 +145,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|
145 |
yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面
|
146 |
return
|
147 |
|
148 |
-
history.append(inputs); history.append("
|
149 |
|
150 |
retry = 0
|
151 |
while True:
|
@@ -198,14 +198,17 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|
198 |
chunk_decoded = chunk.decode()
|
199 |
error_msg = chunk_decoded
|
200 |
if "reduce the length" in error_msg:
|
201 |
-
|
202 |
-
history = []
|
|
|
|
|
|
|
203 |
elif "does not exist" in error_msg:
|
204 |
-
chatbot[-1] = (chatbot[-1][0], f"[Local Message] Model {llm_kwargs['llm_model']} does not exist.
|
205 |
elif "Incorrect API key" in error_msg:
|
206 |
-
chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key. OpenAI以提供了不正确的API_KEY
|
207 |
elif "exceeded your current quota" in error_msg:
|
208 |
-
chatbot[-1] = (chatbot[-1][0], "[Local Message] You exceeded your current quota. OpenAI
|
209 |
elif "bad forward key" in error_msg:
|
210 |
chatbot[-1] = (chatbot[-1][0], "[Local Message] Bad forward key. API2D账户额度不足.")
|
211 |
elif "Not enough point" in error_msg:
|
|
|
21 |
|
22 |
# config_private.py放自己的秘密如API和代理网址
|
23 |
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
24 |
+
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history
|
25 |
proxies, API_KEY, TIMEOUT_SECONDS, MAX_RETRY = \
|
26 |
get_conf('proxies', 'API_KEY', 'TIMEOUT_SECONDS', 'MAX_RETRY')
|
27 |
|
|
|
145 |
yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面
|
146 |
return
|
147 |
|
148 |
+
history.append(inputs); history.append("")
|
149 |
|
150 |
retry = 0
|
151 |
while True:
|
|
|
198 |
chunk_decoded = chunk.decode()
|
199 |
error_msg = chunk_decoded
|
200 |
if "reduce the length" in error_msg:
|
201 |
+
if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
|
202 |
+
history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'],
|
203 |
+
max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])//2) # history至少释放二分之一
|
204 |
+
chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)")
|
205 |
+
# history = [] # 清除历史
|
206 |
elif "does not exist" in error_msg:
|
207 |
+
chatbot[-1] = (chatbot[-1][0], f"[Local Message] Model {llm_kwargs['llm_model']} does not exist. 模型不存在, 或者您没有获得体验资格.")
|
208 |
elif "Incorrect API key" in error_msg:
|
209 |
+
chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key. OpenAI以提供了不正确的API_KEY为由, 拒绝服务.")
|
210 |
elif "exceeded your current quota" in error_msg:
|
211 |
+
chatbot[-1] = (chatbot[-1][0], "[Local Message] You exceeded your current quota. OpenAI以账户额度不足为由, 拒绝服务.")
|
212 |
elif "bad forward key" in error_msg:
|
213 |
chatbot[-1] = (chatbot[-1][0], "[Local Message] Bad forward key. API2D账户额度不足.")
|
214 |
elif "Not enough point" in error_msg:
|
toolbox.py
CHANGED
@@ -551,3 +551,49 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
|
|
551 |
return {"message": f"Gradio is running at: {custom_path}"}
|
552 |
app = gr.mount_gradio_app(app, demo, path=custom_path)
|
553 |
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
551 |
return {"message": f"Gradio is running at: {custom_path}"}
|
552 |
app = gr.mount_gradio_app(app, demo, path=custom_path)
|
553 |
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
|
554 |
+
|
555 |
+
|
556 |
+
def clip_history(inputs, history, tokenizer, max_token_limit):
|
557 |
+
"""
|
558 |
+
reduce the length of input/history by clipping.
|
559 |
+
this function search for the longest entries to clip, little by little,
|
560 |
+
until the number of token of input/history is reduced under threshold.
|
561 |
+
通过剪辑来缩短输入/历史记录的长度。
|
562 |
+
此函数逐渐地搜索最长的条目进行剪辑,
|
563 |
+
直到输入/历史记录的标记数量降低到阈值以下。
|
564 |
+
"""
|
565 |
+
import numpy as np
|
566 |
+
from request_llm.bridge_all import model_info
|
567 |
+
def get_token_num(txt):
|
568 |
+
return len(tokenizer.encode(txt, disallowed_special=()))
|
569 |
+
input_token_num = get_token_num(inputs)
|
570 |
+
if input_token_num < max_token_limit * 3 / 4:
|
571 |
+
# 当输入部分的token占比小于限制的3/4时,在裁剪时把input的余量留出来
|
572 |
+
max_token_limit = max_token_limit - input_token_num
|
573 |
+
if max_token_limit < 128:
|
574 |
+
# 余量太小了,直接清除历史
|
575 |
+
history = []
|
576 |
+
return history
|
577 |
+
else:
|
578 |
+
# 当输入部分的token占比 > 限制的3/4时,直接清除历史
|
579 |
+
history = []
|
580 |
+
return history
|
581 |
+
|
582 |
+
everything = ['']
|
583 |
+
everything.extend(history)
|
584 |
+
n_token = get_token_num('\n'.join(everything))
|
585 |
+
everything_token = [get_token_num(e) for e in everything]
|
586 |
+
|
587 |
+
# 截断时的颗粒度
|
588 |
+
delta = max(everything_token) // 16
|
589 |
+
|
590 |
+
while n_token > max_token_limit:
|
591 |
+
where = np.argmax(everything_token)
|
592 |
+
encoded = tokenizer.encode(everything[where], disallowed_special=())
|
593 |
+
clipped_encoded = encoded[:len(encoded)-delta]
|
594 |
+
everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
|
595 |
+
everything_token[where] = get_token_num(everything[where])
|
596 |
+
n_token = get_token_num('\n'.join(everything))
|
597 |
+
|
598 |
+
history = everything[1:]
|
599 |
+
return history
|