JohnSmith9982 commited on
Commit
1620ce5
1 Parent(s): 55db78d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +19 -18
  2. presets.py +11 -2
  3. utils.py +65 -49
app.py CHANGED
@@ -42,14 +42,6 @@ else:
42
  gr.Chatbot.postprocess = postprocess
43
 
44
  with gr.Blocks(css=customCSS) as demo:
45
- gr.HTML(title)
46
- gr.HTML('''<center><a href="https://huggingface.co/spaces/JohnSmith9982/ChuanhuChatGPT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="复制 Space"></a>强烈建议点击上面的按钮复制一份这个Space,在你自己的Space里运行,响应更迅速、也更安全👆</center>''')
47
- with gr.Row():
48
- with gr.Column(scale=4):
49
- keyTxt = gr.Textbox(show_label=False, placeholder=f"在这里输入你的OpenAI API-key...",value=my_api_key, type="password", visible=not HIDE_MY_KEY).style(container=True)
50
- with gr.Column(scale=1):
51
- use_streaming_checkbox = gr.Checkbox(label="实时传输回答", value=True, visible=enable_streaming_option)
52
- chatbot = gr.Chatbot() # .style(color_map=("#1D51EE", "#585A5B"))
53
  history = gr.State([])
54
  token_count = gr.State([])
55
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
@@ -57,6 +49,15 @@ with gr.Blocks(css=customCSS) as demo:
57
  FALSECONSTANT = gr.State(False)
58
  topic = gr.State("未命名对话历史记录")
59
 
 
 
 
 
 
 
 
 
 
60
  with gr.Row():
61
  with gr.Column(scale=12):
62
  user_input = gr.Textbox(show_label=False, placeholder="在这里输入").style(
@@ -69,8 +70,9 @@ with gr.Blocks(css=customCSS) as demo:
69
  delLastBtn = gr.Button("🗑️ 删除最近一条对话")
70
  reduceTokenBtn = gr.Button("♻️ 总结对话")
71
  status_display = gr.Markdown("status: ready")
72
- systemPromptTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入System Prompt...",
73
- label="System prompt", value=initial_prompt).style(container=True)
 
74
  with gr.Accordion(label="加载Prompt模板", open=False):
75
  with gr.Column():
76
  with gr.Row():
@@ -101,28 +103,27 @@ with gr.Blocks(css=customCSS) as demo:
101
  #inputs, top_p, temperature, top_k, repetition_penalty
102
  with gr.Accordion("参数", open=False):
103
  top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05,
104
- interactive=True, label="Top-p (nucleus sampling)",)
105
  temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0,
106
  step=0.1, interactive=True, label="Temperature",)
107
- #top_k = gr.Slider( minimum=1, maximum=50, value=4, step=1, interactive=True, label="Top-k",)
108
- #repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", )
109
  gr.Markdown(description)
110
 
111
 
