JohnSmith9982 commited on
Commit
a4aaca9
1 Parent(s): 54ff2a0

Upload 19 files

Browse files
app.py CHANGED
@@ -25,9 +25,11 @@ else:
25
  dockerflag = False
26
 
27
  authflag = False
 
28
 
29
- if dockerflag:
30
  my_api_key = os.environ.get("my_api_key")
 
31
  if my_api_key == "empty":
32
  logging.error("Please give a api key!")
33
  sys.exit(1)
@@ -35,6 +37,7 @@ if dockerflag:
35
  username = os.environ.get("USERNAME")
36
  password = os.environ.get("PASSWORD")
37
  if not (isinstance(username, type(None)) or isinstance(password, type(None))):
 
38
  authflag = True
39
  else:
40
  if (
@@ -45,12 +48,15 @@ else:
45
  with open("api_key.txt", "r") as f:
46
  my_api_key = f.read().strip()
47
  if os.path.exists("auth.json"):
 
48
  with open("auth.json", "r", encoding='utf-8') as f:
49
  auth = json.load(f)
50
- username = auth["username"]
51
- password = auth["password"]
52
- if username != "" and password != "":
53
- authflag = True
 
 
54
 
55
  gr.Chatbot.postprocess = postprocess
56
  PromptHelper.compact_text_chunks = compact_text_chunks
@@ -75,19 +81,19 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
75
  with gr.Column(scale=4):
76
  status_display = gr.Markdown(get_geoip(), elem_id="status_display")
77
 
78
- with gr.Row(scale=1).style(equal_height=True):
79
  with gr.Column(scale=5):
80
- with gr.Row(scale=1):
81
  chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%")
82
- with gr.Row(scale=1):
83
  with gr.Column(scale=12):
84
  user_input = gr.Textbox(
85
- show_label=False, placeholder="在这里输入", interactive=True
86
  ).style(container=False)
87
  with gr.Column(min_width=70, scale=1):
88
  submitBtn = gr.Button("发送", variant="primary")
89
  cancelBtn = gr.Button("取消", variant="secondary", visible=False)
90
- with gr.Row(scale=1):
91
  emptyBtn = gr.Button(
92
  "🧹 新的对话",
93
  )
@@ -107,7 +113,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
107
  visible=not HIDE_MY_KEY,
108
  label="API-Key",
109
  )
110
- usageTxt = gr.Markdown(get_usage(my_api_key), elem_id="usage_display")
111
  model_select_dropdown = gr.Dropdown(
112
  label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0]
113
  )
@@ -207,7 +213,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
207
  label="Temperature",
208
  )
209
 
210
- with gr.Accordion("网络设置", open=False):
211
  apiurlTxt = gr.Textbox(
212
  show_label=True,
213
  placeholder=f"在这里输入API地址...",
@@ -226,7 +232,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
226
  changeProxyBtn = gr.Button("🔄 设置代理地址")
227
 
228
  gr.Markdown(description)
229
-
230
  chatgpt_predict_args = dict(
231
  fn=predict,
232
  inputs=[
@@ -264,13 +270,14 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
264
  )
265
 
266
  transfer_input_args = dict(
267
- fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input], show_progress=True
268
  )
269
 
270
  get_usage_args = dict(
271
  fn=get_usage, inputs=[user_api_key], outputs=[usageTxt], show_progress=False
272
  )
273
 
 
274
  # Chatbot
275
  cancelBtn.click(cancel_outputing, [], [])
276
 
@@ -287,8 +294,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
287
  )
288
  emptyBtn.click(**reset_textbox_args)
289
 
290
- retryBtn.click(**reset_textbox_args)
291
- retryBtn.click(
292
  retry,
293
  [
294
  user_api_key,
@@ -304,7 +310,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
304
  ],
305
  [chatbot, history, status_display, token_count],
306
  show_progress=True,
307
- )
308
  retryBtn.click(**get_usage_args)
309
 
310
  delFirstBtn.click(
@@ -330,7 +336,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
330
  token_count,
331
  top_p,
332
  temperature,
333
- gr.State(0),
334
  model_select_dropdown,
335
  language_select_dropdown,
336
  ],
@@ -341,6 +347,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
341
 
342
  # ChatGPT
343
  keyTxt.change(submit_key, keyTxt, [user_api_key, status_display]).then(**get_usage_args)
 
344
 
345
  # Template
346
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
@@ -417,7 +424,7 @@ if __name__ == "__main__":
417
  demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
418
  server_name="0.0.0.0",
419
  server_port=7860,
420
- auth=(username, password),
421
  favicon_path="./assets/favicon.ico",
422
  )
423
  else:
@@ -432,7 +439,7 @@ if __name__ == "__main__":
432
  if authflag:
433
  demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
434
  share=False,
435
- auth=(username, password),
436
  favicon_path="./assets/favicon.ico",
437
  inbrowser=True,
438
  )
 
25
  dockerflag = False
26
 
27
  authflag = False
28
+ auth_list = []
29
 
30
+ if not my_api_key:
31
  my_api_key = os.environ.get("my_api_key")
32
+ if dockerflag:
33
  if my_api_key == "empty":
34
  logging.error("Please give a api key!")
35
  sys.exit(1)
 
37
  username = os.environ.get("USERNAME")
38
  password = os.environ.get("PASSWORD")
39
  if not (isinstance(username, type(None)) or isinstance(password, type(None))):
40
+ auth_list.append((os.environ.get("USERNAME"), os.environ.get("PASSWORD")))
41
  authflag = True
42
  else:
43
  if (
 
48
  with open("api_key.txt", "r") as f:
49
  my_api_key = f.read().strip()
50
  if os.path.exists("auth.json"):
51
+ authflag = True
52
  with open("auth.json", "r", encoding='utf-8') as f:
53
  auth = json.load(f)
54
+ for _ in auth:
55
+ if auth[_]["username"] and auth[_]["password"]:
56
+ auth_list.append((auth[_]["username"], auth[_]["password"]))
57
+ else:
58
+ logging.error("请检查auth.json文件中的用户名和密码!")
59
+ sys.exit(1)
60
 
61
  gr.Chatbot.postprocess = postprocess
62
  PromptHelper.compact_text_chunks = compact_text_chunks
 
81
  with gr.Column(scale=4):
82
  status_display = gr.Markdown(get_geoip(), elem_id="status_display")
83
 
84
+ with gr.Row().style(equal_height=True):
85
  with gr.Column(scale=5):
86
+ with gr.Row():
87
  chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%")
88
+ with gr.Row():
89
  with gr.Column(scale=12):
90
  user_input = gr.Textbox(
91
+ show_label=False, placeholder="在这里输入"
92
  ).style(container=False)
93
  with gr.Column(min_width=70, scale=1):
94
  submitBtn = gr.Button("发送", variant="primary")
95
  cancelBtn = gr.Button("取消", variant="secondary", visible=False)
96
+ with gr.Row():
97
  emptyBtn = gr.Button(
98
  "🧹 新的对话",
99
  )
 
113
  visible=not HIDE_MY_KEY,
114
  label="API-Key",
115
  )
