Tuchuanhuhuhu commited on
Commit
079c7eb
·
1 Parent(s): 4b845f9

改进了在线搜索显示效果

Browse files
Files changed (1) hide show
  1. chat_func.py +41 -12
chat_func.py CHANGED
@@ -6,6 +6,7 @@ import logging
6
  import json
7
  import os
8
  import requests
 
9
 
10
  from tqdm import tqdm
11
  import colorama
@@ -99,6 +100,7 @@ def stream_predict(
99
  top_p,
100
  temperature,
101
  selected_model,
 
102
  ):
103
  def get_return_value():
104
  return chatbot, history, status_text, all_token_counts
@@ -109,7 +111,10 @@ def stream_predict(
109
  status_text = "开始实时传输回答……"
110
  history.append(construct_user(inputs))
111
  history.append(construct_assistant(""))
112
- chatbot.append((parse_text(inputs), ""))
 
 
 
113
  user_token_count = 0
114
  if len(all_token_counts) == 0:
115
  system_prompt_token_count = count_token(construct_system(system_prompt))
@@ -184,7 +189,7 @@ def stream_predict(
184
  yield get_return_value()
185
  break
186
  history[-1] = construct_assistant(partial_words)
187
- chatbot[-1] = (parse_text(inputs), parse_text(partial_words))
188
  all_token_counts[-1] += 1
189
  yield get_return_value()
190
 
@@ -199,11 +204,15 @@ def predict_all(
199
  top_p,
200
  temperature,
201
  selected_model,
 
202
  ):
203
  logging.info("一次性回答模式")
204
  history.append(construct_user(inputs))
205
  history.append(construct_assistant(""))
206
- chatbot.append((parse_text(inputs), ""))
 
 
 
207
  all_token_counts.append(count_token(construct_user(inputs)))
208
  try:
209
  response = get_response(
@@ -229,7 +238,7 @@ def predict_all(
229
  response = json.loads(response.text)
230
  content = response["choices"][0]["message"]["content"]
231
  history[-1] = construct_assistant(content)
232
- chatbot[-1] = (parse_text(inputs), parse_text(content))
233
  total_token_count = response["usage"]["total_tokens"]
234
  all_token_counts[-1] = total_token_count - sum(all_token_counts)
235
  status_text = construct_token_message(total_token_count)
@@ -247,7 +256,7 @@ def predict(
247
  temperature,
248
  stream=False,
249
  selected_model=MODELS[0],
250
- use_websearch_checkbox=False,
251
  files = None,
252
  should_check_token_count=True,
253
  ): # repetition_penalty, top_k
@@ -262,18 +271,24 @@ def predict(
262
  history, chatbot, status_text = chat_ai(openai_api_key, index, inputs, history, chatbot)
263
  yield chatbot, history, status_text, all_token_counts
264
  return
265
- if use_websearch_checkbox:
266
- results = ddg(inputs, max_results=3)
 
 
 
 
267
  web_results = []
268
- for idx, result in enumerate(results):
269
  logging.info(f"搜索结果{idx + 1}:{result}")
 
270
  web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
271
- web_results = "\n\n".join(web_results)
272
  inputs = (
273
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
274
  .replace("{query}", inputs)
275
- .replace("{web_results}", web_results)
276
  )
 
277
  if len(openai_api_key) != 51:
278
  status_text = standard_error_msg + no_apikey_msg
279
  logging.info(status_text)
@@ -286,8 +301,9 @@ def predict(
286
  history[-2] = construct_user(inputs)
287
  yield chatbot, history, status_text, all_token_counts
288
  return
289
- if stream:
290
- yield chatbot, history, "开始生成回答……", all_token_counts
 
291
  if stream:
292
  logging.info("使用流式传输")
293
  iter = stream_predict(
@@ -300,6 +316,7 @@ def predict(
300
  top_p,
301
  temperature,
302
  selected_model,
 
303
  )
304
  for chatbot, history, status_text, all_token_counts in iter:
305
  yield chatbot, history, status_text, all_token_counts
@@ -315,8 +332,10 @@ def predict(
315
  top_p,
316
  temperature,
317
  selected_model,
 
318
  )
319
  yield chatbot, history, status_text, all_token_counts
 
320
  logging.info(f"传输完毕。当前token计数为{all_token_counts}")
321
  if len(history) > 1 and history[-1]["content"] != inputs:
322
  logging.info(
@@ -325,10 +344,20 @@ def predict(
325
  + f"{history[-1]['content']}"
326
  + colorama.Style.RESET_ALL
327
  )
 
 
 
 
 
 
 
 
 
328
  if stream:
329
  max_token = max_token_streaming
330
  else:
331
  max_token = max_token_all
 
332
  if sum(all_token_counts) > max_token and should_check_token_count:
333
  status_text = f"精简token中{all_token_counts}/{max_token}"
334
  logging.info(status_text)
 
6
  import json
7
  import os
8
  import requests
9
+ import urllib3
10
 
11
  from tqdm import tqdm
12
  import colorama
 
100
  top_p,
101
  temperature,
102
  selected_model,
103
+ fake_input=None
104
  ):
105
  def get_return_value():
106
  return chatbot, history, status_text, all_token_counts
 
111
  status_text = "开始实时传输回答……"
112
  history.append(construct_user(inputs))
113
  history.append(construct_assistant(""))
114
+ if fake_input:
115
+ chatbot.append((parse_text(fake_input), ""))
116
+ else:
117
+ chatbot.append((parse_text(inputs), ""))
118
  user_token_count = 0
119
  if len(all_token_counts) == 0:
120
  system_prompt_token_count = count_token(construct_system(system_prompt))
 
189
  yield get_return_value()
190
  break
191
  history[-1] = construct_assistant(partial_words)
192
+ chatbot[-1] = (chatbot[-1][0], parse_text(partial_words))
193
  all_token_counts[-1] += 1
194
  yield get_return_value()
195
 
 
204
  top_p,
205
  temperature,
206
  selected_model,
207
+ fake_input=None
208
  ):
209
  logging.info("一次性回答模式")
210
  history.append(construct_user(inputs))
211
  history.append(construct_assistant(""))
212
+ if fake_input:
213
+ chatbot.append((parse_text(fake_input), ""))
214
+ else:
215
+ chatbot.append((parse_text(inputs), ""))
216
  all_token_counts.append(count_token(construct_user(inputs)))
217
  try:
218
  response = get_response(
 
238
  response = json.loads(response.text)
239
  content = response["choices"][0]["message"]["content"]
240
  history[-1] = construct_assistant(content)
241
+ chatbot[-1] = (chatbot[-1][0], parse_text(content))
242
  total_token_count = response["usage"]["total_tokens"]
243
  all_token_counts[-1] = total_token_count - sum(all_token_counts)
244
  status_text = construct_token_message(total_token_count)
 
256
  temperature,
257
  stream=False,
258
  selected_model=MODELS[0],
259
+ use_websearch=False,
260
  files = None,
261
  should_check_token_count=True,
262
  ): # repetition_penalty, top_k
 
271
  history, chatbot, status_text = chat_ai(openai_api_key, index, inputs, history, chatbot)
272
  yield chatbot, history, status_text, all_token_counts
273
  return
274
+
275
+ old_inputs = ""
276
+ link_references = []
277
+ if use_websearch:
278
+ search_results = ddg(inputs, max_results=5)
279
+ old_inputs = inputs
280
  web_results = []
281
+ for idx, result in enumerate(search_results):
282
  logging.info(f"搜索结果{idx + 1}:{result}")
283
+ domain_name = urllib3.util.parse_url(result["href"]).host
284
  web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
285
+ link_references.append(f"[{idx+1}]: [{domain_name}]({result['href']})")
286
  inputs = (
287
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
288
  .replace("{query}", inputs)
289
+ .replace("{web_results}", "\n\n".join(web_results))
290
  )
291
+
292
  if len(openai_api_key) != 51:
293
  status_text = standard_error_msg + no_apikey_msg
294
  logging.info(status_text)
 
301
  history[-2] = construct_user(inputs)
302
  yield chatbot, history, status_text, all_token_counts
303
  return
304
+
305
+ yield chatbot, history, "开始生成回答……", all_token_counts
306
+
307
  if stream:
308
  logging.info("使用流式传输")
309
  iter = stream_predict(
 
316
  top_p,
317
  temperature,
318
  selected_model,
319
+ fake_input=old_inputs
320
  )
321
  for chatbot, history, status_text, all_token_counts in iter:
322
  yield chatbot, history, status_text, all_token_counts
 
332
  top_p,
333
  temperature,
334
  selected_model,
335
+ fake_input=old_inputs
336
  )
337
  yield chatbot, history, status_text, all_token_counts
338
+
339
  logging.info(f"传输完毕。当前token计数为{all_token_counts}")
340
  if len(history) > 1 and history[-1]["content"] != inputs:
341
  logging.info(
 
344
  + f"{history[-1]['content']}"
345
  + colorama.Style.RESET_ALL
346
  )
347
+
348
+ if use_websearch:
349
+ response = history[-1]['content']
350
+ response += "\n\n" + "\n".join(link_references)
351
+ logging.info(f"Added link references.")
352
+ logging.info(response)
353
+ chatbot[-1] = (parse_text(old_inputs), response)
354
+ yield chatbot, history, status_text, all_token_counts
355
+
356
  if stream:
357
  max_token = max_token_streaming
358
  else:
359
  max_token = max_token_all
360
+
361
  if sum(all_token_counts) > max_token and should_check_token_count:
362
  status_text = f"精简token中{all_token_counts}/{max_token}"
363
  logging.info(status_text)