112
- user_input.submit(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
113
  user_input.submit(reset_textbox, [], [user_input])
114
 
115
- submitBtn.click(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
116
  submitBtn.click(reset_textbox, [], [user_input])
117
 
118
  emptyBtn.click(reset_state, outputs=[chatbot, history, token_count, status_display], show_progress=True)
119
 
120
- retryBtn.click(retry, [keyTxt, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
121
 
122
- delLastBtn.click(delete_last_conversation, [chatbot, history, token_count, use_streaming_checkbox], [
123
  chatbot, history, token_count, status_display], show_progress=True)
124
 
125
- reduceTokenBtn.click(reduce_token_size, [keyTxt, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox], [chatbot, history, status_display, token_count], show_progress=True)
126
 
127
  saveHistoryBtn.click(save_chat_history, [
128
  saveFileName, systemPromptTxt, history, chatbot], None, show_progress=True)
 
42
  gr.Chatbot.postprocess = postprocess
43
 
44
  with gr.Blocks(css=customCSS) as demo:
 
 
 
 
 
 
 
 
45
  history = gr.State([])
46
  token_count = gr.State([])
47
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
 
49
  FALSECONSTANT = gr.State(False)
50
  topic = gr.State("未命名对话历史记录")
51
 
52
+ gr.HTML(title)
53
+ with gr.Row():
54
+ with gr.Column():
55
+ keyTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入你的OpenAI API-key...",value=my_api_key, type="password", visible=not HIDE_MY_KEY, label="API-Key")
56
+ with gr.Column():
57
+ with gr.Row():
58
+ model_select_dropdown = gr.Dropdown(label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0])
59
+ use_streaming_checkbox = gr.Checkbox(label="实时传输回答", value=True, visible=enable_streaming_option)
60
+ chatbot = gr.Chatbot() # .style(color_map=("#1D51EE", "#585A5B"))
61
  with gr.Row():
62
  with gr.Column(scale=12):
63
  user_input = gr.Textbox(show_label=False, placeholder="在这里输入").style(
 
70
  delLastBtn = gr.Button("🗑️ 删除最近一条对话")
71
  reduceTokenBtn = gr.Button("♻️ 总结对话")
72
  status_display = gr.Markdown("status: ready")
73
+
74
+ systemPromptTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入System Prompt...", label="System prompt", value=initial_prompt).style(container=True)
75
+
76
  with gr.Accordion(label="加载Prompt模板", open=False):
77
  with gr.Column():
78
  with gr.Row():
 
103
  #inputs, top_p, temperature, top_k, repetition_penalty
104
  with gr.Accordion("参数", open=False):
105
  top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05,
106
+ interactive=True, label="Top-p (nucleus sampling)",)
107
  temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0,
108
  step=0.1, interactive=True, label="Temperature",)
109
+
 
110
  gr.Markdown(description)
111
 
112
 
113
+ user_input.submit(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
114
  user_input.submit(reset_textbox, [], [user_input])
115
 
116
+ submitBtn.click(predict, [keyTxt, systemPromptTxt, history, user_input, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
117
  submitBtn.click(reset_textbox, [], [user_input])
118
 
119
  emptyBtn.click(reset_state, outputs=[chatbot, history, token_count, status_display], show_progress=True)
120
 
121
+ retryBtn.click(retry, [keyTxt, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
122
 
123
+ delLastBtn.click(delete_last_conversation, [chatbot, history, token_count], [
124
  chatbot, history, token_count, status_display], show_progress=True)
125
 
126
+ reduceTokenBtn.click(reduce_token_size, [keyTxt, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox, model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
127
 
128
  saveHistoryBtn.click(save_chat_history, [
129
  saveFileName, systemPromptTxt, history, chatbot], None, show_progress=True)
presets.py CHANGED
@@ -31,9 +31,18 @@ pre code {
31
  }
32
  """
33
 
34
- standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
35
- error_retrieve_prompt = "连接超时,无法获取对话。请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
36
  summarize_prompt = "请总结以上对话,不超过100字。" # 总结对话时的 prompt
 
 
 
 
 
 
 
 
 
 
 
37
  max_token_streaming = 3500 # 流式对话时的最大 token 数
38
  timeout_streaming = 15 # 流式对话时的超时时间
39
  max_token_all = 3500 # 非流式对话时的最大 token 数
 
31
  }
32
  """
33
 
 
 
34
  summarize_prompt = "请总结以上对话,不超过100字。" # 总结对话时的 prompt
35
+ MODELS = ["gpt-3.5-turbo", "gpt-3.5-turbo-0301"] # 可选的模型
36
+
37
+ # 错误信息
38
+ standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
39
+ error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
40
+ connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
41
+ read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
42
+ proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
43
+ ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
44
+ no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
45
+
46
  max_token_streaming = 3500 # 流式对话时的最大 token 数
47
  timeout_streaming = 15 # 流式对话时的超时时间
48
  max_token_all = 3500 # 非流式对话时的最大 token 数
utils.py CHANGED
@@ -99,7 +99,7 @@ def construct_assistant(text):
99
  def construct_token_message(token, stream=False):
100
  return f"Token 计数: {token}"
101
 
102
- def get_response(openai_api_key, system_prompt, history, temperature, top_p, stream):
103
  headers = {
104
  "Content-Type": "application/json",
105
  "Authorization": f"Bearer {openai_api_key}"
@@ -108,7 +108,7 @@ def get_response(openai_api_key, system_prompt, history, temperature, top_p, str
108
  history = [construct_system(system_prompt), *history]
109
 
110
  payload = {
111
- "model": "gpt-3.5-turbo",
112
  "messages": history, # [{"role": "user", "content": f"{inputs}"}],
113
  "temperature": temperature, # 1.0,
114
  "top_p": top_p, # 1.0,
@@ -124,40 +124,40 @@ def get_response(openai_api_key, system_prompt, history, temperature, top_p, str
124
  response = requests.post(API_URL, headers=headers, json=payload, stream=True, timeout=timeout)
125
  return response
126
 
127
- def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, previous_token_count, top_p, temperature):
128
  def get_return_value():
129
- return chatbot, history, status_text, [*previous_token_count, token_counter]
130
 
131
  print("实时回答模式")
132
- token_counter = 0
133
  partial_words = ""
134
  counter = 0
135
  status_text = "开始实时传输回答……"
136
  history.append(construct_user(inputs))
 
 
137
  user_token_count = 0
138
- if len(previous_token_count) == 0:
139
  system_prompt_token_count = count_token(system_prompt)
140
  user_token_count = count_token(inputs) + system_prompt_token_count
141
  else:
142
  user_token_count = count_token(inputs)
 
143
  print(f"输入token计数: {user_token_count}")
 
144
  try:
145
- response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True)
146
  except requests.exceptions.ConnectTimeout:
147
- history.pop()
148
- status_text = standard_error_msg + "连接超时,无法获取对话。" + error_retrieve_prompt
149
  yield get_return_value()
150
  return
151
  except requests.exceptions.ReadTimeout:
152
- history.pop()
153
- status_text = standard_error_msg + "读取超时,无法获取对话。" + error_retrieve_prompt
154
  yield get_return_value()
155
  return
156
 
157
- chatbot.append((parse_text(inputs), ""))
158
  yield get_return_value()
159
 
160
- for chunk in response.iter_lines():
161
  if counter == 0:
162
  counter += 1
163
  continue
@@ -169,77 +169,93 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
169
  try:
170
  chunk = json.loads(chunk[6:])
171
  except json.JSONDecodeError:
 
172
  status_text = f"JSON解析错误。请重置对话。收到的内容: {chunk}"
173
  yield get_return_value()
174
- break
175
  # decode each line as response data is in bytes
176
  if chunklength > 6 and "delta" in chunk['choices'][0]:
177
  finish_reason = chunk['choices'][0]['finish_reason']
178
- status_text = construct_token_message(sum(previous_token_count)+token_counter+user_token_count, stream=True)
179
  if finish_reason == "stop":
180
- print("生成完毕")
181
  yield get_return_value()
182
  break
183
  try:
184
  partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
185
  except KeyError:
186
- status_text = standard_error_msg + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: " + str(sum(previous_token_count)+token_counter+user_token_count)
187
  yield get_return_value()
188
  break
189
- if token_counter == 0:
190
- history.append(construct_assistant(" " + partial_words))
191
- else:
192
- history[-1] = construct_assistant(partial_words)
193
  chatbot[-1] = (parse_text(inputs), parse_text(partial_words))
194
- token_counter += 1
195
  yield get_return_value()
196
 
197
 
198
- def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, previous_token_count, top_p, temperature):
199
  print("一次性回答模式")
200
  history.append(construct_user(inputs))
 
 
 
201
  try:
202
- response = get_response(openai_api_key, system_prompt, history, temperature, top_p, False)
203
  except requests.exceptions.ConnectTimeout:
204
- status_text = standard_error_msg + error_retrieve_prompt
205
- return chatbot, history, status_text, previous_token_count
 
 
 
 
 
 
206
  response = json.loads(response.text)
207
  content = response["choices"][0]["message"]["content"]
208
- history.append(construct_assistant(content))
209
  chatbot.append((parse_text(inputs), parse_text(content)))
210
  total_token_count = response["usage"]["total_tokens"]
211
- previous_token_count.append(total_token_count - sum(previous_token_count))
212
  status_text = construct_token_message(total_token_count)
213
- print("生成一次性回答完毕")
214
- return chatbot, history, status_text, previous_token_count
215
 
216
 
217
- def predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=False, should_check_token_count = True): # repetition_penalty, top_k
218
  print("输入为:" +colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
 
 
 
 
 
 
 
 
 
 
219
  if stream:
220
  print("使用流式传输")
221
- iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature)
222
- for chatbot, history, status_text, token_count in iter:
223
- yield chatbot, history, status_text, token_count
224
  else:
225
  print("不使用流式传输")
226
- chatbot, history, status_text, token_count = predict_all(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature)
227
- yield chatbot, history, status_text, token_count
228
- print(f"传输完毕。当前token计数为{token_count}")
229
- print("回答为:" +colorama.Fore.BLUE + f"{history[-1]['content']}" + colorama.Style.RESET_ALL)
 
230
  if stream:
231
  max_token = max_token_streaming
232
  else:
233
  max_token = max_token_all
234
- if sum(token_count) > max_token and should_check_token_count:
235
- print(f"精简token中{token_count}/{max_token}")
236
- iter = reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, hidden=True)
237
- for chatbot, history, status_text, token_count in iter:
238
  status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
239
- yield chatbot, history, status_text, token_count
240
 
241
 
242
- def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False):
243
  print("重试中……")
244
  if len(history) == 0:
245
  yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
@@ -247,15 +263,15 @@ def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, t
247
  history.pop()
248
  inputs = history.pop()["content"]
249
  token_count.pop()
250
- iter = predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=stream)
251
  print("重试完毕")
252
  for x in iter:
253
  yield x
254
 
255
 
256
- def reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, hidden=False):
257
  print("开始减少token数量……")
258
- iter = predict(openai_api_key, system_prompt, history, summarize_prompt, chatbot, token_count, top_p, temperature, stream=stream, should_check_token_count=False)
259
  for chatbot, history, status_text, previous_token_count in iter:
260
  history = history[-2:]
261
  token_count = previous_token_count[-1:]
@@ -265,7 +281,7 @@ def reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_cou
265
  print("减少token数量完毕")
266
 
267
 
268
- def delete_last_conversation(chatbot, history, previous_token_count, streaming):
269
  if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
270
  print("由于包含报错信息,只删除chatbot记录")
271
  chatbot.pop()
@@ -280,7 +296,7 @@ def delete_last_conversation(chatbot, history, previous_token_count, streaming):
280
  if len(previous_token_count) > 0:
281
  print("删除了一组对话的token计数记录")
282
  previous_token_count.pop()
283
- return chatbot, history, previous_token_count, construct_token_message(sum(previous_token_count), streaming)
284
 
285
 
286
  def save_chat_history(filename, system, history, chatbot):
 
99
  def construct_token_message(token, stream=False):
100
  return f"Token 计数: {token}"
101
 
102
+ def get_response(openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model):
103
  headers = {
104
  "Content-Type": "application/json",
105
  "Authorization": f"Bearer {openai_api_key}"
 
108
  history = [construct_system(system_prompt), *history]
109
 
110
  payload = {
111
+ "model": selected_model,
112
  "messages": history, # [{"role": "user", "content": f"{inputs}"}],
113
  "temperature": temperature, # 1.0,
114
  "top_p": top_p, # 1.0,
 
124
  response = requests.post(API_URL, headers=headers, json=payload, stream=True, timeout=timeout)
125
  return response
126
 
127
+ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model):
128
  def get_return_value():
129
+ return chatbot, history, status_text, all_token_counts
130
 
131
  print("实时回答模式")
 
132
  partial_words = ""
133
  counter = 0
134
  status_text = "开始实时传输回答……"
135
  history.append(construct_user(inputs))
136
+ history.append(construct_assistant(""))
137
+ chatbot.append((parse_text(inputs), ""))
138
  user_token_count = 0
139
+ if len(all_token_counts) == 0:
140
  system_prompt_token_count = count_token(system_prompt)
141
  user_token_count = count_token(inputs) + system_prompt_token_count
142
  else:
143
  user_token_count = count_token(inputs)
144
+ all_token_counts.append(user_token_count)
145
  print(f"输入token计数: {user_token_count}")
146
+ yield get_return_value()
147
  try:
148
+ response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True, selected_model)
149
  except requests.exceptions.ConnectTimeout:
150
+ status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
 
151
  yield get_return_value()
152
  return
153
  except requests.exceptions.ReadTimeout:
154
+ status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
 
155
  yield get_return_value()
156
  return
157
 
 
158
  yield get_return_value()
159
 
160
+ for chunk in tqdm(response.iter_lines()):
161
  if counter == 0:
162
  counter += 1
163
  continue
 
169
  try:
170
  chunk = json.loads(chunk[6:])
171
  except json.JSONDecodeError:
172
+ print(chunk)
173
  status_text = f"JSON解析错误。请重置对话。收到的内容: {chunk}"
174
  yield get_return_value()
175
+ continue
176
  # decode each line as response data is in bytes
177
  if chunklength > 6 and "delta" in chunk['choices'][0]:
178
  finish_reason = chunk['choices'][0]['finish_reason']
179
+ status_text = construct_token_message(sum(all_token_counts), stream=True)
180
  if finish_reason == "stop":
 
181
  yield get_return_value()
182
  break
183
  try:
184
  partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
185
  except KeyError:
186
+ status_text = standard_error_msg + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: " + str(sum(all_token_counts))
187
  yield get_return_value()
188
  break
189
+ history[-1] = construct_assistant(partial_words)
 
 
 
190
  chatbot[-1] = (parse_text(inputs), parse_text(partial_words))
191
+ all_token_counts[-1] += 1
192
  yield get_return_value()
193
 
194
 
195
+ def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model):
196
  print("一次性回答模式")
197
  history.append(construct_user(inputs))
198
+ history.append(construct_assistant(""))
199
+ chatbot.append((parse_text(inputs), ""))
200
+ all_token_counts.append(count_token(inputs))
201
  try:
202
+ response = get_response(openai_api_key, system_prompt, history, temperature, top_p, False, selected_model)
203
  except requests.exceptions.ConnectTimeout:
204
+ status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
205
+ return chatbot, history, status_text, all_token_counts
206
+ except requests.exceptions.ProxyError:
207
+ status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
208
+ return chatbot, history, status_text, all_token_counts
209
+ except requests.exceptions.SSLError:
210
+ status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
211
+ return chatbot, history, status_text, all_token_counts
212
  response = json.loads(response.text)
213
  content = response["choices"][0]["message"]["content"]
214
+ history[-1] = construct_assistant(content)
215
  chatbot.append((parse_text(inputs), parse_text(content)))
216
  total_token_count = response["usage"]["total_tokens"]
217
+ all_token_counts[-1] = total_token_count - sum(all_token_counts)
218
  status_text = construct_token_message(total_token_count)
219
+ return chatbot, history, status_text, all_token_counts
 
220
 
221
 
222
+ def predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, stream=False, selected_model = MODELS[0], should_check_token_count = True): # repetition_penalty, top_k
223
  print("输入为:" +colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
224
+ if len(openai_api_key) != 51:
225
+ status_text = standard_error_msg + no_apikey_msg
226
+ print(status_text)
227
+ history.append(construct_user(inputs))
228
+ history.append("")
229
+ chatbot.append((parse_text(inputs), ""))
230
+ all_token_counts.append(0)
231
+ yield chatbot, history, status_text, all_token_counts
232
+ return
233
+ yield chatbot, history, "开始生成回答……", all_token_counts
234
  if stream:
235
  print("使用流式传输")
236
+ iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model)
237
+ for chatbot, history, status_text, all_token_counts in iter:
238
+ yield chatbot, history, status_text, all_token_counts
239
  else:
240
  print("不使用流式传输")
241
+ chatbot, history, status_text, all_token_counts = predict_all(openai_api_key, system_prompt, history, inputs, chatbot, all_token_counts, top_p, temperature, selected_model)
242
+ yield chatbot, history, status_text, all_token_counts
243
+ print(f"传输完毕。当前token计数为{all_token_counts}")
244
+ if len(history) > 1 and history[-1]['content'] != inputs:
245
+ print("回答为:" +colorama.Fore.BLUE + f"{history[-1]['content']}" + colorama.Style.RESET_ALL)
246
  if stream:
247
  max_token = max_token_streaming
248
  else:
249
  max_token = max_token_all
250
+ if sum(all_token_counts) > max_token and should_check_token_count:
251
+ print(f"精简token中{all_token_counts}/{max_token}")
252
+ iter = reduce_token_size(openai_api_key, system_prompt, history, chatbot, all_token_counts, top_p, temperature, stream=False, hidden=True)
253
+ for chatbot, history, status_text, all_token_counts in iter:
254
  status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
255
+ yield chatbot, history, status_text, all_token_counts
256
 
257
 
258
+ def retry(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, selected_model = MODELS[0]):
259
  print("重试中……")
260
  if len(history) == 0:
261
  yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
 
263
  history.pop()
264
  inputs = history.pop()["content"]
265
  token_count.pop()
266
+ iter = predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=stream, selected_model=selected_model)
267
  print("重试完毕")
268
  for x in iter:
269
  yield x
270
 
271
 
272
+ def reduce_token_size(openai_api_key, system_prompt, history, chatbot, token_count, top_p, temperature, stream=False, hidden=False, selected_model = MODELS[0]):
273
  print("开始减少token数量……")
274
+ iter = predict(openai_api_key, system_prompt, history, summarize_prompt, chatbot, token_count, top_p, temperature, stream=stream, selected_model = selected_model, should_check_token_count=False)
275
  for chatbot, history, status_text, previous_token_count in iter:
276
  history = history[-2:]
277
  token_count = previous_token_count[-1:]
 
281
  print("减少token数量完毕")
282
 
283
 
284
+ def delete_last_conversation(chatbot, history, previous_token_count):
285
  if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
286
  print("由于包含报错信息,只删除chatbot记录")
287
  chatbot.pop()
 
296
  if len(previous_token_count) > 0:
297
  print("删除了一组对话的token计数记录")
298
  previous_token_count.pop()
299
+ return chatbot, history, previous_token_count, construct_token_message(sum(previous_token_count))
300
 
301
 
302
  def save_chat_history(filename, system, history, chatbot):