116
+ usageTxt = gr.Markdown("**发送消息** 或 **提交key** 以显示额度", elem_id="usage_display")
117
  model_select_dropdown = gr.Dropdown(
118
  label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0]
119
  )
 
213
  label="Temperature",
214
  )
215
 
216
+ with gr.Accordion("网络设置", open=False, visible=False):
217
  apiurlTxt = gr.Textbox(
218
  show_label=True,
219
  placeholder=f"在这里输入API地址...",
 
232
  changeProxyBtn = gr.Button("🔄 设置代理地址")
233
 
234
  gr.Markdown(description)
235
+ gr.HTML(footer.format(versions=versions_html()), elem_id="footer")
236
  chatgpt_predict_args = dict(
237
  fn=predict,
238
  inputs=[
 
270
  )
271
 
272
  transfer_input_args = dict(
273
+ fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn, cancelBtn], show_progress=True
274
  )
275
 
276
  get_usage_args = dict(
277
  fn=get_usage, inputs=[user_api_key], outputs=[usageTxt], show_progress=False
278
  )
279
 
280
+
281
  # Chatbot
282
  cancelBtn.click(cancel_outputing, [], [])
283
 
 
294
  )
295
  emptyBtn.click(**reset_textbox_args)
296
 
297
+ retryBtn.click(**start_outputing_args).then(
 
298
  retry,
299
  [
300
  user_api_key,
 
310
  ],
311
  [chatbot, history, status_display, token_count],
312
  show_progress=True,
313
+ ).then(**end_outputing_args)
314
  retryBtn.click(**get_usage_args)
315
 
316
  delFirstBtn.click(
 
336
  token_count,
337
  top_p,
338
  temperature,
339
+ gr.State(sum(token_count.value[-4:])),
340
  model_select_dropdown,
341
  language_select_dropdown,
342
  ],
 
347
 
348
  # ChatGPT
349
  keyTxt.change(submit_key, keyTxt, [user_api_key, status_display]).then(**get_usage_args)
350
+ keyTxt.submit(**get_usage_args)
351
 
352
  # Template
353
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
 
424
  demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
425
  server_name="0.0.0.0",
426
  server_port=7860,
427
+ auth=auth_list,
428
  favicon_path="./assets/favicon.ico",
429
  )
430
  else:
 
439
  if authflag:
440
  demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
441
  share=False,
442
+ auth=auth_list,
443
  favicon_path="./assets/favicon.ico",
444
  inbrowser=True,
445
  )
assets/custom.css CHANGED
@@ -3,6 +3,21 @@
3
  --chatbot-color-dark: #121111;
4
  }
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  /* status_display */
7
  #status_display {
8
  display: flex;
@@ -22,14 +37,45 @@
22
 
23
  /* usage_display */
24
  #usage_display {
25
- height: 1em;
26
- }
27
- #usage_display p{
28
- padding: 0 1em;
 
 
 
 
 
 
 
 
 
 
29
  font-size: .85em;
30
- font-family: monospace;
31
  color: var(--body-text-color-subdued);
32
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  /* list */
34
  ol:not(.options), ul:not(.options) {
35
  padding-inline-start: 2em !important;
@@ -64,6 +110,7 @@ ol:not(.options), ul:not(.options) {
64
  background-color: var(--neutral-950) !important;
65
  }
66
  }
 
67
  /* 对话气泡 */
68
  [class *= "message"] {
69
  border-radius: var(--radius-xl) !important;
 
3
  --chatbot-color-dark: #121111;
4
  }
5
 
6
+ /* 覆盖gradio的页脚信息QAQ */
7
+ footer {
8
+ display: none !important;
9
+ }
10
+ #footer{
11
+ text-align: center;
12
+ }
13
+ #footer div{
14
+ display: inline-block;
15
+ }
16
+ #footer .versions{
17
+ font-size: 85%;
18
+ opacity: 0.85;
19
+ }
20
+
21
  /* status_display */
22
  #status_display {
23
  display: flex;
 
37
 
38
  /* usage_display */
39
  #usage_display {
40
+ position: relative;
41
+ margin: 0;
42
+ box-shadow: var(--block-shadow);
43
+ border-width: var(--block-border-width);
44
+ border-color: var(--block-border-color);
45
+ border-radius: var(--block-radius);
46
+ background: var(--block-background-fill);
47
+ width: 100%;
48
+ line-height: var(--line-sm);
49
+ min-height: 2em;
50
+ }
51
+ #usage_display p, #usage_display span {
52
+ margin: 0;
53
+ padding: .5em 1em;
54
  font-size: .85em;
 
55
  color: var(--body-text-color-subdued);
56
  }
57
+ .progress-bar {
58
+ background-color: var(--input-background-fill);;
59
+ margin: 0 1em;
60
+ height: 20px;
61
+ border-radius: 10px;
62
+ overflow: hidden;
63
+ }
64
+ .progress {
65
+ background-color: var(--block-title-background-fill);;
66
+ height: 100%;
67
+ border-radius: 10px;
68
+ text-align: right;
69
+ transition: width 0.5s ease-in-out;
70
+ }
71
+ .progress-text {
72
+ /* color: white; */
73
+ color: var(--color-accent) !important;
74
+ font-size: 1em !important;
75
+ font-weight: bold;
76
+ padding-right: 10px;
77
+ line-height: 20px;
78
+ }
79
  /* list */
80
  ol:not(.options), ul:not(.options) {
81
  padding-inline-start: 2em !important;
 
110
  background-color: var(--neutral-950) !important;
111
  }
112
  }
113
+
114
  /* 对话气泡 */
115
  [class *= "message"] {
116
  border-radius: var(--radius-xl) !important;
modules/__pycache__/chat_func.cpython-39.pyc CHANGED
Binary files a/modules/__pycache__/chat_func.cpython-39.pyc and b/modules/__pycache__/chat_func.cpython-39.pyc differ
 
modules/__pycache__/llama_func.cpython-39.pyc CHANGED
Binary files a/modules/__pycache__/llama_func.cpython-39.pyc and b/modules/__pycache__/llama_func.cpython-39.pyc differ
 
modules/__pycache__/openai_func.cpython-39.pyc CHANGED
Binary files a/modules/__pycache__/openai_func.cpython-39.pyc and b/modules/__pycache__/openai_func.cpython-39.pyc differ
 
modules/__pycache__/presets.cpython-39.pyc CHANGED
Binary files a/modules/__pycache__/presets.cpython-39.pyc and b/modules/__pycache__/presets.cpython-39.pyc differ
 
modules/__pycache__/utils.cpython-39.pyc CHANGED
Binary files a/modules/__pycache__/utils.cpython-39.pyc and b/modules/__pycache__/utils.cpython-39.pyc differ
 
modules/chat_func.py CHANGED
@@ -13,6 +13,9 @@ import colorama
13
  from duckduckgo_search import ddg
14
  import asyncio
15
  import aiohttp
 
 
 
16
 
17
  from modules.presets import *
18
  from modules.llama_func import *
@@ -58,39 +61,21 @@ def get_response(
58
  else:
59
  timeout = timeout_all
60
 
61
- # 获取环境变量中的代理设置
62
- http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
63
- https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
64
-
65
- # 如果存在代理设置,使用它们
66
- proxies = {}
67
- if http_proxy:
68
- logging.info(f"使用 HTTP 代理: {http_proxy}")
69
- proxies["http"] = http_proxy
70
- if https_proxy:
71
- logging.info(f"使用 HTTPS 代理: {https_proxy}")
72
- proxies["https"] = https_proxy
73
 
74
  # 如果有自定义的api-url,使用自定义url发送请求,否则使用默认设置发送请求
75
  if shared.state.api_url != API_URL:
76
  logging.info(f"使用自定义API URL: {shared.state.api_url}")
77
- if proxies:
78
- response = requests.post(
79
- shared.state.api_url,
80
- headers=headers,
81
- json=payload,
82
- stream=True,
83
- timeout=timeout,
84
- proxies=proxies,
85
- )
86
- else:
87
- response = requests.post(
88
- shared.state.api_url,
89
- headers=headers,
90
- json=payload,
91
- stream=True,
92
- timeout=timeout,
93
- )
94
  return response
95
 
96
 
@@ -121,13 +106,17 @@ def stream_predict(
121
  else:
122
  chatbot.append((inputs, ""))
123
  user_token_count = 0
 
 
 
 
124
  if len(all_token_counts) == 0:
125
  system_prompt_token_count = count_token(construct_system(system_prompt))
126
  user_token_count = (
127
- count_token(construct_user(inputs)) + system_prompt_token_count
128
  )
129
  else:
130
- user_token_count = count_token(construct_user(inputs))
131
  all_token_counts.append(user_token_count)
132
  logging.info(f"输入token计数: {user_token_count}")
133
  yield get_return_value()
@@ -155,6 +144,8 @@ def stream_predict(
155
  yield get_return_value()
156
  error_json_str = ""
157
 
 
 
158
  for chunk in tqdm(response.iter_lines()):
159
  if counter == 0:
160
  counter += 1
@@ -219,7 +210,10 @@ def predict_all(
219
  chatbot.append((fake_input, ""))
220
  else:
221
  chatbot.append((inputs, ""))
222
- all_token_counts.append(count_token(construct_user(inputs)))
 
 
 
223
  try:
224
  response = get_response(
225
  openai_api_key,
@@ -242,13 +236,22 @@ def predict_all(
242
  status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
243
  return chatbot, history, status_text, all_token_counts
244
  response = json.loads(response.text)
245
- content = response["choices"][0]["message"]["content"]
246
- history[-1] = construct_assistant(content)
247
- chatbot[-1] = (chatbot[-1][0], content+display_append)
248
- total_token_count = response["usage"]["total_tokens"]
249
- all_token_counts[-1] = total_token_count - sum(all_token_counts)
250
- status_text = construct_token_message(total_token_count)
251
- return chatbot, history, status_text, all_token_counts
 
 
 
 
 
 
 
 
 
252
 
253
 
254
  def predict(
@@ -268,40 +271,59 @@ def predict(
268
  should_check_token_count=True,
269
  ): # repetition_penalty, top_k
270
  logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
271
- yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
 
272
  if reply_language == "跟随问题语言(不稳定)":
273
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
 
 
 
274
  if files:
 
 
275
  msg = "加载索引中……(这可能需要几分钟)"
276
  logging.info(msg)
277
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
278
  index = construct_index(openai_api_key, file_src=files)
279
  msg = "索引构建完成,获取回答中……"
 
280
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
281
- history, chatbot, status_text = chat_ai(openai_api_key, index, inputs, history, chatbot, reply_language)
282
- yield chatbot, history, status_text, all_token_counts
283
- return
284
-
285
- old_inputs = ""
286
- link_references = []
287
- if use_websearch:
 
 
 
 
 
 
 
 
 
 
 
288
  search_results = ddg(inputs, max_results=5)
289
  old_inputs = inputs
290
- web_results = []
291
  for idx, result in enumerate(search_results):
292
  logging.info(f"搜索结果{idx + 1}:{result}")
293
  domain_name = urllib3.util.parse_url(result["href"]).host
294
- web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
295
- link_references.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
296
- link_references = "\n\n" + "".join(link_references)
 
297
  inputs = (
298
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
299
  .replace("{query}", inputs)
300
- .replace("{web_results}", "\n\n".join(web_results))
301
  .replace("{reply_language}", reply_language )
302
  )
303
  else:
304
- link_references = ""
305
 
306
  if len(openai_api_key) != 51:
307
  status_text = standard_error_msg + no_apikey_msg
@@ -334,7 +356,7 @@ def predict(
334
  temperature,
335
  selected_model,
336
  fake_input=old_inputs,
337
- display_append=link_references
338
  )
339
  for chatbot, history, status_text, all_token_counts in iter:
340
  if shared.state.interrupted:
@@ -354,7 +376,7 @@ def predict(
354
  temperature,
355
  selected_model,
356
  fake_input=old_inputs,
357
- display_append=link_references
358
  )
359
  yield chatbot, history, status_text, all_token_counts
360
 
@@ -367,10 +389,15 @@ def predict(
367
  + colorama.Style.RESET_ALL
368
  )
369
 
 
 
 
 
 
370
  if stream:
371
- max_token = max_token_streaming
372
  else:
373
- max_token = max_token_all
374
 
375
  if sum(all_token_counts) > max_token and should_check_token_count:
376
  status_text = f"精简token中{all_token_counts}/{max_token}"
@@ -460,6 +487,7 @@ def reduce_token_size(
460
  flag = False
461
  for chatbot, history, status_text, previous_token_count in iter:
462
  num_chat = find_n(previous_token_count, max_token_count)
 
463
  if flag:
464
  chatbot = chatbot[:-1]
465
  flag = True
 
13
  from duckduckgo_search import ddg
14
  import asyncio
15
  import aiohttp
16
+ from llama_index.indices.query.vector_store import GPTVectorStoreIndexQuery
17
+ from llama_index.indices.query.schema import QueryBundle
18
+ from langchain.llms import OpenAIChat
19
 
20
  from modules.presets import *
21
  from modules.llama_func import *
 
61
  else:
62
  timeout = timeout_all
63
 
64
+ proxies = get_proxies()
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # 如果有自定义的api-url,使用自定义url发送请求,否则使用默认设置发送请求
67
  if shared.state.api_url != API_URL:
68
  logging.info(f"使用自定义API URL: {shared.state.api_url}")
69
+
70
+ response = requests.post(
71
+ shared.state.api_url,
72
+ headers=headers,
73
+ json=payload,
74
+ stream=True,
75
+ timeout=timeout,
76
+ proxies=proxies,
77
+ )
78
+
 
 
 
 
 
 
 
79
  return response
80
 
81
 
 
106
  else:
107
  chatbot.append((inputs, ""))
108
  user_token_count = 0
109
+ if fake_input is not None:
110
+ input_token_count = count_token(construct_user(fake_input))
111
+ else:
112
+ input_token_count = count_token(construct_user(inputs))
113
  if len(all_token_counts) == 0:
114
  system_prompt_token_count = count_token(construct_system(system_prompt))
115
  user_token_count = (
116
+ input_token_count + system_prompt_token_count
117
  )
118
  else:
119
+ user_token_count = input_token_count
120
  all_token_counts.append(user_token_count)
121
  logging.info(f"输入token计数: {user_token_count}")
122
  yield get_return_value()
 
144
  yield get_return_value()
145
  error_json_str = ""
146
 
147
+ if fake_input is not None:
148
+ history[-2] = construct_user(fake_input)
149
  for chunk in tqdm(response.iter_lines()):
150
  if counter == 0:
151
  counter += 1
 
210
  chatbot.append((fake_input, ""))
211
  else:
212
  chatbot.append((inputs, ""))
213
+ if fake_input is not None:
214
+ all_token_counts.append(count_token(construct_user(fake_input)))
215
+ else:
216
+ all_token_counts.append(count_token(construct_user(inputs)))
217
  try:
218
  response = get_response(
219
  openai_api_key,
 
236
  status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
237
  return chatbot, history, status_text, all_token_counts
238
  response = json.loads(response.text)
239
+ if fake_input is not None:
240
+ history[-2] = construct_user(fake_input)
241
+ try:
242
+ content = response["choices"][0]["message"]["content"]
243
+ history[-1] = construct_assistant(content)
244
+ chatbot[-1] = (chatbot[-1][0], content+display_append)
245
+ total_token_count = response["usage"]["total_tokens"]
246
+ if fake_input is not None:
247
+ all_token_counts[-1] += count_token(construct_assistant(content))
248
+ else:
249
+ all_token_counts[-1] = total_token_count - sum(all_token_counts)
250
+ status_text = construct_token_message(total_token_count)
251
+ return chatbot, history, status_text, all_token_counts
252
+ except KeyError:
253
+ status_text = standard_error_msg + str(response)
254
+ return chatbot, history, status_text, all_token_counts
255
 
256
 
257
  def predict(
 
271
  should_check_token_count=True,
272
  ): # repetition_penalty, top_k
273
  logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
274
+ if should_check_token_count:
275
+ yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
276
  if reply_language == "跟随问题语言(不稳定)":
277
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
278
+ old_inputs = None
279
+ display_reference = []
280
+ limited_context = False
281
  if files:
282
+ limited_context = True
283
+ old_inputs = inputs
284
  msg = "加载索引中……(这可能需要几分钟)"
285
  logging.info(msg)
286
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
287
  index = construct_index(openai_api_key, file_src=files)
288
  msg = "索引构建完成,获取回答中……"
289
+ logging.info(msg)
290
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
291
+ llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
292
+ prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
293
+ service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
294
+ query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
295
+ query_bundle = QueryBundle(inputs)
296
+ nodes = query_object.retrieve(query_bundle)
297
+ reference_results = [n.node.text for n in nodes]
298
+ reference_results = add_source_numbers(reference_results, use_source=False)
299
+ display_reference = add_details(reference_results)
300
+ display_reference = "\n\n" + "".join(display_reference)
301
+ inputs = (
302
+ replace_today(PROMPT_TEMPLATE)
303
+ .replace("{query_str}", inputs)
304
+ .replace("{context_str}", "\n\n".join(reference_results))
305
+ .replace("{reply_language}", reply_language )
306
+ )
307
+ elif use_websearch:
308
+ limited_context = True
309
  search_results = ddg(inputs, max_results=5)
310
  old_inputs = inputs
311
+ reference_results = []
312
  for idx, result in enumerate(search_results):
313
  logging.info(f"搜索结果{idx + 1}:{result}")
314
  domain_name = urllib3.util.parse_url(result["href"]).host
315
+ reference_results.append([result["body"], result["href"]])
316
+ display_reference.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
317
+ reference_results = add_source_numbers(reference_results)
318
+ display_reference = "\n\n" + "".join(display_reference)
319
  inputs = (
320
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
321
  .replace("{query}", inputs)
322
+ .replace("{web_results}", "\n\n".join(reference_results))
323
  .replace("{reply_language}", reply_language )
324
  )
325
  else:
326
+ display_reference = ""
327
 
328
  if len(openai_api_key) != 51:
329
  status_text = standard_error_msg + no_apikey_msg
 
356
  temperature,
357
  selected_model,
358
  fake_input=old_inputs,
359
+ display_append=display_reference
360
  )
361
  for chatbot, history, status_text, all_token_counts in iter:
362
  if shared.state.interrupted:
 
376
  temperature,
377
  selected_model,
378
  fake_input=old_inputs,
379
+ display_append=display_reference
380
  )
381
  yield chatbot, history, status_text, all_token_counts
382
 
 
389
  + colorama.Style.RESET_ALL
390
  )
391
 
392
+ if limited_context:
393
+ history = history[-4:]
394
+ all_token_counts = all_token_counts[-2:]
395
+ yield chatbot, history, status_text, all_token_counts
396
+
397
  if stream:
398
+ max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["streaming"]
399
  else:
400
+ max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["all"]
401
 
402
  if sum(all_token_counts) > max_token and should_check_token_count:
403
  status_text = f"精简token中{all_token_counts}/{max_token}"
 
487
  flag = False
488
  for chatbot, history, status_text, previous_token_count in iter:
489
  num_chat = find_n(previous_token_count, max_token_count)
490
+ logging.info(f"previous_token_count: {previous_token_count}, keeping {num_chat} chats")
491
  if flag:
492
  chatbot = chatbot[:-1]
493
  flag = True
modules/llama_func.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import logging
3
 
4
- from llama_index import GPTSimpleVectorIndex
5
  from llama_index import download_loader
6
  from llama_index import (
7
  Document,
@@ -11,19 +11,32 @@ from llama_index import (
11
  RefinePrompt,
12
  )
13
  from langchain.llms import OpenAI
 
14
  import colorama
 
 
15
 
16
  from modules.presets import *
17
  from modules.utils import *
18
 
19
  def get_index_name(file_src):
20
- index_name = []
21
- for file in file_src:
22
- index_name.append(os.path.basename(file.name))
23
- index_name = sorted(index_name)
24
- index_name = "".join(index_name)
25
- index_name = sha1sum(index_name)
26
- return index_name
 
 
 
 
 
 
 
 
 
 
27
 
28
  def get_documents(file_src):
29
  documents = []
@@ -33,9 +46,12 @@ def get_documents(file_src):
33
  logging.info(f"loading file: {file.name}")
34
  if os.path.splitext(file.name)[1] == ".pdf":
35
  logging.debug("Loading PDF...")
36
- CJKPDFReader = download_loader("CJKPDFReader")
37
- loader = CJKPDFReader()
38
- text_raw = loader.load_data(file=file.name)[0].text
 
 
 
39
  elif os.path.splitext(file.name)[1] == ".docx":
40
  logging.debug("Loading DOCX...")
41
  DocxReader = download_loader("DocxReader")
@@ -51,7 +67,10 @@ def get_documents(file_src):
51
  with open(file.name, "r", encoding="utf-8") as f:
52
  text_raw = f.read()
53
  text = add_space(text_raw)
 
 
54
  documents += [Document(text)]
 
55
  return documents
56
 
57
 
@@ -59,13 +78,11 @@ def construct_index(
59
  api_key,
60
  file_src,
61
  max_input_size=4096,
62
- num_outputs=1,
63
  max_chunk_overlap=20,
64
  chunk_size_limit=600,
65
  embedding_limit=None,
66
- separator=" ",
67
- num_children=10,
68
- max_keywords_per_chunk=10,
69
  ):
70
  os.environ["OPENAI_API_KEY"] = api_key
71
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
@@ -73,16 +90,9 @@ def construct_index(
73
  separator = " " if separator == "" else separator
74
 
75
  llm_predictor = LLMPredictor(
76
- llm=OpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
77
- )
78
- prompt_helper = PromptHelper(
79
- max_input_size,
80
- num_outputs,
81
- max_chunk_overlap,
82
- embedding_limit,
83
- chunk_size_limit,
84
- separator=separator,
85
  )
 
86
  index_name = get_index_name(file_src)
87
  if os.path.exists(f"./index/{index_name}.json"):
88
  logging.info("找到了缓存的索引文件,加载中……")
@@ -90,14 +100,19 @@ def construct_index(
90
  else:
91
  try:
92
  documents = get_documents(file_src)
93
- logging.debug("构建索引中……")
94
- index = GPTSimpleVectorIndex(
95
- documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
 
96
  )
 
97
  os.makedirs("./index", exist_ok=True)
98
  index.save_to_disk(f"./index/{index_name}.json")
 
99
  return index
 
100
  except Exception as e:
 
101
  print(e)
102
  return None
103
 
@@ -144,7 +159,7 @@ def ask_ai(
144
  question,
145
  prompt_tmpl,
146
  refine_tmpl,
147
- sim_k=1,
148
  temprature=0,
149
  prefix_messages=[],
150
  reply_language="中文",
@@ -154,7 +169,7 @@ def ask_ai(
154
  logging.debug("Index file found")
155
  logging.debug("Querying index...")
156
  llm_predictor = LLMPredictor(
157
- llm=OpenAI(
158
  temperature=temprature,
159
  model_name="gpt-3.5-turbo-0301",
160
  prefix_messages=prefix_messages,
@@ -166,7 +181,6 @@ def ask_ai(
166
  rf_prompt = RefinePrompt(refine_tmpl.replace("{reply_language}", reply_language))
167
  response = index.query(
168
  question,
169
- llm_predictor=llm_predictor,
170
  similarity_top_k=sim_k,
171
  text_qa_template=qa_prompt,
172
  refine_template=rf_prompt,
 
1
  import os
2
  import logging
3
 
4
+ from llama_index import GPTSimpleVectorIndex, ServiceContext
5
  from llama_index import download_loader
6
  from llama_index import (
7
  Document,
 
11
  RefinePrompt,
12
  )
13
  from langchain.llms import OpenAI
14
+ from langchain.chat_models import ChatOpenAI
15
  import colorama
16
+ import PyPDF2
17
+ from tqdm import tqdm
18
 
19
  from modules.presets import *
20
  from modules.utils import *
21
 
22
  def get_index_name(file_src):
23
+ file_paths = [x.name for x in file_src]
24
+ file_paths.sort(key=lambda x: os.path.basename(x))
25
+
26
+ md5_hash = hashlib.md5()
27
+ for file_path in file_paths:
28
+ with open(file_path, "rb") as f:
29
+ while chunk := f.read(8192):
30
+ md5_hash.update(chunk)
31
+
32
+ return md5_hash.hexdigest()
33
+
34
+ def block_split(text):
35
+ blocks = []
36
+ while len(text) > 0:
37
+ blocks.append(Document(text[:1000]))
38
+ text = text[1000:]
39
+ return blocks
40
 
41
  def get_documents(file_src):
42
  documents = []
 
46
  logging.info(f"loading file: {file.name}")
47
  if os.path.splitext(file.name)[1] == ".pdf":
48
  logging.debug("Loading PDF...")
49
+ pdftext = ""
50
+ with open(file.name, 'rb') as pdfFileObj:
51
+ pdfReader = PyPDF2.PdfReader(pdfFileObj)
52
+ for page in tqdm(pdfReader.pages):
53
+ pdftext += page.extract_text()
54
+ text_raw = pdftext
55
  elif os.path.splitext(file.name)[1] == ".docx":
56
  logging.debug("Loading DOCX...")
57
  DocxReader = download_loader("DocxReader")
 
67
  with open(file.name, "r", encoding="utf-8") as f:
68
  text_raw = f.read()
69
  text = add_space(text_raw)
70
+ # text = block_split(text)
71
+ # documents += text
72
  documents += [Document(text)]
73
+ logging.debug("Documents loaded.")
74
  return documents
75
 
76
 
 
78
  api_key,
79
  file_src,
80
  max_input_size=4096,
81
+ num_outputs=5,
82
  max_chunk_overlap=20,
83
  chunk_size_limit=600,
84
  embedding_limit=None,
85
+ separator=" "
 
 
86
  ):
87
  os.environ["OPENAI_API_KEY"] = api_key
88
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
 
90
  separator = " " if separator == "" else separator
91
 
92
  llm_predictor = LLMPredictor(
93
+ llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
 
 
 
 
 
 
 
 
94
  )
95
+ prompt_helper = PromptHelper(max_input_size = max_input_size, num_output = num_outputs, max_chunk_overlap = max_chunk_overlap, embedding_limit=embedding_limit, chunk_size_limit=600, separator=separator)
96
  index_name = get_index_name(file_src)
97
  if os.path.exists(f"./index/{index_name}.json"):
98
  logging.info("找到了缓存的索引文件,加载中……")
 
100
  else:
101
  try:
102
  documents = get_documents(file_src)
103
+ logging.info("构建索引中……")
104
+ service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit=chunk_size_limit)
105
+ index = GPTSimpleVectorIndex.from_documents(
106
+ documents, service_context=service_context
107
  )
108
+ logging.debug("索引构建完成!")
109
  os.makedirs("./index", exist_ok=True)
110
  index.save_to_disk(f"./index/{index_name}.json")
111
+ logging.debug("索引已保存至本地!")
112
  return index
113
+
114
  except Exception as e:
115
+ logging.error("索引构建失败!", e)
116
  print(e)
117
  return None
118
 
 
159
  question,
160
  prompt_tmpl,
161
  refine_tmpl,
162
+ sim_k=5,
163
  temprature=0,
164
  prefix_messages=[],
165
  reply_language="中文",
 
169
  logging.debug("Index file found")
170
  logging.debug("Querying index...")
171
  llm_predictor = LLMPredictor(
172
+ llm=ChatOpenAI(
173
  temperature=temprature,
174
  model_name="gpt-3.5-turbo-0301",
175
  prefix_messages=prefix_messages,
 
181
  rf_prompt = RefinePrompt(refine_tmpl.replace("{reply_language}", reply_language))
182
  response = index.query(
183
  question,
 
184
  similarity_top_k=sim_k,
185
  text_qa_template=qa_prompt,
186
  refine_template=rf_prompt,
modules/openai_func.py CHANGED
@@ -1,70 +1,82 @@
1
  import requests
2
  import logging
3
- from modules.presets import timeout_all, BALANCE_API_URL,standard_error_msg,connection_timeout_prompt,error_retrieve_prompt,read_timeout_prompt
4
- from modules import shared
5
- import os
 
 
 
 
 
 
6
 
 
 
 
7
 
8
- def get_usage_response(openai_api_key):
9
  headers = {
10
  "Content-Type": "application/json",
11
- "Authorization": f"Bearer {openai_api_key}",
12
  }
13
-
14
  timeout = timeout_all
15
-
16
- # 获取环境变量中的代理设置
17
- http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
18
- https_proxy = os.environ.get(
19
- "HTTPS_PROXY") or os.environ.get("https_proxy")
20
-
21
- # 如果存在代理设置,使用它们
22
- proxies = {}
23
- if http_proxy:
24
- logging.info(f"使用 HTTP 代理: {http_proxy}")
25
- proxies["http"] = http_proxy
26
- if https_proxy:
27
- logging.info(f"使用 HTTPS 代理: {https_proxy}")
28
- proxies["https"] = https_proxy
29
-
30
- # 如果有代理,使用代理发送请求,否则使用默认设置发送请求
31
- """
32
- 暂不支持修改
33
- if shared.state.balance_api_url != BALANCE_API_URL:
34
- logging.info(f"使用自定义BALANCE API URL: {shared.state.balance_api_url}")
35
- """
36
- if proxies:
37
- response = requests.get(
38
- BALANCE_API_URL,
39
- headers=headers,
40
- timeout=timeout,
41
- proxies=proxies,
42
- )
43
  else:
44
- response = requests.get(
45
- BALANCE_API_URL,
46
- headers=headers,
47
- timeout=timeout,
48
- )
49
- return response
50
 
51
  def get_usage(openai_api_key):
52
  try:
53
- response=get_usage_response(openai_api_key=openai_api_key)
54
- logging.debug(response.json())
55
  try:
56
- balance = response.json().get("total_available") if response.json().get(
57
- "total_available") else 0
58
- total_used = response.json().get("total_used") if response.json().get(
59
- "total_used") else 0
60
  except Exception as e:
61
  logging.error(f"API使用情况解析失败:"+str(e))
62
  balance = 0
63
  total_used=0
64
- return f"**API使用情况**(已用/余额)\u3000{total_used}$ / {balance}$"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  except requests.exceptions.ConnectTimeout:
66
  status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
67
  return status_text
68
  except requests.exceptions.ReadTimeout:
69
  status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
70
  return status_text
 
 
 
 
1
  import requests
2
  import logging
3
+ from modules.presets import (
4
+ timeout_all,
5
+ USAGE_API_URL,
6
+ BALANCE_API_URL,
7
+ standard_error_msg,
8
+ connection_timeout_prompt,
9
+ error_retrieve_prompt,
10
+ read_timeout_prompt
11
+ )
12
 
13
+ from modules import shared
14
+ from modules.utils import get_proxies
15
+ import os, datetime
16
 
17
+ def get_billing_data(openai_api_key, billing_url):
18
  headers = {
19
  "Content-Type": "application/json",
20
+ "Authorization": f"Bearer {openai_api_key}"
21
  }
22
+
23
  timeout = timeout_all
24
+ proxies = get_proxies()
25
+ response = requests.get(
26
+ billing_url,
27
+ headers=headers,
28
+ timeout=timeout,
29
+ proxies=proxies,
30
+ )
31
+
32
+ if response.status_code == 200:
33
+ data = response.json()
34
+ return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  else:
36
+ raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
37
+
 
 
 
 
38
 
39
  def get_usage(openai_api_key):
40
  try:
41
+ balance_data=get_billing_data(openai_api_key, BALANCE_API_URL)
42
+ logging.debug(balance_data)
43
  try:
44
+ balance = balance_data["total_available"] if balance_data["total_available"] else 0
45
+ total_used = balance_data["total_used"] if balance_data["total_used"] else 0
46
+ usage_percent = round(total_used / (total_used+balance) * 100, 2)
 
47
  except Exception as e:
48
  logging.error(f"API使用情况解析失败:"+str(e))
49
  balance = 0
50
  total_used=0
51
+ return f"**API使用情况解析失败**"
52
+ if balance == 0:
53
+ last_day_of_month = datetime.datetime.now().strftime("%Y-%m-%d")
54
+ first_day_of_month = datetime.datetime.now().replace(day=1).strftime("%Y-%m-%d")
55
+ usage_url = f"{USAGE_API_URL}?start_date={first_day_of_month}&end_date={last_day_of_month}"
56
+ try:
57
+ usage_data = get_billing_data(openai_api_key, usage_url)
58
+ except Exception as e:
59
+ logging.error(f"获取API使用情况失败:"+str(e))
60
+ return f"**获取API使用情况失败**"
61
+ return f"**本月使用金额** \u3000 ${usage_data['total_usage'] / 100}"
62
+
63
+ # return f"**免费额度**(已用/余额)\u3000${total_used} / ${balance}"
64
+ return f"""\
65
+ <b>免费额度使用情况</b>
66
+ <div class="progress-bar">
67
+ <div class="progress" style="width: {usage_percent}%;">
68
+ <span class="progress-text">{usage_percent}%</span>
69
+ </div>
70
+ </div>
71
+ <div style="display: flex; justify-content: space-between;"><span>已用 ${total_used}</span><span>可用 ${balance}</span></div>
72
+ """
73
+
74
  except requests.exceptions.ConnectTimeout:
75
  status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
76
  return status_text
77
  except requests.exceptions.ReadTimeout:
78
  status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
79
  return status_text
80
+ except Exception as e:
81
+ logging.error(f"获取API使用情况失败:"+str(e))
82
+ return standard_error_msg + error_retrieve_prompt
modules/presets.py CHANGED
@@ -5,6 +5,7 @@ import gradio as gr
5
  initial_prompt = "You are a helpful assistant."
6
  API_URL = "https://api.openai.com/v1/chat/completions"
7
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
 
8
  HISTORY_DIR = "history"
9
  TEMPLATES_DIR = "templates"
10
 
@@ -18,9 +19,7 @@ ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
18
  no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
19
  no_input_msg = "请输入对话内容。" # 未输入对话内容
20
 
21
- max_token_streaming = 3500 # 流式对话时的最大 token 数
22
  timeout_streaming = 10 # 流式对话时的超时时间
23
- max_token_all = 3500 # 非流式对话时的最大 token 数
24
  timeout_all = 200 # 非流式对话时的超时时间
25
  enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
26
  HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
@@ -41,6 +40,10 @@ description = """\
41
  </div>
42
  """
43
 
 
 
 
 
44
  summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
45
 
46
  MODELS = [
@@ -52,8 +55,36 @@ MODELS = [
52
  "gpt-4-32k-0314",
53
  ] # 可选的模型
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  REPLY_LANGUAGES = [
56
- "中文",
 
57
  "English",
58
  "日本語",
59
  "Español",
 
5
  initial_prompt = "You are a helpful assistant."
6
  API_URL = "https://api.openai.com/v1/chat/completions"
7
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
8
+ USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
9
  HISTORY_DIR = "history"
10
  TEMPLATES_DIR = "templates"
11
 
 
19
  no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
20
  no_input_msg = "请输入对话内容。" # 未输入对话内容
21
 
 
22
  timeout_streaming = 10 # 流式对话时的超时时间
 
23
  timeout_all = 200 # 非流式对话时的超时时间
24
  enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
25
  HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
 
40
  </div>
41
  """
42
 
43
+ footer = """\
44
+ <div class="versions">{versions}</div>
45
+ """
46
+
47
  summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
48
 
49
  MODELS = [
 
55
  "gpt-4-32k-0314",
56
  ] # 可选的模型
57
 
58
+ MODEL_SOFT_TOKEN_LIMIT = {
59
+ "gpt-3.5-turbo": {
60
+ "streaming": 3500,
61
+ "all": 3500
62
+ },
63
+ "gpt-3.5-turbo-0301": {
64
+ "streaming": 3500,
65
+ "all": 3500
66
+ },
67
+ "gpt-4": {
68
+ "streaming": 7500,
69
+ "all": 7500
70
+ },
71
+ "gpt-4-0314": {
72
+ "streaming": 7500,
73
+ "all": 7500
74
+ },
75
+ "gpt-4-32k": {
76
+ "streaming": 31000,
77
+ "all": 31000
78
+ },
79
+ "gpt-4-32k-0314": {
80
+ "streaming": 31000,
81
+ "all": 31000
82
+ }
83
+ }
84
+
85
  REPLY_LANGUAGES = [
86
+ "简体中文",
87
+ "繁體中文",
88
  "English",
89
  "日本語",
90
  "Español",
modules/utils.py CHANGED
@@ -10,6 +10,8 @@ import csv
10
  import requests
11
  import re
12
  import html
 
 
13
 
14
  import gradio as gr
15
  from pypinyin import lazy_pinyin
@@ -115,7 +117,11 @@ def convert_mdtext(md_text):
115
 
116
 
117
  def convert_asis(userinput):
118
- return f"<p style=\"white-space:pre-wrap;\">{html.escape(userinput)}</p>"+ALREADY_CONVERTED_MARK
 
 
 
 
119
 
120
  def detect_converted_mark(userinput):
121
  if userinput.endswith(ALREADY_CONVERTED_MARK):
@@ -153,6 +159,7 @@ def construct_assistant(text):
153
  def construct_token_message(token, stream=False):
154
  return f"Token 计数: {token}"
155
 
 
156
  def delete_first_conversation(history, previous_token_count):
157
  if history:
158
  del history[:2]
@@ -346,6 +353,8 @@ def change_proxy(proxy):
346
 
347
 
348
  def hide_middle_chars(s):
 
 
349
  if len(s) <= 8:
350
  return s
351
  else:
@@ -362,20 +371,14 @@ def submit_key(key):
362
  return key, msg
363
 
364
 
365
- def sha1sum(filename):
366
- sha1 = hashlib.sha1()
367
- sha1.update(filename.encode("utf-8"))
368
- return sha1.hexdigest()
369
-
370
-
371
  def replace_today(prompt):
372
  today = datetime.datetime.today().strftime("%Y-%m-%d")
373
  return prompt.replace("{current_date}", today)
374
 
375
 
376
  def get_geoip():
377
- response = requests.get("https://ipapi.co/json/", timeout=5)
378
  try:
 
379
  data = response.json()
380
  except:
381
  data = {"error": True, "reason": "连接ipapi失败"}
@@ -383,7 +386,7 @@ def get_geoip():
383
  logging.warning(f"无法获取IP地址信息。\n{data}")
384
  if data["reason"] == "RateLimited":
385
  return (
386
- f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
387
  )
388
  else:
389
  return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
@@ -427,8 +430,91 @@ def cancel_outputing():
427
  logging.info("中止输出……")
428
  shared.state.interrupt()
429
 
 
430
  def transfer_input(inputs):
431
  # 一次性返回,降低延迟
432
  textbox = reset_textbox()
433
  outputing = start_outputing()
434
- return inputs, gr.update(value="")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import requests
11
  import re
12
  import html
13
+ import sys
14
+ import subprocess
15
 
16
  import gradio as gr
17
  from pypinyin import lazy_pinyin
 
117
 
118
 
119
  def convert_asis(userinput):
120
+ return (
121
+ f'<p style="white-space:pre-wrap;">{html.escape(userinput)}</p>'
122
+ + ALREADY_CONVERTED_MARK
123
+ )
124
+
125
 
126
  def detect_converted_mark(userinput):
127
  if userinput.endswith(ALREADY_CONVERTED_MARK):
 
159
  def construct_token_message(token, stream=False):
160
  return f"Token 计数: {token}"
161
 
162
+
163
  def delete_first_conversation(history, previous_token_count):
164
  if history:
165
  del history[:2]
 
353
 
354
 
355
  def hide_middle_chars(s):
356
+ if s is None:
357
+ return ""
358
  if len(s) <= 8:
359
  return s
360
  else:
 
371
  return key, msg
372
 
373
 
 
 
 
 
 
 
374
  def replace_today(prompt):
375
  today = datetime.datetime.today().strftime("%Y-%m-%d")
376
  return prompt.replace("{current_date}", today)
377
 
378
 
379
  def get_geoip():
 
380
  try:
381
+ response = requests.get("https://ipapi.co/json/", timeout=5)
382
  data = response.json()
383
  except:
384
  data = {"error": True, "reason": "连接ipapi失败"}
 
386
  logging.warning(f"无法获取IP地址信息。\n{data}")
387
  if data["reason"] == "RateLimited":
388
  return (
389
+ f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用。"
390
  )
391
  else:
392
  return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
 
430
  logging.info("中止输出……")
431
  shared.state.interrupt()
432
 
433
+
434
  def transfer_input(inputs):
435
  # 一次性返回,降低延迟
436
  textbox = reset_textbox()
437
  outputing = start_outputing()
438
+ return (
439
+ inputs,
440
+ gr.update(value=""),
441
+ gr.Button.update(visible=True),
442
+ gr.Button.update(visible=False),
443
+ )
444
+
445
+
446
+ def get_proxies():
447
+ # 获取环境变量中的代理设置
448
+ http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
449
+ https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
450
+
451
+ # 如果存在代理设置,使用它们
452
+ proxies = {}
453
+ if http_proxy:
454
+ logging.info(f"使用 HTTP 代理: {http_proxy}")
455
+ proxies["http"] = http_proxy
456
+ if https_proxy:
457
+ logging.info(f"使用 HTTPS 代理: {https_proxy}")
458
+ proxies["https"] = https_proxy
459
+
460
+ if proxies == {}:
461
+ proxies = None
462
+
463
+ return proxies
464
+
465
+ def run(command, desc=None, errdesc=None, custom_env=None, live=False):
466
+ if desc is not None:
467
+ print(desc)
468
+ if live:
469
+ result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
470
+ if result.returncode != 0:
471
+ raise RuntimeError(f"""{errdesc or 'Error running command'}.
472
+ Command: {command}
473
+ Error code: {result.returncode}""")
474
+
475
+ return ""
476
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
477
+ if result.returncode != 0:
478
+ message = f"""{errdesc or 'Error running command'}.
479
+ Command: {command}
480
+ Error code: {result.returncode}
481
+ stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
482
+ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
483
+ """
484
+ raise RuntimeError(message)
485
+ return result.stdout.decode(encoding="utf8", errors="ignore")
486
+
487
+ def versions_html():
488
+ git = os.environ.get('GIT', "git")
489
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
490
+ try:
491
+ commit_hash = run(f"{git} rev-parse HEAD").strip()
492
+ except Exception:
493
+ commit_hash = "<none>"
494
+ if commit_hash != "<none>":
495
+ short_commit = commit_hash[0:7]
496
+ commit_info = f"<a style=\"text-decoration:none\" href=\"https://github.com/GaiZhenbiao/ChuanhuChatGPT/commit/{short_commit}\">{short_commit}</a>"
497
+ else:
498
+ commit_info = "unknown \U0001F615"
499
+ return f"""
500
+ Python: <span title="{sys.version}">{python_version}</span>
501
+  • 
502
+ Gradio: {gr.__version__}
503
+  • 
504
+ Commit: {commit_info}
505
+ """
506
+
507
+ def add_source_numbers(lst, source_name = "Source", use_source = True):
508
+ if use_source:
509
+ return [f'[{idx+1}]\t "{item[0]}"\n{source_name}: {item[1]}' for idx, item in enumerate(lst)]
510
+ else:
511
+ return [f'[{idx+1}]\t "{item}"' for idx, item in enumerate(lst)]
512
+
513
+ def add_details(lst):
514
+ nodes = []
515
+ for index, txt in enumerate(lst):
516
+ brief = txt[:25].replace("\n", "")
517
+ nodes.append(
518
+ f"<details><summary>{brief}...</summary><p>{txt}</p></details>"
519
+ )
520
+ return nodes