JohnSmith9982 commited on
Commit
7b5a1c0
1 Parent(s): cecb277

Upload 38 files

Browse files
CITATION.cff ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.2.0
2
+ title: ChuanhuChatGPT
3
+ message: >-
4
+ If you use this software, please cite it using these
5
+ metadata.
6
+ type: software
7
+ authors:
8
+ - given-names: Chuanhu
9
+ orcid: https://orcid.org/0000-0001-8954-8598
10
+ - given-names: MZhao
11
+ orcid: https://orcid.org/0000-0003-2298-6213
12
+ - given-names: Keldos
13
+ orcid: https://orcid.org/0009-0005-0357-272X
14
+ repository-code: 'https://github.com/GaiZhenbiao/ChuanhuChatGPT'
15
+ url: 'https://github.com/GaiZhenbiao/ChuanhuChatGPT'
16
+ abstract: Provided a light and easy to use interface for ChatGPT API
17
+ license: GPL-3.0
18
+ commit: bd0034c37e5af6a90bd9c2f7dd073f6cd27c61af
19
+ version: '20230405'
20
+ date-released: '2023-04-05'
ChuanhuChatbot.py CHANGED
@@ -10,31 +10,32 @@ from modules.config import *
10
  from modules.utils import *
11
  from modules.presets import *
12
  from modules.overwrites import *
13
- from modules.chat_func import *
14
- from modules.openai_func import get_usage
15
 
 
16
  gr.Chatbot.postprocess = postprocess
17
  PromptHelper.compact_text_chunks = compact_text_chunks
18
 
19
  with open("assets/custom.css", "r", encoding="utf-8") as f:
20
  customCSS = f.read()
21
 
 
 
 
22
  with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
23
  user_name = gr.State("")
24
- history = gr.State([])
25
- token_count = gr.State([])
26
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
27
- user_api_key = gr.State(my_api_key)
28
  user_question = gr.State("")
29
- outputing = gr.State(False)
 
 
30
  topic = gr.State("未命名对话历史记录")
31
 
32
  with gr.Row():
33
- with gr.Column():
34
- gr.HTML(title)
35
- user_info = gr.Markdown(value="", elem_id="user_info")
36
- gr.HTML('<center><a href="https://huggingface.co/spaces/JohnSmith9982/ChuanhuChatGPT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a></center>')
37
  status_display = gr.Markdown(get_geoip(), elem_id="status_display")
 
 
38
 
39
  # https://github.com/gradio-app/gradio/pull/3296
40
  def create_greeting(request: gr.Request):
@@ -50,14 +51,14 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
50
  with gr.Row():
51
  chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%")
52
  with gr.Row():
53
- with gr.Column(scale=12):
54
  user_input = gr.Textbox(
55
  elem_id="user_input_tb",
56
  show_label=False, placeholder="在这里输入"
57
  ).style(container=False)
58
- with gr.Column(min_width=70, scale=1):
59
- submitBtn = gr.Button("发送", variant="primary")
60
- cancelBtn = gr.Button("取消", variant="secondary", visible=False)
61
  with gr.Row():
62
  emptyBtn = gr.Button(
63
  "🧹 新的对话",
@@ -65,37 +66,41 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
65
  retryBtn = gr.Button("🔄 重新生成")
66
  delFirstBtn = gr.Button("🗑️ 删除最旧对话")
67
  delLastBtn = gr.Button("🗑️ 删除最新对话")
68
- reduceTokenBtn = gr.Button("♻️ 总结对话")
69
 
70
  with gr.Column():
71
  with gr.Column(min_width=50, scale=1):
72
- with gr.Tab(label="ChatGPT"):
73
  keyTxt = gr.Textbox(
74
  show_label=True,
75
  placeholder=f"OpenAI API-key...",
76
- value=hide_middle_chars(my_api_key),
77
  type="password",
78
  visible=not HIDE_MY_KEY,
79
  label="API-Key",
80
  )
81
  if multi_api_key:
82
- usageTxt = gr.Markdown("多账号模式已开启,无需输入key,可直接开始对话", elem_id="usage_display")
83
  else:
84
- usageTxt = gr.Markdown("**发送消息** 或 **提交key** 以显示额度", elem_id="usage_display")
85
  model_select_dropdown = gr.Dropdown(
86
- label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0]
87
  )
88
- use_streaming_checkbox = gr.Checkbox(
89
- label="实时传输回答", value=True, visible=enable_streaming_option
90
  )
91
- use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
 
 
 
 
 
92
  language_select_dropdown = gr.Dropdown(
93
  label="选择回复语言(针对搜索&索引功能)",
94
  choices=REPLY_LANGUAGES,
95
  multiselect=False,
96
  value=REPLY_LANGUAGES[0],
97
  )
98
- index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
99
  two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
100
  # TODO: 公式ocr
101
  # formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
@@ -105,7 +110,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
105
  show_label=True,
106
  placeholder=f"在这里输入System Prompt...",
107
  label="System prompt",
108
- value=initial_prompt,
109
  lines=10,
110
  ).style(container=False)
111
  with gr.Accordion(label="加载Prompt模板", open=True):
@@ -161,27 +166,87 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
161
 
162
  with gr.Tab(label="高级"):
163
  gr.Markdown("# ⚠️ 务必谨慎更改 ⚠️\n\n如果无法使用请恢复默认设置")
164
- default_btn = gr.Button("🔙 恢复默认设置")
165
-
166
  with gr.Accordion("参数", open=False):
167
- top_p = gr.Slider(
 
 
 
 
 
 
 
 
168
  minimum=-0,
169
  maximum=1.0,
170
  value=1.0,
171
  step=0.05,
172
  interactive=True,
173
- label="Top-p",
174
  )
175
- temperature = gr.Slider(
176
- minimum=-0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  maximum=2.0,
178
- value=1.0,
179
- step=0.1,
180
  interactive=True,
181
- label="Temperature",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  )
183
 
184
- with gr.Accordion("网络设置", open=False, visible=False):
185
  # 优先展示自定义的api_host
186
  apihostTxt = gr.Textbox(
187
  show_label=True,
@@ -199,27 +264,22 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
199
  lines=2,
200
  )
201
  changeProxyBtn = gr.Button("🔄 设置代理地址")
 
202
 
203
- gr.Markdown(description)
204
- gr.HTML(footer.format(versions=versions_html()), elem_id="footer")
205
  chatgpt_predict_args = dict(
206
  fn=predict,
207
  inputs=[
208
- user_api_key,
209
- systemPromptTxt,
210
- history,
211
  user_question,
212
  chatbot,
213
- token_count,
214
- top_p,
215
- temperature,
216
  use_streaming_checkbox,
217
- model_select_dropdown,
218
  use_websearch_checkbox,
219
  index_files,
220
  language_select_dropdown,
221
  ],
222
- outputs=[chatbot, history, status_display, token_count],
223
  show_progress=True,
224
  )
225
 
@@ -243,12 +303,18 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
243
  )
244
 
245
  get_usage_args = dict(
246
- fn=get_usage, inputs=[user_api_key], outputs=[usageTxt], show_progress=False
 
 
 
 
 
 
247
  )
248
 
249
 
250
  # Chatbot
251
- cancelBtn.click(cancel_outputing, [], [])
252
 
253
  user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
254
  user_input.submit(**get_usage_args)
@@ -256,9 +322,12 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
256
  submitBtn.click(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
257
  submitBtn.click(**get_usage_args)
258
 
 
 
259
  emptyBtn.click(
260
- reset_state,
261
- outputs=[chatbot, history, token_count, status_display],
 
262
  show_progress=True,
263
  )
264
  emptyBtn.click(**reset_textbox_args)
@@ -266,61 +335,42 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
266
  retryBtn.click(**start_outputing_args).then(
267
  retry,
268
  [
269
- user_api_key,
270
- systemPromptTxt,
271
- history,
272
  chatbot,
273
- token_count,
274
- top_p,
275
- temperature,
276
  use_streaming_checkbox,
277
- model_select_dropdown,
 
278
  language_select_dropdown,
279
  ],
280
- [chatbot, history, status_display, token_count],
281
  show_progress=True,
282
  ).then(**end_outputing_args)
283
  retryBtn.click(**get_usage_args)
284
 
285
  delFirstBtn.click(
286
  delete_first_conversation,
287
- [history, token_count],
288
- [history, token_count, status_display],
289
  )
290
 
291
  delLastBtn.click(
292
  delete_last_conversation,
293
- [chatbot, history, token_count],
294
- [chatbot, history, token_count, status_display],
295
- show_progress=True,
296
  )
297
 
298
- reduceTokenBtn.click(
299
- reduce_token_size,
300
- [
301
- user_api_key,
302
- systemPromptTxt,
303
- history,
304
- chatbot,
305
- token_count,
306
- top_p,
307
- temperature,
308
- gr.State(sum(token_count.value[-4:])),
309
- model_select_dropdown,
310
- language_select_dropdown,
311
- ],
312
- [chatbot, history, status_display, token_count],
313
- show_progress=True,
314
- )
315
- reduceTokenBtn.click(**get_usage_args)
316
-
317
  two_column.change(update_doc_config, [two_column], None)
318
 
319
- # ChatGPT
320
- keyTxt.change(submit_key, keyTxt, [user_api_key, status_display]).then(**get_usage_args)
321
  keyTxt.submit(**get_usage_args)
 
 
 
322
 
323
  # Template
 
324
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
325
  templateFileSelectDropdown.change(
326
  load_template,
@@ -338,31 +388,33 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
338
  # S&L
339
  saveHistoryBtn.click(
340
  save_chat_history,
341
- [saveFileName, systemPromptTxt, history, chatbot, user_name],
342
  downloadFile,
343
  show_progress=True,
344
  )
345
  saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
346
  exportMarkdownBtn.click(
347
  export_markdown,
348
- [saveFileName, systemPromptTxt, history, chatbot, user_name],
349
  downloadFile,
350
  show_progress=True,
351
  )
352
  historyRefreshBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
353
- historyFileSelectDropdown.change(
354
- load_chat_history,
355
- [historyFileSelectDropdown, systemPromptTxt, history, chatbot, user_name],
356
- [saveFileName, systemPromptTxt, history, chatbot],
357
- show_progress=True,
358
- )
359
- downloadFile.change(
360
- load_chat_history,
361
- [downloadFile, systemPromptTxt, history, chatbot, user_name],
362
- [saveFileName, systemPromptTxt, history, chatbot],
363
- )
364
 
365
  # Advanced
 
 
 
 
 
 
 
 
 
 
 
366
  default_btn.click(
367
  reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
368
  )
@@ -389,35 +441,14 @@ demo.title = "川虎ChatGPT 🚀"
389
 
390
  if __name__ == "__main__":
391
  reload_javascript()
392
- # if running in Docker
393
- if dockerflag:
394
- if authflag:
395
- demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
396
- server_name="0.0.0.0",
397
- server_port=7860,
398
- auth=auth_list,
399
- favicon_path="./assets/favicon.ico",
400
- )
401
- else:
402
- demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
403
- server_name="0.0.0.0",
404
- server_port=7860,
405
- share=False,
406
- favicon_path="./assets/favicon.ico",
407
- )
408
- # if not running in Docker
409
- else:
410
- if authflag:
411
- demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
412
- share=False,
413
- auth=auth_list,
414
- favicon_path="./assets/favicon.ico",
415
- inbrowser=True,
416
- )
417
- else:
418
- demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
419
- share=False, favicon_path="./assets/favicon.ico", inbrowser=True
420
- ) # 改为 share=True 可以创建公开分享链接
421
- # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
422
- # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
423
- # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
 
10
  from modules.utils import *
11
  from modules.presets import *
12
  from modules.overwrites import *
13
+ from modules.models import get_model
 
14
 
15
+ gr.Chatbot._postprocess_chat_messages = postprocess_chat_messages
16
  gr.Chatbot.postprocess = postprocess
17
  PromptHelper.compact_text_chunks = compact_text_chunks
18
 
19
  with open("assets/custom.css", "r", encoding="utf-8") as f:
20
  customCSS = f.read()
21
 
22
+ def create_new_model():
23
+ return get_model(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key)[0]
24
+
25
  with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
26
  user_name = gr.State("")
 
 
27
  promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
 
28
  user_question = gr.State("")
29
+ user_api_key = gr.State(my_api_key)
30
+ current_model = gr.State(create_new_model)
31
+
32
  topic = gr.State("未命名对话历史记录")
33
 
34
  with gr.Row():
35
+ gr.HTML(CHUANHU_TITLE, elem_id="app_title")
 
 
 
36
  status_display = gr.Markdown(get_geoip(), elem_id="status_display")
37
+ with gr.Row(elem_id="float_display"):
38
+ user_info = gr.Markdown(value="getting user info...", elem_id="user_info")
39
 
40
  # https://github.com/gradio-app/gradio/pull/3296
41
  def create_greeting(request: gr.Request):
 
51
  with gr.Row():
52
  chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%")
53
  with gr.Row():
54
+ with gr.Column(min_width=225, scale=12):
55
  user_input = gr.Textbox(
56
  elem_id="user_input_tb",
57
  show_label=False, placeholder="在这里输入"
58
  ).style(container=False)
59
+ with gr.Column(min_width=42, scale=1):
60
+ submitBtn = gr.Button(value="", variant="primary", elem_id="submit_btn")
61
+ cancelBtn = gr.Button(value="", variant="secondary", visible=False, elem_id="cancel_btn")
62
  with gr.Row():
63
  emptyBtn = gr.Button(
64
  "🧹 新的对话",
 
66
  retryBtn = gr.Button("🔄 重新生成")
67
  delFirstBtn = gr.Button("🗑️ 删除最旧对话")
68
  delLastBtn = gr.Button("🗑️ 删除最新对话")
 
69
 
70
  with gr.Column():
71
  with gr.Column(min_width=50, scale=1):
72
+ with gr.Tab(label="模型"):
73
  keyTxt = gr.Textbox(
74
  show_label=True,
75
  placeholder=f"OpenAI API-key...",
76
+ value=hide_middle_chars(user_api_key.value),
77
  type="password",
78
  visible=not HIDE_MY_KEY,
79
  label="API-Key",
80
  )
81
  if multi_api_key:
82
+ usageTxt = gr.Markdown("多账号模式已开启,无需输入key,可直接开始对话", elem_id="usage_display", elem_classes="insert_block")
83
  else:
84
+ usageTxt = gr.Markdown("**发送消息** 或 **提交key** 以显示额度", elem_id="usage_display", elem_classes="insert_block")
85
  model_select_dropdown = gr.Dropdown(
86
+ label="选择模型", choices=MODELS, multiselect=False, value=MODELS[DEFAULT_MODEL], interactive=True
87
  )
88
+ lora_select_dropdown = gr.Dropdown(
89
+ label="选择LoRA模型", choices=[], multiselect=False, interactive=True, visible=False
90
  )
91
+ with gr.Row():
92
+ use_streaming_checkbox = gr.Checkbox(
93
+ label="实时传输回答", value=True, visible=ENABLE_STREAMING_OPTION
94
+ )
95
+ single_turn_checkbox = gr.Checkbox(label="单轮对话", value=False)
96
+ use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
97
  language_select_dropdown = gr.Dropdown(
98
  label="选择回复语言(针对搜索&索引功能)",
99
  choices=REPLY_LANGUAGES,
100
  multiselect=False,
101
  value=REPLY_LANGUAGES[0],
102
  )
103
+ index_files = gr.Files(label="上传索引文件", type="file")
104
  two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
105
  # TODO: 公式ocr
106
  # formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
 
110
  show_label=True,
111
  placeholder=f"在这里输入System Prompt...",
112
  label="System prompt",
113
+ value=INITIAL_SYSTEM_PROMPT,
114
  lines=10,
115
  ).style(container=False)
116
  with gr.Accordion(label="加载Prompt模板", open=True):
 
166
 
167
  with gr.Tab(label="高级"):
168
  gr.Markdown("# ⚠️ 务必谨慎更改 ⚠️\n\n如果无法使用请恢复默认设置")
169
+ gr.HTML(APPEARANCE_SWITCHER, elem_classes="insert_block")
 
170
  with gr.Accordion("参数", open=False):
171
+ temperature_slider = gr.Slider(
172
+ minimum=-0,
173
+ maximum=2.0,
174
+ value=1.0,
175
+ step=0.1,
176
+ interactive=True,
177
+ label="temperature",
178
+ )
179
+ top_p_slider = gr.Slider(
180
  minimum=-0,
181
  maximum=1.0,
182
  value=1.0,
183
  step=0.05,
184
  interactive=True,
185
+ label="top-p",
186
  )
187
+ n_choices_slider = gr.Slider(
188
+ minimum=1,
189
+ maximum=10,
190
+ value=1,
191
+ step=1,
192
+ interactive=True,
193
+ label="n choices",
194
+ )
195
+ stop_sequence_txt = gr.Textbox(
196
+ show_label=True,
197
+ placeholder=f"在这里输入停止符,用英文逗号隔开...",
198
+ label="stop",
199
+ value="",
200
+ lines=1,
201
+ )
202
+ max_context_length_slider = gr.Slider(
203
+ minimum=1,
204
+ maximum=32768,
205
+ value=2000,
206
+ step=1,
207
+ interactive=True,
208
+ label="max context",
209
+ )
210
+ max_generation_slider = gr.Slider(
211
+ minimum=1,
212
+ maximum=32768,
213
+ value=1000,
214
+ step=1,
215
+ interactive=True,
216
+ label="max generations",
217
+ )
218
+ presence_penalty_slider = gr.Slider(
219
+ minimum=-2.0,
220
  maximum=2.0,
221
+ value=0.0,
222
+ step=0.01,
223
  interactive=True,
224
+ label="presence penalty",
225
+ )
226
+ frequency_penalty_slider = gr.Slider(
227
+ minimum=-2.0,
228
+ maximum=2.0,
229
+ value=0.0,
230
+ step=0.01,
231
+ interactive=True,
232
+ label="frequency penalty",
233
+ )
234
+ logit_bias_txt = gr.Textbox(
235
+ show_label=True,
236
+ placeholder=f"word:likelihood",
237
+ label="logit bias",
238
+ value="",
239
+ lines=1,
240
+ )
241
+ user_identifier_txt = gr.Textbox(
242
+ show_label=True,
243
+ placeholder=f"用于定位滥用行为",
244
+ label="用户名",
245
+ value=user_name.value,
246
+ lines=1,
247
  )
248
 
249
+ with gr.Accordion("网络设置", open=False):
250
  # 优先展示自定义的api_host
251
  apihostTxt = gr.Textbox(
252
  show_label=True,
 
264
  lines=2,
265
  )
266
  changeProxyBtn = gr.Button("🔄 设置代理地址")
267
+ default_btn = gr.Button("🔙 恢复默认设置")
268
 
269
+ gr.Markdown(CHUANHU_DESCRIPTION)
270
+ gr.HTML(FOOTER.format(versions=versions_html()), elem_id="footer")
271
  chatgpt_predict_args = dict(
272
  fn=predict,
273
  inputs=[
274
+ current_model,
 
 
275
  user_question,
276
  chatbot,
 
 
 
277
  use_streaming_checkbox,
 
278
  use_websearch_checkbox,
279
  index_files,
280
  language_select_dropdown,
281
  ],
282
+ outputs=[chatbot, status_display],
283
  show_progress=True,
284
  )
285
 
 
303
  )
304
 
305
  get_usage_args = dict(
306
+ fn=billing_info, inputs=[current_model], outputs=[usageTxt], show_progress=False
307
+ )
308
+
309
+ load_history_from_file_args = dict(
310
+ fn=load_chat_history,
311
+ inputs=[current_model, historyFileSelectDropdown, chatbot, user_name],
312
+ outputs=[saveFileName, systemPromptTxt, chatbot]
313
  )
314
 
315
 
316
  # Chatbot
317
+ cancelBtn.click(interrupt, [current_model], [])
318
 
319
  user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
320
  user_input.submit(**get_usage_args)
 
322
  submitBtn.click(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
323
  submitBtn.click(**get_usage_args)
324
 
325
+ index_files.change(handle_file_upload, [current_model, index_files, chatbot], [index_files, chatbot, status_display])
326
+
327
  emptyBtn.click(
328
+ reset,
329
+ inputs=[current_model],
330
+ outputs=[chatbot, status_display],
331
  show_progress=True,
332
  )
333
  emptyBtn.click(**reset_textbox_args)
 
335
  retryBtn.click(**start_outputing_args).then(
336
  retry,
337
  [
338
+ current_model,
 
 
339
  chatbot,
 
 
 
340
  use_streaming_checkbox,
341
+ use_websearch_checkbox,
342
+ index_files,
343
  language_select_dropdown,
344
  ],
345
+ [chatbot, status_display],
346
  show_progress=True,
347
  ).then(**end_outputing_args)
348
  retryBtn.click(**get_usage_args)
349
 
350
  delFirstBtn.click(
351
  delete_first_conversation,
352
+ [current_model],
353
+ [status_display],
354
  )
355
 
356
  delLastBtn.click(
357
  delete_last_conversation,
358
+ [current_model, chatbot],
359
+ [chatbot, status_display],
360
+ show_progress=False
361
  )
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  two_column.change(update_doc_config, [two_column], None)
364
 
365
+ # LLM Models
366
+ keyTxt.change(set_key, [current_model, keyTxt], [user_api_key, status_display]).then(**get_usage_args)
367
  keyTxt.submit(**get_usage_args)
368
+ single_turn_checkbox.change(set_single_turn, [current_model, single_turn_checkbox], None)
369
+ model_select_dropdown.change(get_model, [model_select_dropdown, lora_select_dropdown, user_api_key, temperature_slider, top_p_slider, systemPromptTxt], [current_model, status_display, lora_select_dropdown], show_progress=True)
370
+ lora_select_dropdown.change(get_model, [model_select_dropdown, lora_select_dropdown, user_api_key, temperature_slider, top_p_slider, systemPromptTxt], [current_model, status_display], show_progress=True)
371
 
372
  # Template
373
+ systemPromptTxt.change(set_system_prompt, [current_model, systemPromptTxt], None)
374
  templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
375
  templateFileSelectDropdown.change(
376
  load_template,
 
388
  # S&L
389
  saveHistoryBtn.click(
390
  save_chat_history,
391
+ [current_model, saveFileName, chatbot, user_name],
392
  downloadFile,
393
  show_progress=True,
394
  )
395
  saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
396
  exportMarkdownBtn.click(
397
  export_markdown,
398
+ [current_model, saveFileName, chatbot, user_name],
399
  downloadFile,
400
  show_progress=True,
401
  )
402
  historyRefreshBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
403
+ historyFileSelectDropdown.change(**load_history_from_file_args)
404
+ downloadFile.change(**load_history_from_file_args)
 
 
 
 
 
 
 
 
 
405
 
406
  # Advanced
407
+ max_context_length_slider.change(set_token_upper_limit, [current_model, max_context_length_slider], None)
408
+ temperature_slider.change(set_temperature, [current_model, temperature_slider], None)
409
+ top_p_slider.change(set_top_p, [current_model, top_p_slider], None)
410
+ n_choices_slider.change(set_n_choices, [current_model, n_choices_slider], None)
411
+ stop_sequence_txt.change(set_stop_sequence, [current_model, stop_sequence_txt], None)
412
+ max_generation_slider.change(set_max_tokens, [current_model, max_generation_slider], None)
413
+ presence_penalty_slider.change(set_presence_penalty, [current_model, presence_penalty_slider], None)
414
+ frequency_penalty_slider.change(set_frequency_penalty, [current_model, frequency_penalty_slider], None)
415
+ logit_bias_txt.change(set_logit_bias, [current_model, logit_bias_txt], None)
416
+ user_identifier_txt.change(set_user_identifier, [current_model, user_identifier_txt], None)
417
+
418
  default_btn.click(
419
  reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
420
  )
 
441
 
442
  if __name__ == "__main__":
443
  reload_javascript()
444
+ demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
445
+ server_name=server_name,
446
+ server_port=server_port,
447
+ share=share,
448
+ auth=auth_list if authflag else None,
449
+ favicon_path="./assets/favicon.ico",
450
+ inbrowser=not dockerflag, # 禁止在docker下开启inbrowser
451
+ )
452
+ # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
453
+ # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
454
+ # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile CHANGED
@@ -1,7 +1,9 @@
1
  FROM python:3.9 as builder
2
  RUN apt-get update && apt-get install -y build-essential
3
  COPY requirements.txt .
 
4
  RUN pip install --user -r requirements.txt
 
5
 
6
  FROM python:3.9
7
  MAINTAINER iskoldt
@@ -9,6 +11,5 @@ COPY --from=builder /root/.local /root/.local
9
  ENV PATH=/root/.local/bin:$PATH
10
  COPY . /app
11
  WORKDIR /app
12
- ENV my_api_key empty
13
  ENV dockerrun yes
14
  CMD ["python3", "-u", "ChuanhuChatbot.py", "2>&1", "|", "tee", "/var/log/application.log"]
 
1
  FROM python:3.9 as builder
2
  RUN apt-get update && apt-get install -y build-essential
3
  COPY requirements.txt .
4
+ COPY requirements_advanced.txt .
5
  RUN pip install --user -r requirements.txt
6
+ # RUN pip install --user -r requirements_advanced.txt
7
 
8
  FROM python:3.9
9
  MAINTAINER iskoldt
 
11
  ENV PATH=/root/.local/bin:$PATH
12
  COPY . /app
13
  WORKDIR /app
 
14
  ENV dockerrun yes
15
  CMD ["python3", "-u", "ChuanhuChatbot.py", "2>&1", "|", "tee", "/var/log/application.log"]
README.md CHANGED
@@ -1,13 +1,105 @@
1
- ---
2
- title: ChuanhuChatGPT
3
- emoji: 🐯
4
- colorFrom: green
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 3.24.1
8
- app_file: ChuanhuChatbot.py
9
- pinned: false
10
- license: gpl-3.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">川虎 Chat 🐯 Chuanhu Chat</h1>
2
+ <div align="center">
3
+ <a href="https://github.com/GaiZhenBiao/ChuanhuChatGPT">
4
+ <img src="https://user-images.githubusercontent.com/70903329/227087087-93b37d64-7dc3-4738-a518-c1cf05591c8a.png" alt="Logo" height="156">
5
+ </a>
6
+
7
+ <p align="center">
8
+ <h3>为ChatGPT/ChatGLM/LLaMA等多种LLM提供了一个轻快好用的Web图形界面</h3>
9
+ <p align="center">
10
+ <a href="https://github.com/GaiZhenbiao/ChuanhuChatGPT/blob/main/LICENSE">
11
+ <img alt="Tests Passing" src="https://img.shields.io/github/license/GaiZhenbiao/ChuanhuChatGPT" />
12
+ </a>
13
+ <a href="https://gradio.app/">
14
+ <img alt="GitHub Contributors" src="https://img.shields.io/badge/Base-Gradio-fb7d1a?style=flat" />
15
+ </a>
16
+ <a href="https://t.me/tkdifferent">
17
+ <img alt="GitHub pull requests" src="https://img.shields.io/badge/Telegram-Group-blue.svg?logo=telegram" />
18
+ </a>
19
+ <p>
20
+ 实时回复 / 无限对话 / 保存对话 / 预设Prompt集 / 联网搜索 / 根据文件回答 <br />
21
+ 渲染LaTeX / 渲染表格 / 代码高亮 / 自动亮暗色切换 / 自适应界面 / “小而美”的体验 <br />
22
+ 自定义api-Host / 多参数可调 / 多API Key均衡负载 / 多用户显示 / 适配GPT-4 / 支持本地部署LLM
23
+ </p>
24
+ <a href="https://www.bilibili.com/video/BV1mo4y1r7eE"><strong>视频教程</strong></a>
25
+ ·
26
+ <a href="https://www.bilibili.com/video/BV1184y1w7aP"><strong>2.0介绍视频</strong></a>
27
+ ||
28
+ <a href="https://huggingface.co/spaces/JohnSmith9982/ChuanhuChatGPT"><strong>在线体验</strong></a>
29
+ ·
30
+ <a href="https://huggingface.co/login?next=%2Fspaces%2FJohnSmith9982%2FChuanhuChatGPT%3Fduplicate%3Dtrue"><strong>一键部署</strong></a>
31
+ </p>
32
+ <p align="center">
33
+ <img alt="Animation Demo" src="https://user-images.githubusercontent.com/51039745/226255695-6b17ff1f-ea8d-464f-b69b-a7b6b68fffe8.gif" />
34
+ </p>
35
+ </p>
36
+ </div>
37
+
38
+ ## 目录
39
+ |[使用技巧](#使用技巧)|[安装方式](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用教程)|[常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题)| [给作者买可乐🥤](#捐款) |
40
+ | ---- | ---- | ---- | --- |
41
+
42
+ ## 使用技巧
43
+
44
+ - 使用System Prompt可以很有效地设定前提条件。
45
+ - 使用Prompt模板功能时,选择Prompt模板集合文件,然后从下拉菜单中选择想要的prompt。
46
+ - 如果回答不满意,可以使用`重新生成`按钮再试一次
47
+ - 对于长对话,可以使用`优化Tokens`按钮减少Tokens占用。
48
+ - 输入框支持换行,按`shift enter`即可。
49
+ - 可以在输入框按上下箭头在输入历史之间切换
50
+ - 部署到服务器:将程序最后一句改成`demo.launch(server_name="0.0.0.0", server_port=<你的端口号>)`。
51
+ - 获取公共链接:将程序最后一句改成`demo.launch(share=True)`。注意程序必须在运行,才能通过公共链接访问。
52
+ - 在Hugging Face上使用:建议在右上角 **复制Space** 再使用,这样App反应可能会快一点。
53
+
54
+
55
+ ## 安装方式、使用方式
56
+
57
+ 请查看[本项目的wiki页面](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用教程)。
58
+
59
+ ## 疑难杂症解决
60
+
61
+ 在遇到各种问题查阅相关信息前,您可以先尝试手动拉取本项目的最新更改并更新 gradio,然后重试。步骤为:
62
+
63
+ 1. 点击网页上的 `Download ZIP` 下载最新代码,或
64
+ ```shell
65
+ git pull https://github.com/GaiZhenbiao/ChuanhuChatGPT.git main -f
66
+ ```
67
+ 2. 尝试再次安装依赖(可能本项目引入了新的依赖)
68
+ ```
69
+ pip install -r requirements.txt
70
+ ```
71
+ 3. 更新gradio
72
+ ```
73
+ pip install gradio --upgrade --force-reinstall
74
+ ```
75
+
76
+ 很多时候,这样就可以解决问题。
77
+
78
+ 如果问题仍然存在,请查阅该页面:[常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题)
79
+
80
+ 该页面列出了**几乎所有**您可能遇到的各种问题,包括如何配置代理,以及遇到问题后您该采取的措施,**请务必认真阅读**。
81
+
82
+ ## 了解更多
83
+
84
+ 若需了解更多信息,请查看我们的 [wiki](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki):
85
+
86
+ - [想要做出贡献?](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/贡献指南)
87
+ - [项目更新情况?](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/更新日志)
88
+ - [二次开发许可?](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用许可)
89
+ - [如何引用项目?](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用许可#如何引用该项目)
90
+
91
+ ## Starchart
92
+
93
+ [![Star History Chart](https://api.star-history.com/svg?repos=GaiZhenbiao/ChuanhuChatGPT&type=Date)](https://star-history.com/#GaiZhenbiao/ChuanhuChatGPT&Date)
94
+
95
+ ## Contributors
96
+
97
+ <a href="https://github.com/GaiZhenbiao/ChuanhuChatGPT/graphs/contributors">
98
+ <img src="https://contrib.rocks/image?repo=GaiZhenbiao/ChuanhuChatGPT" />
99
+ </a>
100
+
101
+ ## 捐款
102
+
103
+ 🐯如果觉得这个软件对你有所帮助,欢迎请作者喝可乐、喝咖啡~
104
+
105
+ <img width="250" alt="image" src="https://user-images.githubusercontent.com/51039745/226920291-e8ec0b0a-400f-4c20-ac13-dafac0c3aeeb.JPG">
assets/custom.css CHANGED
@@ -3,14 +3,18 @@
3
  --chatbot-color-dark: #121111;
4
  }
5
 
 
 
 
 
6
  /* 覆盖gradio的页脚信息QAQ */
7
  footer {
8
  display: none !important;
9
  }
10
- #footer{
11
  text-align: center;
12
  }
13
- #footer div{
14
  display: inline-block;
15
  }
16
  #footer .versions{
@@ -18,16 +22,34 @@ footer {
18
  opacity: 0.85;
19
  }
20
 
21
- /* user_info */
 
 
 
 
22
  #user_info {
23
  white-space: nowrap;
24
- margin-top: -1.3em !important;
25
- padding-left: 112px !important;
 
 
 
 
 
 
 
 
 
 
 
26
  }
27
  #user_info p {
28
- font-size: .85em;
29
- font-family: monospace;
30
- color: var(--body-text-color-subdued);
 
 
 
31
  }
32
 
33
  /* status_display */
@@ -43,14 +65,18 @@ footer {
43
  color: var(--body-text-color-subdued);
44
  }
45
 
46
- #chuanhu_chatbot, #status_display {
47
  transition: all 0.6s;
48
  }
 
 
 
49
 
50
  /* usage_display */
51
- #usage_display {
52
  position: relative;
53
  margin: 0;
 
54
  box-shadow: var(--block-shadow);
55
  border-width: var(--block-border-width);
56
  border-color: var(--block-border-color);
@@ -62,7 +88,6 @@ footer {
62
  }
63
  #usage_display p, #usage_display span {
64
  margin: 0;
65
- padding: .5em 1em;
66
  font-size: .85em;
67
  color: var(--body-text-color-subdued);
68
  }
@@ -74,7 +99,7 @@ footer {
74
  overflow: hidden;
75
  }
76
  .progress {
77
- background-color: var(--block-title-background-fill);;
78
  height: 100%;
79
  border-radius: 10px;
80
  text-align: right;
@@ -88,38 +113,107 @@ footer {
88
  padding-right: 10px;
89
  line-height: 20px;
90
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  /* list */
92
  ol:not(.options), ul:not(.options) {
93
  padding-inline-start: 2em !important;
94
  }
95
 
96
- /* 亮色 */
97
- @media (prefers-color-scheme: light) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  #chuanhu_chatbot {
99
- background-color: var(--chatbot-color-light) !important;
100
- color: #000000 !important;
101
  }
102
- [data-testid = "bot"] {
103
- background-color: #FFFFFF !important;
104
- }
105
- [data-testid = "user"] {
106
- background-color: #95EC69 !important;
107
  }
108
  }
109
- /* 暗色 */
110
- @media (prefers-color-scheme: dark) {
111
  #chuanhu_chatbot {
112
- background-color: var(--chatbot-color-dark) !important;
113
- color: #FFFFFF !important;
114
  }
115
- [data-testid = "bot"] {
116
- background-color: #2C2C2C !important;
117
  }
118
- [data-testid = "user"] {
119
- background-color: #26B561 !important;
120
  }
121
- body {
122
- background-color: var(--neutral-950) !important;
123
  }
124
  }
125
  /* 对话气泡 */
 
3
  --chatbot-color-dark: #121111;
4
  }
5
 
6
+ #app_title {
7
+ margin-top: 6px;
8
+ white-space: nowrap;
9
+ }
10
  /* 覆盖gradio的页脚信息QAQ */
11
  footer {
12
  display: none !important;
13
  }
14
+ #footer {
15
  text-align: center;
16
  }
17
+ #footer div {
18
  display: inline-block;
19
  }
20
  #footer .versions{
 
22
  opacity: 0.85;
23
  }
24
 
25
+ #float_display {
26
+ position: absolute;
27
+ max-height: 30px;
28
+ }
29
+ /* user_info */
30
  #user_info {
31
  white-space: nowrap;
32
+ position: absolute; left: 8em; top: .2em;
33
+ z-index: var(--layer-2);
34
+ box-shadow: var(--block-shadow);
35
+ border: none; border-radius: var(--block-label-radius);
36
+ background: var(--color-accent);
37
+ padding: var(--block-label-padding);
38
+ font-size: var(--block-label-text-size); line-height: var(--line-sm);
39
+ width: auto; min-height: 30px!important;
40
+ opacity: 1;
41
+ transition: opacity 0.3s ease-in-out;
42
+ }
43
+ #user_info .wrap {
44
+ opacity: 0;
45
  }
46
  #user_info p {
47
+ color: white;
48
+ font-weight: var(--block-label-text-weight);
49
+ }
50
+ #user_info.hideK {
51
+ opacity: 0;
52
+ transition: opacity 1s ease-in-out;
53
  }
54
 
55
  /* status_display */
 
65
  color: var(--body-text-color-subdued);
66
  }
67
 
68
+ #status_display {
69
  transition: all 0.6s;
70
  }
71
+ #chuanhu_chatbot {
72
+ transition: height 0.3s ease;
73
+ }
74
 
75
  /* usage_display */
76
+ .insert_block {
77
  position: relative;
78
  margin: 0;
79
+ padding: .5em 1em;
80
  box-shadow: var(--block-shadow);
81
  border-width: var(--block-border-width);
82
  border-color: var(--block-border-color);
 
88
  }
89
  #usage_display p, #usage_display span {
90
  margin: 0;
 
91
  font-size: .85em;
92
  color: var(--body-text-color-subdued);
93
  }
 
99
  overflow: hidden;
100
  }
101
  .progress {
102
+ background-color: var(--block-title-background-fill);
103
  height: 100%;
104
  border-radius: 10px;
105
  text-align: right;
 
113
  padding-right: 10px;
114
  line-height: 20px;
115
  }
116
+
117
+ .apSwitch {
118
+ top: 2px;
119
+ display: inline-block;
120
+ height: 24px;
121
+ position: relative;
122
+ width: 48px;
123
+ border-radius: 12px;
124
+ }
125
+ .apSwitch input {
126
+ display: none !important;
127
+ }
128
+ .apSlider {
129
+ background-color: var(--block-label-background-fill);
130
+ bottom: 0;
131
+ cursor: pointer;
132
+ left: 0;
133
+ position: absolute;
134
+ right: 0;
135
+ top: 0;
136
+ transition: .4s;
137
+ font-size: 18px;
138
+ border-radius: 12px;
139
+ }
140
+ .apSlider::before {
141
+ bottom: -1.5px;
142
+ left: 1px;
143
+ position: absolute;
144
+ transition: .4s;
145
+ content: "🌞";
146
+ }
147
+ input:checked + .apSlider {
148
+ background-color: var(--block-label-background-fill);
149
+ }
150
+ input:checked + .apSlider::before {
151
+ transform: translateX(23px);
152
+ content:"🌚";
153
+ }
154
+
155
+ #submit_btn, #cancel_btn {
156
+ height: 42px !important;
157
+ }
158
+ #submit_btn::before {
159
+ content: url("data:image/svg+xml, %3Csvg width='21px' height='20px' viewBox='0 0 21 20' version='1.1' xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink'%3E %3Cg id='page' stroke='none' stroke-width='1' fill='none' fill-rule='evenodd'%3E %3Cg id='send' transform='translate(0.435849, 0.088463)' fill='%23FFFFFF' fill-rule='nonzero'%3E %3Cpath d='M0.579148261,0.0428666046 C0.301105539,-0.0961547561 -0.036517765,0.122307382 0.0032026237,0.420210298 L1.4927172,18.1553639 C1.5125774,18.4334066 1.79062012,18.5922882 2.04880264,18.4929872 L8.24518329,15.8913017 L11.6412765,19.7441794 C11.8597387,19.9825018 12.2370824,19.8832008 12.3165231,19.5852979 L13.9450591,13.4882182 L19.7839562,11.0255541 C20.0619989,10.8865327 20.0818591,10.4694687 19.7839562,10.3105871 L0.579148261,0.0428666046 Z M11.6138902,17.0883151 L9.85385903,14.7195502 L0.718169621,0.618812241 L12.69945,12.9346347 L11.6138902,17.0883151 Z' id='shape'%3E%3C/path%3E %3C/g%3E %3C/g%3E %3C/svg%3E");
160
+ height: 21px;
161
+ }
162
+ #cancel_btn::before {
163
+ content: url("data:image/svg+xml,%3Csvg width='21px' height='21px' viewBox='0 0 21 21' version='1.1' xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink'%3E %3Cg id='pg' stroke='none' stroke-width='1' fill='none' fill-rule='evenodd'%3E %3Cpath d='M10.2072007,20.088463 C11.5727865,20.088463 12.8594566,19.8259823 14.067211,19.3010209 C15.2749653,18.7760595 16.3386126,18.0538087 17.2581528,17.1342685 C18.177693,16.2147282 18.8982283,15.1527965 19.4197586,13.9484733 C19.9412889,12.7441501 20.202054,11.4557644 20.202054,10.0833163 C20.202054,8.71773046 19.9395733,7.43106036 19.4146119,6.22330603 C18.8896505,5.01555169 18.1673997,3.95018885 17.2478595,3.0272175 C16.3283192,2.10424615 15.2646719,1.3837109 14.0569176,0.865611739 C12.8491633,0.34751258 11.5624932,0.088463 10.1969073,0.088463 C8.83132146,0.088463 7.54636692,0.34751258 6.34204371,0.865611739 C5.1377205,1.3837109 4.07407321,2.10424615 3.15110186,3.0272175 C2.22813051,3.95018885 1.5058797,5.01555169 0.984349419,6.22330603 C0.46281914,7.43106036 0.202054,8.71773046 0.202054,10.0833163 C0.202054,11.4557644 0.4645347,12.7441501 0.9894961,13.9484733 C1.5144575,15.1527965 2.23670831,16.2147282 3.15624854,17.1342685 C4.07578877,18.0538087 5.1377205,18.7760595 6.34204371,19.3010209 C7.54636692,19.8259823 8.83475258,20.088463 10.2072007,20.088463 Z M10.2072007,18.2562448 C9.07493099,18.2562448 8.01471483,18.0452309 7.0265522,17.6232031 C6.03838956,17.2011753 5.17031614,16.6161693 4.42233192,15.8681851 C3.6743477,15.1202009 3.09105726,14.2521274 2.67246059,13.2639648 C2.25386392,12.2758022 2.04456558,11.215586 2.04456558,10.0833163 C2.04456558,8.95104663 2.25386392,7.89083047 2.67246059,6.90266784 C3.09105726,5.9145052 3.6743477,5.04643178 4.42233192,4.29844756 C5.17031614,3.55046334 6.036674,2.9671729 7.02140552,2.54857623 C8.00613703,2.12997956 9.06463763,1.92068122 10.1969073,1.92068122 C11.329177,1.92068122 12.3911087,2.12997956 13.3827025,2.54857623 C14.3742962,2.9671729 15.2440852,3.55046334 15.9920694,4.29844756 C16.7400537,5.04643178 17.3233441,5.9145052 17.7419408,6.90266784 C18.1605374,7.89083047 18.3698358,8.95104663 18.3698358,10.0833163 C18.3698358,11.215586 18.1605374,12.2758022 17.7419408,13.2639648 C17.3233441,14.2521274 16.7400537,15.1202009 15.9920694,15.8681851 C15.2440852,16.6161693 14.3760118,17.2011753 13.3878492,17.6232031 C12.3996865,18.0452309 11.3394704,18.2562448 10.2072007,18.2562448 Z M7.65444721,13.6242324 L12.7496608,13.6242324 C13.0584616,13.6242324 13.3003556,13.5384544 13.4753427,13.3668984 C13.6503299,13.1953424 13.7378234,12.9585951 13.7378234,12.6566565 L13.7378234,7.49968276 C13.7378234,7.19774418 13.6503299,6.96099688 13.4753427,6.78944087 C13.3003556,6.61788486 13.0584616,6.53210685 12.7496608,6.53210685 L7.65444721,6.53210685 C7.33878414,6.53210685 7.09345904,6.61788486 6.91847191,6.78944087 C6.74348478,6.96099688 6.65599121,7.19774418 6.65599121,7.49968276 L6.65599121,12.6566565 C6.65599121,12.9585951 6.74348478,13.1953424 6.91847191,13.3668984 C7.09345904,13.5384544 7.33878414,13.6242324 7.65444721,13.6242324 Z' id='shape' fill='%23FF3B30' fill-rule='nonzero'%3E%3C/path%3E %3C/g%3E %3C/svg%3E");
164
+ height: 21px;
165
+ }
166
  /* list */
167
  ol:not(.options), ul:not(.options) {
168
  padding-inline-start: 2em !important;
169
  }
170
 
171
+ /* 亮色(默认) */
172
+ #chuanhu_chatbot {
173
+ background-color: var(--chatbot-color-light) !important;
174
+ color: #000000 !important;
175
+ }
176
+ [data-testid = "bot"] {
177
+ background-color: #FFFFFF !important;
178
+ }
179
+ [data-testid = "user"] {
180
+ background-color: #95EC69 !important;
181
+ }
182
+ /* 暗色 */
183
+ .dark #chuanhu_chatbot {
184
+ background-color: var(--chatbot-color-dark) !important;
185
+ color: #FFFFFF !important;
186
+ }
187
+ .dark [data-testid = "bot"] {
188
+ background-color: #2C2C2C !important;
189
+ }
190
+ .dark [data-testid = "user"] {
191
+ background-color: #26B561 !important;
192
+ }
193
+
194
+ /* 屏幕宽度大于等于500px的设备 */
195
+ /* update on 2023.4.8: 高度的细致调整已写入JavaScript */
196
+ @media screen and (min-width: 500px) {
197
  #chuanhu_chatbot {
198
+ height: calc(100vh - 200px);
 
199
  }
200
+ #chuanhu_chatbot .wrap {
201
+ max-height: calc(100vh - 200px - var(--line-sm)*1rem - 2*var(--block-label-margin) );
 
 
 
202
  }
203
  }
204
+ /* 屏幕宽度小于500px的设备 */
205
+ @media screen and (max-width: 499px) {
206
  #chuanhu_chatbot {
207
+ height: calc(100vh - 140px);
 
208
  }
209
+ #chuanhu_chatbot .wrap {
210
+ max-height: calc(100vh - 140px - var(--line-sm)*1rem - 2*var(--block-label-margin) );
211
  }
212
+ [data-testid = "bot"] {
213
+ max-width: 98% !important;
214
  }
215
+ #app_title h1{
216
+ letter-spacing: -1px; font-size: 22px;
217
  }
218
  }
219
  /* 对话气泡 */
assets/custom.js CHANGED
@@ -1,70 +1,224 @@
 
1
  // custom javascript here
 
2
  const MAX_HISTORY_LENGTH = 32;
3
 
4
  var key_down_history = [];
5
  var currentIndex = -1;
6
  var user_input_ta;
7
 
 
 
 
 
 
 
 
 
8
  var ga = document.getElementsByTagName("gradio-app");
9
  var targetNode = ga[0];
10
- var observer = new MutationObserver(function(mutations) {
 
 
 
11
  for (var i = 0; i < mutations.length; i++) {
12
- if (mutations[i].addedNodes.length) {
13
- var user_input_tb = document.getElementById('user_input_tb');
14
- if (user_input_tb) {
15
- // 监听到user_input_tb被添加到DOM树中
16
- // 这里可以编写元素加载完成后需要执行的代码
17
- user_input_ta = user_input_tb.querySelector("textarea");
18
- if (user_input_ta){
19
- observer.disconnect(); // 停止监听
20
- // textarea 上监听 keydown 事件
21
- user_input_ta.addEventListener("keydown", function (event) {
22
- var value = user_input_ta.value.trim();
23
- // 判断按下的是否为方向键
24
- if (event.code === 'ArrowUp' || event.code === 'ArrowDown') {
25
- // 如果按下的是方向键,且输入框中有内容,且历史记录中没有该内容,则不执行操作
26
- if(value && key_down_history.indexOf(value) === -1)
27
- return;
28
- // 对于需要响应的动作,阻止默认行为。
29
- event.preventDefault();
30
- var length = key_down_history.length;
31
- if(length === 0) {
32
- currentIndex = -1; // 如果历史记录为空,直接将当前选中的记录重置
33
- return;
34
- }
35
- if (currentIndex === -1) {
36
- currentIndex = length;
37
- }
38
- if (event.code === 'ArrowUp' && currentIndex > 0) {
39
- currentIndex--;
40
- user_input_ta.value = key_down_history[currentIndex];
41
- } else if (event.code === 'ArrowDown' && currentIndex < length - 1) {
42
- currentIndex++;
43
- user_input_ta.value = key_down_history[currentIndex];
44
- }
45
- user_input_ta.selectionStart = user_input_ta.value.length;
46
- user_input_ta.selectionEnd = user_input_ta.value.length;
47
- const input_event = new InputEvent("input", {bubbles: true, cancelable: true});
48
- user_input_ta.dispatchEvent(input_event);
49
- }else if(event.code === "Enter") {
50
- if (value) {
51
- currentIndex = -1;
52
- if(key_down_history.indexOf(value) === -1){
53
- key_down_history.push(value);
54
- if (key_down_history.length > MAX_HISTORY_LENGTH) {
55
- key_down_history.shift();
56
- }
57
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  }
59
  }
60
- });
61
- break;
62
  }
63
- }
 
 
 
 
 
 
 
 
 
64
  }
65
- }
66
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- // 监听目标节点的子节点列表是否发生变化
69
- observer.observe(targetNode, { childList: true , subtree: true });
 
 
 
 
 
 
 
 
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
  // custom javascript here
3
+
4
  const MAX_HISTORY_LENGTH = 32;
5
 
6
  var key_down_history = [];
7
  var currentIndex = -1;
8
  var user_input_ta;
9
 
10
+ var gradioContainer = null;
11
+ var user_input_ta = null;
12
+ var user_input_tb = null;
13
+ var userInfoDiv = null;
14
+ var appTitleDiv = null;
15
+ var chatbot = null;
16
+ var apSwitch = null;
17
+
18
  var ga = document.getElementsByTagName("gradio-app");
19
  var targetNode = ga[0];
20
+ var isInIframe = (window.self !== window.top);
21
+
22
+ // gradio 页面加载好了么??? 我能动你的元素了么??
23
+ function gradioLoaded(mutations) {
24
  for (var i = 0; i < mutations.length; i++) {
25
+ if (mutations[i].addedNodes.length) {
26
+ gradioContainer = document.querySelector(".gradio-container");
27
+ user_input_tb = document.getElementById('user_input_tb');
28
+ userInfoDiv = document.getElementById("user_info");
29
+ appTitleDiv = document.getElementById("app_title");
30
+ chatbot = document.querySelector('#chuanhu_chatbot');
31
+ apSwitch = document.querySelector('.apSwitch input[type="checkbox"]');
32
+
33
+ if (gradioContainer && apSwitch) { // gradioCainter 加载出来了没?
34
+ adjustDarkMode();
35
+ }
36
+ if (user_input_tb) { // user_input_tb 加载出来了没?
37
+ selectHistory();
38
+ }
39
+ if (userInfoDiv && appTitleDiv) { // userInfoDiv 和 appTitleDiv 加载出来了没?
40
+ setTimeout(showOrHideUserInfo(), 2000);
41
+ }
42
+ if (chatbot) { // chatbot 加载出来了没?
43
+ setChatbotHeight()
44
+ }
45
+ }
46
+ }
47
+ }
48
+
49
+ function selectHistory() {
50
+ user_input_ta = user_input_tb.querySelector("textarea");
51
+ if (user_input_ta) {
52
+ observer.disconnect(); // 停止监听
53
+ // textarea 上监听 keydown 事件
54
+ user_input_ta.addEventListener("keydown", function (event) {
55
+ var value = user_input_ta.value.trim();
56
+ // 判断按下的是否为方向键
57
+ if (event.code === 'ArrowUp' || event.code === 'ArrowDown') {
58
+ // ���果按下的是方向键,且输入框中有内容,且历史记录中没有该内容,则不执行操作
59
+ if (value && key_down_history.indexOf(value) === -1)
60
+ return;
61
+ // 对于需要响应的动作,阻止默认行为。
62
+ event.preventDefault();
63
+ var length = key_down_history.length;
64
+ if (length === 0) {
65
+ currentIndex = -1; // 如果历史记录为空,直接将当前选中的记录重置
66
+ return;
67
+ }
68
+ if (currentIndex === -1) {
69
+ currentIndex = length;
70
+ }
71
+ if (event.code === 'ArrowUp' && currentIndex > 0) {
72
+ currentIndex--;
73
+ user_input_ta.value = key_down_history[currentIndex];
74
+ } else if (event.code === 'ArrowDown' && currentIndex < length - 1) {
75
+ currentIndex++;
76
+ user_input_ta.value = key_down_history[currentIndex];
77
+ }
78
+ user_input_ta.selectionStart = user_input_ta.value.length;
79
+ user_input_ta.selectionEnd = user_input_ta.value.length;
80
+ const input_event = new InputEvent("input", { bubbles: true, cancelable: true });
81
+ user_input_ta.dispatchEvent(input_event);
82
+ } else if (event.code === "Enter") {
83
+ if (value) {
84
+ currentIndex = -1;
85
+ if (key_down_history.indexOf(value) === -1) {
86
+ key_down_history.push(value);
87
+ if (key_down_history.length > MAX_HISTORY_LENGTH) {
88
+ key_down_history.shift();
89
  }
90
  }
91
+ }
 
92
  }
93
+ });
94
+ }
95
+ }
96
+
97
+ function toggleUserInfoVisibility(shouldHide) {
98
+ if (userInfoDiv) {
99
+ if (shouldHide) {
100
+ userInfoDiv.classList.add("hideK");
101
+ } else {
102
+ userInfoDiv.classList.remove("hideK");
103
  }
104
+ }
105
+ }
106
+ function showOrHideUserInfo() {
107
+ var sendBtn = document.getElementById("submit_btn");
108
+
109
+ // Bind mouse/touch events to show/hide user info
110
+ appTitleDiv.addEventListener("mouseenter", function () {
111
+ toggleUserInfoVisibility(false);
112
+ });
113
+ userInfoDiv.addEventListener("mouseenter", function () {
114
+ toggleUserInfoVisibility(false);
115
+ });
116
+ sendBtn.addEventListener("mouseenter", function () {
117
+ toggleUserInfoVisibility(false);
118
+ });
119
+
120
+ appTitleDiv.addEventListener("mouseleave", function () {
121
+ toggleUserInfoVisibility(true);
122
+ });
123
+ userInfoDiv.addEventListener("mouseleave", function () {
124
+ toggleUserInfoVisibility(true);
125
+ });
126
+ sendBtn.addEventListener("mouseleave", function () {
127
+ toggleUserInfoVisibility(true);
128
+ });
129
+
130
+ appTitleDiv.ontouchstart = function () {
131
+ toggleUserInfoVisibility(false);
132
+ };
133
+ userInfoDiv.ontouchstart = function () {
134
+ toggleUserInfoVisibility(false);
135
+ };
136
+ sendBtn.ontouchstart = function () {
137
+ toggleUserInfoVisibility(false);
138
+ };
139
+
140
+ appTitleDiv.ontouchend = function () {
141
+ setTimeout(function () {
142
+ toggleUserInfoVisibility(true);
143
+ }, 3000);
144
+ };
145
+ userInfoDiv.ontouchend = function () {
146
+ setTimeout(function () {
147
+ toggleUserInfoVisibility(true);
148
+ }, 3000);
149
+ };
150
+ sendBtn.ontouchend = function () {
151
+ setTimeout(function () {
152
+ toggleUserInfoVisibility(true);
153
+ }, 3000); // Delay 1 second to hide user info
154
+ };
155
+
156
+ // Hide user info after 2 second
157
+ setTimeout(function () {
158
+ toggleUserInfoVisibility(true);
159
+ }, 2000);
160
+ }
161
 
162
+ function toggleDarkMode(isEnabled) {
163
+ if (isEnabled) {
164
+ gradioContainer.classList.add("dark");
165
+ document.body.style.setProperty("background-color", "var(--neutral-950)", "important");
166
+ } else {
167
+ gradioContainer.classList.remove("dark");
168
+ document.body.style.backgroundColor = "";
169
+ }
170
+ }
171
+ function adjustDarkMode() {
172
+ const darkModeQuery = window.matchMedia("(prefers-color-scheme: dark)");
173
 
174
+ // 根据当前颜色模式设置初始状态
175
+ apSwitch.checked = darkModeQuery.matches;
176
+ toggleDarkMode(darkModeQuery.matches);
177
+ // 监听颜色模式变化
178
+ darkModeQuery.addEventListener("change", (e) => {
179
+ apSwitch.checked = e.matches;
180
+ toggleDarkMode(e.matches);
181
+ });
182
+ // apSwitch = document.querySelector('.apSwitch input[type="checkbox"]');
183
+ apSwitch.addEventListener("change", (e) => {
184
+ toggleDarkMode(e.target.checked);
185
+ });
186
+ }
187
+
188
+ function setChatbotHeight() {
189
+ const screenWidth = window.innerWidth;
190
+ const statusDisplay = document.querySelector('#status_display');
191
+ const statusDisplayHeight = statusDisplay ? statusDisplay.offsetHeight : 0;
192
+ const wrap = chatbot.querySelector('.wrap');
193
+ const vh = window.innerHeight * 0.01;
194
+ document.documentElement.style.setProperty('--vh', `${vh}px`);
195
+ if (isInIframe) {
196
+ chatbot.style.height = `700px`;
197
+ wrap.style.maxHeight = `calc(700px - var(--line-sm) * 1rem - 2 * var(--block-label-margin))`
198
+ } else {
199
+ if (screenWidth <= 320) {
200
+ chatbot.style.height = `calc(var(--vh, 1vh) * 100 - ${statusDisplayHeight + 150}px)`;
201
+ wrap.style.maxHeight = `calc(var(--vh, 1vh) * 100 - ${statusDisplayHeight + 150}px - var(--line-sm) * 1rem - 2 * var(--block-label-margin))`;
202
+ } else if (screenWidth <= 499) {
203
+ chatbot.style.height = `calc(var(--vh, 1vh) * 100 - ${statusDisplayHeight + 100}px)`;
204
+ wrap.style.maxHeight = `calc(var(--vh, 1vh) * 100 - ${statusDisplayHeight + 100}px - var(--line-sm) * 1rem - 2 * var(--block-label-margin))`;
205
+ } else {
206
+ chatbot.style.height = `calc(var(--vh, 1vh) * 100 - ${statusDisplayHeight + 160}px)`;
207
+ wrap.style.maxHeight = `calc(var(--vh, 1vh) * 100 - ${statusDisplayHeight + 160}px - var(--line-sm) * 1rem - 2 * var(--block-label-margin))`;
208
+ }
209
+ }
210
+ }
211
+
212
+ // 监视页面内部 DOM 变动
213
+ var observer = new MutationObserver(function (mutations) {
214
+ gradioLoaded(mutations);
215
+ });
216
+ observer.observe(targetNode, { childList: true, subtree: true });
217
+
218
+ // 监视页面变化
219
+ window.addEventListener("DOMContentLoaded", function () {
220
+ isInIframe = (window.self !== window.top);
221
+ });
222
+ window.addEventListener('resize', setChatbotHeight);
223
+ window.addEventListener('scroll', setChatbotHeight);
224
+ window.matchMedia("(prefers-color-scheme: dark)").addEventListener("change", adjustDarkMode);
config_example.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // 你的OpenAI API Key,一般必填,
3
+ // 若缺省填为 "openai_api_key": "" 则必须再在图形界面中填入API Key
4
+ "openai_api_key": "",
5
+ // 如果使用代理,请取消注释下面的两行,并替换代理URL
6
+ // "https_proxy": "http://127.0.0.1:1079",
7
+ // "http_proxy": "http://127.0.0.1:1079",
8
+ "users": [], // 用户列表,[[用户名1, 密码1], [用户名2, 密码2], ...]
9
+ "local_embedding": false, //是否在本地编制索引
10
+ "default_model": "gpt-3.5-turbo", // 默认模型
11
+ "advance_docs": {
12
+ "pdf": {
13
+ // 是否认为PDF是双栏的
14
+ "two_column": false,
15
+ // 是否使用OCR识别PDF中的公式
16
+ "formula_ocr": true
17
+ }
18
+ },
19
+ // 是否多个API Key轮换使用
20
+ "multi_api_key": false,
21
+ "api_key_list": [
22
+ "sk-xxxxxxxxxxxxxxxxxxxxxxxx1",
23
+ "sk-xxxxxxxxxxxxxxxxxxxxxxxx2",
24
+ "sk-xxxxxxxxxxxxxxxxxxxxxxxx3"
25
+ ],
26
+ // 如果使用自定义端口、自定义ip,请取消注释并替换对应内容
27
+ // "server_name": "0.0.0.0",
28
+ // "server_port": 7860,
29
+ // 如果要share到gradio,设置为true
30
+ // "share": false,
31
+ }
configs/ds_config_chatbot.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false
4
+ },
5
+ "bf16": {
6
+ "enabled": true
7
+ },
8
+ "comms_logger": {
9
+ "enabled": false,
10
+ "verbose": false,
11
+ "prof_all": false,
12
+ "debug": false
13
+ },
14
+ "steps_per_print": 20000000000000000,
15
+ "train_micro_batch_size_per_gpu": 1,
16
+ "wall_clock_breakdown": false
17
+ }
modules/__init__.py ADDED
File without changes
modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (172 Bytes). View file
 
modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (154 Bytes). View file
 
modules/__pycache__/base_model.cpython-311.pyc ADDED
Binary file (26.7 kB). View file
 
modules/__pycache__/base_model.cpython-39.pyc ADDED
Binary file (15.8 kB). View file
 
modules/__pycache__/config.cpython-311.pyc ADDED
Binary file (7.87 kB). View file
 
modules/__pycache__/config.cpython-39.pyc CHANGED
Binary files a/modules/__pycache__/config.cpython-39.pyc and b/modules/__pycache__/config.cpython-39.pyc differ
 
modules/__pycache__/llama_func.cpython-311.pyc ADDED
Binary file (9.28 kB). View file
 
modules/__pycache__/models.cpython-311.pyc ADDED
Binary file (30.6 kB). View file
 
modules/base_model.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, List
3
+
4
+ import logging
5
+ import json
6
+ import commentjson as cjson
7
+ import os
8
+ import sys
9
+ import requests
10
+ import urllib3
11
+ import traceback
12
+
13
+ from tqdm import tqdm
14
+ import colorama
15
+ from duckduckgo_search import ddg
16
+ import asyncio
17
+ import aiohttp
18
+ from enum import Enum
19
+
20
+ from .presets import *
21
+ from .llama_func import *
22
+ from .utils import *
23
+ from . import shared
24
+ from .config import retrieve_proxy
25
+
26
+
27
+ class ModelType(Enum):
28
+ Unknown = -1
29
+ OpenAI = 0
30
+ ChatGLM = 1
31
+ LLaMA = 2
32
+ XMBot = 3
33
+
34
+ @classmethod
35
+ def get_type(cls, model_name: str):
36
+ model_type = None
37
+ model_name_lower = model_name.lower()
38
+ if "gpt" in model_name_lower:
39
+ model_type = ModelType.OpenAI
40
+ elif "chatglm" in model_name_lower:
41
+ model_type = ModelType.ChatGLM
42
+ elif "llama" in model_name_lower or "alpaca" in model_name_lower:
43
+ model_type = ModelType.LLaMA
44
+ elif "xmbot" in model_name_lower:
45
+ model_type = ModelType.XMBot
46
+ else:
47
+ model_type = ModelType.Unknown
48
+ return model_type
49
+
50
+
51
+ class BaseLLMModel:
52
+ def __init__(
53
+ self,
54
+ model_name,
55
+ system_prompt="",
56
+ temperature=1.0,
57
+ top_p=1.0,
58
+ n_choices=1,
59
+ stop=None,
60
+ max_generation_token=None,
61
+ presence_penalty=0,
62
+ frequency_penalty=0,
63
+ logit_bias=None,
64
+ user="",
65
+ ) -> None:
66
+ self.history = []
67
+ self.all_token_counts = []
68
+ self.model_name = model_name
69
+ self.model_type = ModelType.get_type(model_name)
70
+ try:
71
+ self.token_upper_limit = MODEL_TOKEN_LIMIT[model_name]
72
+ except KeyError:
73
+ self.token_upper_limit = DEFAULT_TOKEN_LIMIT
74
+ self.interrupted = False
75
+ self.system_prompt = system_prompt
76
+ self.api_key = None
77
+ self.need_api_key = False
78
+ self.single_turn = False
79
+
80
+ self.temperature = temperature
81
+ self.top_p = top_p
82
+ self.n_choices = n_choices
83
+ self.stop_sequence = stop
84
+ self.max_generation_token = None
85
+ self.presence_penalty = presence_penalty
86
+ self.frequency_penalty = frequency_penalty
87
+ self.logit_bias = logit_bias
88
+ self.user_identifier = user
89
+
90
+ def get_answer_stream_iter(self):
91
+ """stream predict, need to be implemented
92
+ conversations are stored in self.history, with the most recent question, in OpenAI format
93
+ should return a generator, each time give the next word (str) in the answer
94
+ """
95
+ logging.warning("stream predict not implemented, using at once predict instead")
96
+ response, _ = self.get_answer_at_once()
97
+ yield response
98
+
99
+ def get_answer_at_once(self):
100
+ """predict at once, need to be implemented
101
+ conversations are stored in self.history, with the most recent question, in OpenAI format
102
+ Should return:
103
+ the answer (str)
104
+ total token count (int)
105
+ """
106
+ logging.warning("at once predict not implemented, using stream predict instead")
107
+ response_iter = self.get_answer_stream_iter()
108
+ count = 0
109
+ for response in response_iter:
110
+ count += 1
111
+ return response, sum(self.all_token_counts) + count
112
+
113
+ def billing_info(self):
114
+ """get billing infomation, inplement if needed"""
115
+ logging.warning("billing info not implemented, using default")
116
+ return BILLING_NOT_APPLICABLE_MSG
117
+
118
+ def count_token(self, user_input):
119
+ """get token count from input, implement if needed"""
120
+ logging.warning("token count not implemented, using default")
121
+ return len(user_input)
122
+
123
+ def stream_next_chatbot(self, inputs, chatbot, fake_input=None, display_append=""):
124
+ def get_return_value():
125
+ return chatbot, status_text
126
+
127
+ status_text = "开始实时传输回答……"
128
+ if fake_input:
129
+ chatbot.append((fake_input, ""))
130
+ else:
131
+ chatbot.append((inputs, ""))
132
+
133
+ user_token_count = self.count_token(inputs)
134
+ self.all_token_counts.append(user_token_count)
135
+ logging.debug(f"输入token计数: {user_token_count}")
136
+
137
+ stream_iter = self.get_answer_stream_iter()
138
+
139
+ for partial_text in stream_iter:
140
+ chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
141
+ self.all_token_counts[-1] += 1
142
+ status_text = self.token_message()
143
+ yield get_return_value()
144
+ if self.interrupted:
145
+ self.recover()
146
+ break
147
+ self.history.append(construct_assistant(partial_text))
148
+
149
+ def next_chatbot_at_once(self, inputs, chatbot, fake_input=None, display_append=""):
150
+ if fake_input:
151
+ chatbot.append((fake_input, ""))
152
+ else:
153
+ chatbot.append((inputs, ""))
154
+ if fake_input is not None:
155
+ user_token_count = self.count_token(fake_input)
156
+ else:
157
+ user_token_count = self.count_token(inputs)
158
+ self.all_token_counts.append(user_token_count)
159
+ ai_reply, total_token_count = self.get_answer_at_once()
160
+ self.history.append(construct_assistant(ai_reply))
161
+ if fake_input is not None:
162
+ self.history[-2] = construct_user(fake_input)
163
+ chatbot[-1] = (chatbot[-1][0], ai_reply + display_append)
164
+ if fake_input is not None:
165
+ self.all_token_counts[-1] += count_token(construct_assistant(ai_reply))
166
+ else:
167
+ self.all_token_counts[-1] = total_token_count - sum(self.all_token_counts)
168
+ status_text = self.token_message()
169
+ return chatbot, status_text
170
+
171
+ def handle_file_upload(self, files, chatbot):
172
+ """if the model accepts multi modal input, implement this function"""
173
+ status = gr.Markdown.update()
174
+ if files:
175
+ construct_index(self.api_key, file_src=files)
176
+ status = "索引构建完成"
177
+ return gr.Files.update(), chatbot, status
178
+
179
+ def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
180
+ fake_inputs = None
181
+ display_append = []
182
+ limited_context = False
183
+ fake_inputs = real_inputs
184
+ if files:
185
+ from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
186
+ from llama_index.indices.query.schema import QueryBundle
187
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
188
+ from langchain.chat_models import ChatOpenAI
189
+ from llama_index import (
190
+ GPTSimpleVectorIndex,
191
+ ServiceContext,
192
+ LangchainEmbedding,
193
+ OpenAIEmbedding,
194
+ )
195
+ limited_context = True
196
+ msg = "加载索引中……"
197
+ logging.info(msg)
198
+ # yield chatbot + [(inputs, "")], msg
199
+ index = construct_index(self.api_key, file_src=files)
200
+ assert index is not None, "获取索引失败"
201
+ msg = "索引获取成功,生成回答中……"
202
+ logging.info(msg)
203
+ if local_embedding or self.model_type != ModelType.OpenAI:
204
+ embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
205
+ else:
206
+ embed_model = OpenAIEmbedding()
207
+ # yield chatbot + [(inputs, "")], msg
208
+ with retrieve_proxy():
209
+ prompt_helper = PromptHelper(
210
+ max_input_size=4096,
211
+ num_output=5,
212
+ max_chunk_overlap=20,
213
+ chunk_size_limit=600,
214
+ )
215
+ from llama_index import ServiceContext
216
+
217
+ service_context = ServiceContext.from_defaults(
218
+ prompt_helper=prompt_helper, embed_model=embed_model
219
+ )
220
+ query_object = GPTVectorStoreIndexQuery(
221
+ index.index_struct,
222
+ service_context=service_context,
223
+ similarity_top_k=5,
224
+ vector_store=index._vector_store,
225
+ docstore=index._docstore,
226
+ )
227
+ query_bundle = QueryBundle(real_inputs)
228
+ nodes = query_object.retrieve(query_bundle)
229
+ reference_results = [n.node.text for n in nodes]
230
+ reference_results = add_source_numbers(reference_results, use_source=False)
231
+ display_append = add_details(reference_results)
232
+ display_append = "\n\n" + "".join(display_append)
233
+ real_inputs = (
234
+ replace_today(PROMPT_TEMPLATE)
235
+ .replace("{query_str}", real_inputs)
236
+ .replace("{context_str}", "\n\n".join(reference_results))
237
+ .replace("{reply_language}", reply_language)
238
+ )
239
+ elif use_websearch:
240
+ limited_context = True
241
+ search_results = ddg(real_inputs, max_results=5)
242
+ reference_results = []
243
+ for idx, result in enumerate(search_results):
244
+ logging.debug(f"搜索结果{idx + 1}:{result}")
245
+ domain_name = urllib3.util.parse_url(result["href"]).host
246
+ reference_results.append([result["body"], result["href"]])
247
+ display_append.append(
248
+ f"{idx+1}. [{domain_name}]({result['href']})\n"
249
+ )
250
+ reference_results = add_source_numbers(reference_results)
251
+ display_append = "\n\n" + "".join(display_append)
252
+ real_inputs = (
253
+ replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
254
+ .replace("{query}", real_inputs)
255
+ .replace("{web_results}", "\n\n".join(reference_results))
256
+ .replace("{reply_language}", reply_language)
257
+ )
258
+ else:
259
+ display_append = ""
260
+ return limited_context, fake_inputs, display_append, real_inputs, chatbot
261
+
262
+ def predict(
263
+ self,
264
+ inputs,
265
+ chatbot,
266
+ stream=False,
267
+ use_websearch=False,
268
+ files=None,
269
+ reply_language="中文",
270
+ should_check_token_count=True,
271
+ ): # repetition_penalty, top_k
272
+
273
+ status_text = "开始生成回答……"
274
+ logging.info(
275
+ "输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
276
+ )
277
+ if should_check_token_count:
278
+ yield chatbot + [(inputs, "")], status_text
279
+ if reply_language == "跟随问题语言(不稳定)":
280
+ reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
281
+
282
+ limited_context, fake_inputs, display_append, inputs, chatbot = self.prepare_inputs(real_inputs=inputs, use_websearch=use_websearch, files=files, reply_language=reply_language, chatbot=chatbot)
283
+ yield chatbot + [(fake_inputs, "")], status_text
284
+
285
+ if (
286
+ self.need_api_key and
287
+ self.api_key is None
288
+ and not shared.state.multi_api_key
289
+ ):
290
+ status_text = STANDARD_ERROR_MSG + NO_APIKEY_MSG
291
+ logging.info(status_text)
292
+ chatbot.append((inputs, ""))
293
+ if len(self.history) == 0:
294
+ self.history.append(construct_user(inputs))
295
+ self.history.append("")
296
+ self.all_token_counts.append(0)
297
+ else:
298
+ self.history[-2] = construct_user(inputs)
299
+ yield chatbot + [(inputs, "")], status_text
300
+ return
301
+ elif len(inputs.strip()) == 0:
302
+ status_text = STANDARD_ERROR_MSG + NO_INPUT_MSG
303
+ logging.info(status_text)
304
+ yield chatbot + [(inputs, "")], status_text
305
+ return
306
+
307
+ if self.single_turn:
308
+ self.history = []
309
+ self.all_token_counts = []
310
+ self.history.append(construct_user(inputs))
311
+
312
+ try:
313
+ if stream:
314
+ logging.debug("使用流式传输")
315
+ iter = self.stream_next_chatbot(
316
+ inputs,
317
+ chatbot,
318
+ fake_input=fake_inputs,
319
+ display_append=display_append,
320
+ )
321
+ for chatbot, status_text in iter:
322
+ yield chatbot, status_text
323
+ else:
324
+ logging.debug("不使用流式传输")
325
+ chatbot, status_text = self.next_chatbot_at_once(
326
+ inputs,
327
+ chatbot,
328
+ fake_input=fake_inputs,
329
+ display_append=display_append,
330
+ )
331
+ yield chatbot, status_text
332
+ except Exception as e:
333
+ traceback.print_exc()
334
+ status_text = STANDARD_ERROR_MSG + str(e)
335
+ yield chatbot, status_text
336
+
337
+ if len(self.history) > 1 and self.history[-1]["content"] != inputs:
338
+ logging.info(
339
+ "回答为:"
340
+ + colorama.Fore.BLUE
341
+ + f"{self.history[-1]['content']}"
342
+ + colorama.Style.RESET_ALL
343
+ )
344
+
345
+ if limited_context:
346
+ # self.history = self.history[-4:]
347
+ # self.all_token_counts = self.all_token_counts[-2:]
348
+ self.history = []
349
+ self.all_token_counts = []
350
+
351
+ max_token = self.token_upper_limit - TOKEN_OFFSET
352
+
353
+ if sum(self.all_token_counts) > max_token and should_check_token_count:
354
+ count = 0
355
+ while (
356
+ sum(self.all_token_counts)
357
+ > self.token_upper_limit * REDUCE_TOKEN_FACTOR
358
+ and sum(self.all_token_counts) > 0
359
+ ):
360
+ count += 1
361
+ del self.all_token_counts[0]
362
+ del self.history[:2]
363
+ logging.info(status_text)
364
+ status_text = f"为了防止token超限,模型忘记了早期的 {count} 轮对话"
365
+ yield chatbot, status_text
366
+
367
+ def retry(
368
+ self,
369
+ chatbot,
370
+ stream=False,
371
+ use_websearch=False,
372
+ files=None,
373
+ reply_language="中文",
374
+ ):
375
+ logging.debug("重试中……")
376
+ if len(self.history) == 0:
377
+ yield chatbot, f"{STANDARD_ERROR_MSG}上下文是空的"
378
+ return
379
+
380
+ inputs = self.history[-2]["content"]
381
+ del self.history[-2:]
382
+ self.all_token_counts.pop()
383
+ iter = self.predict(
384
+ inputs,
385
+ chatbot,
386
+ stream=stream,
387
+ use_websearch=use_websearch,
388
+ files=files,
389
+ reply_language=reply_language,
390
+ )
391
+ for x in iter:
392
+ yield x
393
+ logging.debug("重试完毕")
394
+
395
+ # def reduce_token_size(self, chatbot):
396
+ # logging.info("开始减少token数量……")
397
+ # chatbot, status_text = self.next_chatbot_at_once(
398
+ # summarize_prompt,
399
+ # chatbot
400
+ # )
401
+ # max_token_count = self.token_upper_limit * REDUCE_TOKEN_FACTOR
402
+ # num_chat = find_n(self.all_token_counts, max_token_count)
403
+ # logging.info(f"previous_token_count: {self.all_token_counts}, keeping {num_chat} chats")
404
+ # chatbot = chatbot[:-1]
405
+ # self.history = self.history[-2*num_chat:] if num_chat > 0 else []
406
+ # self.all_token_counts = self.all_token_counts[-num_chat:] if num_chat > 0 else []
407
+ # msg = f"保留了最近{num_chat}轮对话"
408
+ # logging.info(msg)
409
+ # logging.info("减少token数量完毕")
410
+ # return chatbot, msg + "," + self.token_message(self.all_token_counts if len(self.all_token_counts) > 0 else [0])
411
+
412
+ def interrupt(self):
413
+ self.interrupted = True
414
+
415
+ def recover(self):
416
+ self.interrupted = False
417
+
418
+ def set_token_upper_limit(self, new_upper_limit):
419
+ self.token_upper_limit = new_upper_limit
420
+ print(f"token上限设置为{new_upper_limit}")
421
+
422
+ def set_temperature(self, new_temperature):
423
+ self.temperature = new_temperature
424
+
425
+ def set_top_p(self, new_top_p):
426
+ self.top_p = new_top_p
427
+
428
+ def set_n_choices(self, new_n_choices):
429
+ self.n_choices = new_n_choices
430
+
431
+ def set_stop_sequence(self, new_stop_sequence: str):
432
+ new_stop_sequence = new_stop_sequence.split(",")
433
+ self.stop_sequence = new_stop_sequence
434
+
435
+ def set_max_tokens(self, new_max_tokens):
436
+ self.max_generation_token = new_max_tokens
437
+
438
+ def set_presence_penalty(self, new_presence_penalty):
439
+ self.presence_penalty = new_presence_penalty
440
+
441
+ def set_frequency_penalty(self, new_frequency_penalty):
442
+ self.frequency_penalty = new_frequency_penalty
443
+
444
+ def set_logit_bias(self, logit_bias):
445
+ logit_bias = logit_bias.split()
446
+ bias_map = {}
447
+ encoding = tiktoken.get_encoding("cl100k_base")
448
+ for line in logit_bias:
449
+ word, bias_amount = line.split(":")
450
+ if word:
451
+ for token in encoding.encode(word):
452
+ bias_map[token] = float(bias_amount)
453
+ self.logit_bias = bias_map
454
+
455
+ def set_user_identifier(self, new_user_identifier):
456
+ self.user_identifier = new_user_identifier
457
+
458
+ def set_system_prompt(self, new_system_prompt):
459
+ self.system_prompt = new_system_prompt
460
+
461
+ def set_key(self, new_access_key):
462
+ self.api_key = new_access_key.strip()
463
+ msg = f"API密钥更改为了{hide_middle_chars(self.api_key)}"
464
+ logging.info(msg)
465
+ return new_access_key, msg
466
+
467
+ def set_single_turn(self, new_single_turn):
468
+ self.single_turn = new_single_turn
469
+
470
+ def reset(self):
471
+ self.history = []
472
+ self.all_token_counts = []
473
+ self.interrupted = False
474
+ return [], self.token_message([0])
475
+
476
+ def delete_first_conversation(self):
477
+ if self.history:
478
+ del self.history[:2]
479
+ del self.all_token_counts[0]
480
+ return self.token_message()
481
+
482
+ def delete_last_conversation(self, chatbot):
483
+ if len(chatbot) > 0 and STANDARD_ERROR_MSG in chatbot[-1][1]:
484
+ msg = "由于包含报错信息,只删除chatbot记录"
485
+ chatbot.pop()
486
+ return chatbot, self.history
487
+ if len(self.history) > 0:
488
+ self.history.pop()
489
+ self.history.pop()
490
+ if len(chatbot) > 0:
491
+ msg = "删除了一组chatbot对话"
492
+ chatbot.pop()
493
+ if len(self.all_token_counts) > 0:
494
+ msg = "删除了一组对话的token计数记录"
495
+ self.all_token_counts.pop()
496
+ msg = "删除了一组对话"
497
+ return chatbot, msg
498
+
499
+ def token_message(self, token_lst=None):
500
+ if token_lst is None:
501
+ token_lst = self.all_token_counts
502
+ token_sum = 0
503
+ for i in range(len(token_lst)):
504
+ token_sum += sum(token_lst[: i + 1])
505
+ return f"Token 计数: {sum(token_lst)},本次对话累计消耗了 {token_sum} tokens"
506
+
507
+ def save_chat_history(self, filename, chatbot, user_name):
508
+ if filename == "":
509
+ return
510
+ if not filename.endswith(".json"):
511
+ filename += ".json"
512
+ return save_file(filename, self.system_prompt, self.history, chatbot, user_name)
513
+
514
+ def export_markdown(self, filename, chatbot, user_name):
515
+ if filename == "":
516
+ return
517
+ if not filename.endswith(".md"):
518
+ filename += ".md"
519
+ return save_file(filename, self.system_prompt, self.history, chatbot, user_name)
520
+
521
+ def load_chat_history(self, filename, chatbot, user_name):
522
+ logging.debug(f"{user_name} 加载对话历史中……")
523
+ if type(filename) != str:
524
+ filename = filename.name
525
+ try:
526
+ with open(os.path.join(HISTORY_DIR, user_name, filename), "r") as f:
527
+ json_s = json.load(f)
528
+ try:
529
+ if type(json_s["history"][0]) == str:
530
+ logging.info("历史记录格式为旧版,正在转换……")
531
+ new_history = []
532
+ for index, item in enumerate(json_s["history"]):
533
+ if index % 2 == 0:
534
+ new_history.append(construct_user(item))
535
+ else:
536
+ new_history.append(construct_assistant(item))
537
+ json_s["history"] = new_history
538
+ logging.info(new_history)
539
+ except:
540
+ # 没有对话历史
541
+ pass
542
+ logging.debug(f"{user_name} 加载对话历史完毕")
543
+ self.history = json_s["history"]
544
+ return filename, json_s["system"], json_s["chatbot"]
545
+ except FileNotFoundError:
546
+ logging.warning(f"{user_name} 没有找到对话历史文件,不执行任何操作")
547
+ return filename, self.system_prompt, chatbot
modules/config.py CHANGED
@@ -3,9 +3,10 @@ from contextlib import contextmanager
3
  import os
4
  import logging
5
  import sys
6
- import json
7
 
8
  from . import shared
 
9
 
10
 
11
  __all__ = [
@@ -18,6 +19,9 @@ __all__ = [
18
  "advance_docs",
19
  "update_doc_config",
20
  "multi_api_key",
 
 
 
21
  ]
22
 
23
  # 添加一个统一的config文件,避免文件过多造成的疑惑(优先级最低)
@@ -28,6 +32,30 @@ if os.path.exists("config.json"):
28
  else:
29
  config = {}
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  ## 处理docker if we are running in Docker
32
  dockerflag = config.get("dockerflag", False)
33
  if os.environ.get("dockerrun") == "yes":
@@ -54,35 +82,6 @@ api_host = os.environ.get("api_host", config.get("api_host", ""))
54
  if api_host:
55
  shared.state.set_api_host(api_host)
56
 
57
- if dockerflag:
58
- if my_api_key == "empty":
59
- logging.error("Please give a api key!")
60
- sys.exit(1)
61
- # auth
62
- username = os.environ.get("USERNAME")
63
- password = os.environ.get("PASSWORD")
64
- if not (isinstance(username, type(None)) or isinstance(password, type(None))):
65
- auth_list.append((os.environ.get("USERNAME"), os.environ.get("PASSWORD")))
66
- authflag = True
67
- else:
68
- if (
69
- not my_api_key
70
- and os.path.exists("api_key.txt")
71
- and os.path.getsize("api_key.txt")
72
- ):
73
- with open("api_key.txt", "r") as f:
74
- my_api_key = f.read().strip()
75
- if os.path.exists("auth.json"):
76
- authflag = True
77
- with open("auth.json", "r", encoding='utf-8') as f:
78
- auth = json.load(f)
79
- for _ in auth:
80
- if auth[_]["username"] and auth[_]["password"]:
81
- auth_list.append((auth[_]["username"], auth[_]["password"]))
82
- else:
83
- logging.error("请检查auth.json文件中的用户名和密码!")
84
- sys.exit(1)
85
-
86
  @contextmanager
87
  def retrieve_openai_api(api_key = None):
88
  old_api_key = os.environ.get("OPENAI_API_KEY", "")
@@ -111,6 +110,8 @@ https_proxy = os.environ.get("HTTPS_PROXY", https_proxy)
111
  os.environ["HTTP_PROXY"] = ""
112
  os.environ["HTTPS_PROXY"] = ""
113
 
 
 
114
  @contextmanager
115
  def retrieve_proxy(proxy=None):
116
  """
@@ -137,9 +138,29 @@ advance_docs = defaultdict(lambda: defaultdict(dict))
137
  advance_docs.update(config.get("advance_docs", {}))
138
  def update_doc_config(two_column_pdf):
139
  global advance_docs
140
- if two_column_pdf:
141
- advance_docs["pdf"]["two_column"] = True
 
 
 
 
 
 
 
 
142
  else:
143
- advance_docs["pdf"]["two_column"] = False
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- logging.info(f"更新后的文件参数为:{advance_docs}")
 
3
  import os
4
  import logging
5
  import sys
6
+ import commentjson as json
7
 
8
  from . import shared
9
+ from . import presets
10
 
11
 
12
  __all__ = [
 
19
  "advance_docs",
20
  "update_doc_config",
21
  "multi_api_key",
22
+ "server_name",
23
+ "server_port",
24
+ "share",
25
  ]
26
 
27
  # 添加一个统一的config文件,避免文件过多造成的疑惑(优先级最低)
 
32
  else:
33
  config = {}
34
 
35
+ if os.path.exists("api_key.txt"):
36
+ logging.info("检测到api_key.txt文件,正在进行迁移...")
37
+ with open("api_key.txt", "r") as f:
38
+ config["openai_api_key"] = f.read().strip()
39
+ os.rename("api_key.txt", "api_key(deprecated).txt")
40
+ with open("config.json", "w", encoding='utf-8') as f:
41
+ json.dump(config, f, indent=4)
42
+
43
+ if os.path.exists("auth.json"):
44
+ logging.info("检测到auth.json文件,正在进行迁移...")
45
+ auth_list = []
46
+ with open("auth.json", "r", encoding='utf-8') as f:
47
+ auth = json.load(f)
48
+ for _ in auth:
49
+ if auth[_]["username"] and auth[_]["password"]:
50
+ auth_list.append((auth[_]["username"], auth[_]["password"]))
51
+ else:
52
+ logging.error("请检查auth.json文件中的用户名和密码!")
53
+ sys.exit(1)
54
+ config["users"] = auth_list
55
+ os.rename("auth.json", "auth(deprecated).json")
56
+ with open("config.json", "w", encoding='utf-8') as f:
57
+ json.dump(config, f, indent=4)
58
+
59
  ## 处理docker if we are running in Docker
60
  dockerflag = config.get("dockerflag", False)
61
  if os.environ.get("dockerrun") == "yes":
 
82
  if api_host:
83
  shared.state.set_api_host(api_host)
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  @contextmanager
86
  def retrieve_openai_api(api_key = None):
87
  old_api_key = os.environ.get("OPENAI_API_KEY", "")
 
110
  os.environ["HTTP_PROXY"] = ""
111
  os.environ["HTTPS_PROXY"] = ""
112
 
113
+ local_embedding = config.get("local_embedding", False) # 是否使用本地embedding
114
+
115
  @contextmanager
116
  def retrieve_proxy(proxy=None):
117
  """
 
138
  advance_docs.update(config.get("advance_docs", {}))
139
  def update_doc_config(two_column_pdf):
140
  global advance_docs
141
+ advance_docs["pdf"]["two_column"] = two_column_pdf
142
+
143
+ logging.info(f"更新后的文件参数为:{advance_docs}")
144
+
145
+ ## 处理gradio.launch参数
146
+ server_name = config.get("server_name", None)
147
+ server_port = config.get("server_port", None)
148
+ if server_name is None:
149
+ if dockerflag:
150
+ server_name = "0.0.0.0"
151
  else:
152
+ server_name = "127.0.0.1"
153
+ if server_port is None:
154
+ if dockerflag:
155
+ server_port = 7860
156
+
157
+ assert server_port is None or type(server_port) == int, "要求port设置为int类型"
158
+
159
+ # 设置默认model
160
+ default_model = config.get("default_model", "")
161
+ try:
162
+ presets.DEFAULT_MODEL = presets.MODELS.index(default_model)
163
+ except ValueError:
164
+ pass
165
 
166
+ share = config.get("share", False)
modules/llama_func.py CHANGED
@@ -15,6 +15,8 @@ from tqdm import tqdm
15
 
16
  from modules.presets import *
17
  from modules.utils import *
 
 
18
 
19
  def get_index_name(file_src):
20
  file_paths = [x.name for x in file_src]
@@ -28,6 +30,7 @@ def get_index_name(file_src):
28
 
29
  return md5_hash.hexdigest()
30
 
 
31
  def block_split(text):
32
  blocks = []
33
  while len(text) > 0:
@@ -35,6 +38,7 @@ def block_split(text):
35
  text = text[1000:]
36
  return blocks
37
 
 
38
  def get_documents(file_src):
39
  documents = []
40
  logging.debug("Loading documents...")
@@ -44,37 +48,45 @@ def get_documents(file_src):
44
  filename = os.path.basename(filepath)
45
  file_type = os.path.splitext(filepath)[1]
46
  logging.info(f"loading file: {filename}")
47
- if file_type == ".pdf":
48
- logging.debug("Loading PDF...")
49
- try:
50
- from modules.pdf_func import parse_pdf
51
- from modules.config import advance_docs
52
- two_column = advance_docs["pdf"].get("two_column", False)
53
- pdftext = parse_pdf(filepath, two_column).text
54
- except:
55
- pdftext = ""
56
- with open(filepath, 'rb') as pdfFileObj:
57
- pdfReader = PyPDF2.PdfReader(pdfFileObj)
58
- for page in tqdm(pdfReader.pages):
59
- pdftext += page.extract_text()
60
- text_raw = pdftext
61
- elif file_type == ".docx":
62
- logging.debug("Loading Word...")
63
- DocxReader = download_loader("DocxReader")
64
- loader = DocxReader()
65
- text_raw = loader.load_data(file=filepath)[0].text
66
- elif file_type == ".epub":
67
- logging.debug("Loading EPUB...")
68
- EpubReader = download_loader("EpubReader")
69
- loader = EpubReader()
70
- text_raw = loader.load_data(file=filepath)[0].text
71
- elif file_type == ".xlsx":
72
- logging.debug("Loading Excel...")
73
- text_raw = excel_to_string(filepath)
74
- else:
75
- logging.debug("Loading text file...")
76
- with open(filepath, "r", encoding="utf-8") as f:
77
- text_raw = f.read()
 
 
 
 
 
 
 
 
78
  text = add_space(text_raw)
79
  # text = block_split(text)
80
  # documents += text
@@ -84,27 +96,36 @@ def get_documents(file_src):
84
 
85
 
86
  def construct_index(
87
- api_key,
88
- file_src,
89
- max_input_size=4096,
90
- num_outputs=5,
91
- max_chunk_overlap=20,
92
- chunk_size_limit=600,
93
- embedding_limit=None,
94
- separator=" "
95
  ):
96
  from langchain.chat_models import ChatOpenAI
97
- from llama_index import GPTSimpleVectorIndex, ServiceContext
 
98
 
99
- os.environ["OPENAI_API_KEY"] = api_key
 
 
 
 
100
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
101
  embedding_limit = None if embedding_limit == 0 else embedding_limit
102
  separator = " " if separator == "" else separator
103
 
104
- llm_predictor = LLMPredictor(
105
- llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
 
 
 
 
 
106
  )
107
- prompt_helper = PromptHelper(max_input_size = max_input_size, num_output = num_outputs, max_chunk_overlap = max_chunk_overlap, embedding_limit=embedding_limit, chunk_size_limit=600, separator=separator)
108
  index_name = get_index_name(file_src)
109
  if os.path.exists(f"./index/{index_name}.json"):
110
  logging.info("找到了缓存的索引文件,加载中……")
@@ -112,11 +133,19 @@ def construct_index(
112
  else:
113
  try:
114
  documents = get_documents(file_src)
 
 
 
 
115
  logging.info("构建索引中……")
116
  with retrieve_proxy():
117
- service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit=chunk_size_limit)
 
 
 
 
118
  index = GPTSimpleVectorIndex.from_documents(
119
- documents, service_context=service_context
120
  )
121
  logging.debug("索引构建完成!")
122
  os.makedirs("./index", exist_ok=True)
 
15
 
16
  from modules.presets import *
17
  from modules.utils import *
18
+ from modules.config import local_embedding
19
+
20
 
21
  def get_index_name(file_src):
22
  file_paths = [x.name for x in file_src]
 
30
 
31
  return md5_hash.hexdigest()
32
 
33
+
34
  def block_split(text):
35
  blocks = []
36
  while len(text) > 0:
 
38
  text = text[1000:]
39
  return blocks
40
 
41
+
42
  def get_documents(file_src):
43
  documents = []
44
  logging.debug("Loading documents...")
 
48
  filename = os.path.basename(filepath)
49
  file_type = os.path.splitext(filepath)[1]
50
  logging.info(f"loading file: {filename}")
51
+ try:
52
+ if file_type == ".pdf":
53
+ logging.debug("Loading PDF...")
54
+ try:
55
+ from modules.pdf_func import parse_pdf
56
+ from modules.config import advance_docs
57
+
58
+ two_column = advance_docs["pdf"].get("two_column", False)
59
+ pdftext = parse_pdf(filepath, two_column).text
60
+ except:
61
+ pdftext = ""
62
+ with open(filepath, "rb") as pdfFileObj:
63
+ pdfReader = PyPDF2.PdfReader(pdfFileObj)
64
+ for page in tqdm(pdfReader.pages):
65
+ pdftext += page.extract_text()
66
+ text_raw = pdftext
67
+ elif file_type == ".docx":
68
+ logging.debug("Loading Word...")
69
+ DocxReader = download_loader("DocxReader")
70
+ loader = DocxReader()
71
+ text_raw = loader.load_data(file=filepath)[0].text
72
+ elif file_type == ".epub":
73
+ logging.debug("Loading EPUB...")
74
+ EpubReader = download_loader("EpubReader")
75
+ loader = EpubReader()
76
+ text_raw = loader.load_data(file=filepath)[0].text
77
+ elif file_type == ".xlsx":
78
+ logging.debug("Loading Excel...")
79
+ text_list = excel_to_string(filepath)
80
+ for elem in text_list:
81
+ documents.append(Document(elem))
82
+ continue
83
+ else:
84
+ logging.debug("Loading text file...")
85
+ with open(filepath, "r", encoding="utf-8") as f:
86
+ text_raw = f.read()
87
+ except Exception as e:
88
+ logging.error(f"Error loading file: {filename}")
89
+ pass
90
  text = add_space(text_raw)
91
  # text = block_split(text)
92
  # documents += text
 
96
 
97
 
98
  def construct_index(
99
+ api_key,
100
+ file_src,
101
+ max_input_size=4096,
102
+ num_outputs=5,
103
+ max_chunk_overlap=20,
104
+ chunk_size_limit=600,
105
+ embedding_limit=None,
106
+ separator=" ",
107
  ):
108
  from langchain.chat_models import ChatOpenAI
109
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
110
+ from llama_index import GPTSimpleVectorIndex, ServiceContext, LangchainEmbedding, OpenAIEmbedding
111
 
112
+ if api_key:
113
+ os.environ["OPENAI_API_KEY"] = api_key
114
+ else:
115
+ # 由于一个依赖的愚蠢的设计,这里必须要有一个API KEY
116
+ os.environ["OPENAI_API_KEY"] = "sk-xxxxxxx"
117
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
118
  embedding_limit = None if embedding_limit == 0 else embedding_limit
119
  separator = " " if separator == "" else separator
120
 
121
+ prompt_helper = PromptHelper(
122
+ max_input_size=max_input_size,
123
+ num_output=num_outputs,
124
+ max_chunk_overlap=max_chunk_overlap,
125
+ embedding_limit=embedding_limit,
126
+ chunk_size_limit=600,
127
+ separator=separator,
128
  )
 
129
  index_name = get_index_name(file_src)
130
  if os.path.exists(f"./index/{index_name}.json"):
131
  logging.info("找到了缓存的索引文件,加载中……")
 
133
  else:
134
  try:
135
  documents = get_documents(file_src)
136
+ if local_embedding:
137
+ embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
138
+ else:
139
+ embed_model = OpenAIEmbedding()
140
  logging.info("构建索引中……")
141
  with retrieve_proxy():
142
+ service_context = ServiceContext.from_defaults(
143
+ prompt_helper=prompt_helper,
144
+ chunk_size_limit=chunk_size_limit,
145
+ embed_model=embed_model,
146
+ )
147
  index = GPTSimpleVectorIndex.from_documents(
148
+ documents, service_context=service_context
149
  )
150
  logging.debug("索引构建完成!")
151
  os.makedirs("./index", exist_ok=True)
modules/models.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, List
3
+
4
+ import logging
5
+ import json
6
+ import commentjson as cjson
7
+ import os
8
+ import sys
9
+ import requests
10
+ import urllib3
11
+ import platform
12
+
13
+ from tqdm import tqdm
14
+ import colorama
15
+ from duckduckgo_search import ddg
16
+ import asyncio
17
+ import aiohttp
18
+ from enum import Enum
19
+ import uuid
20
+
21
+ from .presets import *
22
+ from .llama_func import *
23
+ from .utils import *
24
+ from . import shared
25
+ from .config import retrieve_proxy
26
+ from modules import config
27
+ from .base_model import BaseLLMModel, ModelType
28
+
29
+
30
+ class OpenAIClient(BaseLLMModel):
31
+ def __init__(
32
+ self,
33
+ model_name,
34
+ api_key,
35
+ system_prompt=INITIAL_SYSTEM_PROMPT,
36
+ temperature=1.0,
37
+ top_p=1.0,
38
+ ) -> None:
39
+ super().__init__(
40
+ model_name=model_name,
41
+ temperature=temperature,
42
+ top_p=top_p,
43
+ system_prompt=system_prompt,
44
+ )
45
+ self.api_key = api_key
46
+ self.need_api_key = True
47
+ self._refresh_header()
48
+
49
+ def get_answer_stream_iter(self):
50
+ response = self._get_response(stream=True)
51
+ if response is not None:
52
+ iter = self._decode_chat_response(response)
53
+ partial_text = ""
54
+ for i in iter:
55
+ partial_text += i
56
+ yield partial_text
57
+ else:
58
+ yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
59
+
60
+ def get_answer_at_once(self):
61
+ response = self._get_response()
62
+ response = json.loads(response.text)
63
+ content = response["choices"][0]["message"]["content"]
64
+ total_token_count = response["usage"]["total_tokens"]
65
+ return content, total_token_count
66
+
67
+ def count_token(self, user_input):
68
+ input_token_count = count_token(construct_user(user_input))
69
+ if self.system_prompt is not None and len(self.all_token_counts) == 0:
70
+ system_prompt_token_count = count_token(
71
+ construct_system(self.system_prompt)
72
+ )
73
+ return input_token_count + system_prompt_token_count
74
+ return input_token_count
75
+
76
+ def billing_info(self):
77
+ try:
78
+ curr_time = datetime.datetime.now()
79
+ last_day_of_month = get_last_day_of_month(
80
+ curr_time).strftime("%Y-%m-%d")
81
+ first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
82
+ usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
83
+ try:
84
+ usage_data = self._get_billing_data(usage_url)
85
+ except Exception as e:
86
+ logging.error(f"获取API使用情况失败:" + str(e))
87
+ return f"**获取API使用情况失败**"
88
+ rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100)
89
+ return f"**本月使用金额** \u3000 ${rounded_usage}"
90
+ except requests.exceptions.ConnectTimeout:
91
+ status_text = (
92
+ STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
93
+ )
94
+ return status_text
95
+ except requests.exceptions.ReadTimeout:
96
+ status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
97
+ return status_text
98
+ except Exception as e:
99
+ logging.error(f"获取API使用情况失败:" + str(e))
100
+ return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG
101
+
102
+ def set_token_upper_limit(self, new_upper_limit):
103
+ pass
104
+
105
+ def set_key(self, new_access_key):
106
+ self.api_key = new_access_key.strip()
107
+ self._refresh_header()
108
+ msg = f"API密钥更改为了{hide_middle_chars(self.api_key)}"
109
+ logging.info(msg)
110
+ return msg
111
+
112
+ @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
113
+ def _get_response(self, stream=False):
114
+ openai_api_key = self.api_key
115
+ system_prompt = self.system_prompt
116
+ history = self.history
117
+ logging.debug(colorama.Fore.YELLOW +
118
+ f"{history}" + colorama.Fore.RESET)
119
+ headers = {
120
+ "Content-Type": "application/json",
121
+ "Authorization": f"Bearer {openai_api_key}",
122
+ }
123
+
124
+ if system_prompt is not None:
125
+ history = [construct_system(system_prompt), *history]
126
+
127
+ payload = {
128
+ "model": self.model_name,
129
+ "messages": history,
130
+ "temperature": self.temperature,
131
+ "top_p": self.top_p,
132
+ "n": self.n_choices,
133
+ "stream": stream,
134
+ "presence_penalty": self.presence_penalty,
135
+ "frequency_penalty": self.frequency_penalty,
136
+ }
137
+
138
+ if self.max_generation_token is not None:
139
+ payload["max_tokens"] = self.max_generation_token
140
+ if self.stop_sequence is not None:
141
+ payload["stop"] = self.stop_sequence
142
+ if self.logit_bias is not None:
143
+ payload["logit_bias"] = self.logit_bias
144
+ if self.user_identifier is not None:
145
+ payload["user"] = self.user_identifier
146
+
147
+ if stream:
148
+ timeout = TIMEOUT_STREAMING
149
+ else:
150
+ timeout = TIMEOUT_ALL
151
+
152
+ # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
153
+ if shared.state.completion_url != COMPLETION_URL:
154
+ logging.info(f"使用自定义API URL: {shared.state.completion_url}")
155
+
156
+ with retrieve_proxy():
157
+ try:
158
+ response = requests.post(
159
+ shared.state.completion_url,
160
+ headers=headers,
161
+ json=payload,
162
+ stream=stream,
163
+ timeout=timeout,
164
+ )
165
+ except:
166
+ return None
167
+ return response
168
+
169
+ def _refresh_header(self):
170
+ self.headers = {
171
+ "Content-Type": "application/json",
172
+ "Authorization": f"Bearer {self.api_key}",
173
+ }
174
+
175
+ def _get_billing_data(self, billing_url):
176
+ with retrieve_proxy():
177
+ response = requests.get(
178
+ billing_url,
179
+ headers=self.headers,
180
+ timeout=TIMEOUT_ALL,
181
+ )
182
+
183
+ if response.status_code == 200:
184
+ data = response.json()
185
+ return data
186
+ else:
187
+ raise Exception(
188
+ f"API request failed with status code {response.status_code}: {response.text}"
189
+ )
190
+
191
+ def _decode_chat_response(self, response):
192
+ error_msg = ""
193
+ for chunk in response.iter_lines():
194
+ if chunk:
195
+ chunk = chunk.decode()
196
+ chunk_length = len(chunk)
197
+ try:
198
+ chunk = json.loads(chunk[6:])
199
+ except json.JSONDecodeError:
200
+ print(f"JSON解析错误,收到的内容: {chunk}")
201
+ error_msg += chunk
202
+ continue
203
+ if chunk_length > 6 and "delta" in chunk["choices"][0]:
204
+ if chunk["choices"][0]["finish_reason"] == "stop":
205
+ break
206
+ try:
207
+ yield chunk["choices"][0]["delta"]["content"]
208
+ except Exception as e:
209
+ # logging.error(f"Error: {e}")
210
+ continue
211
+ if error_msg:
212
+ raise Exception(error_msg)
213
+
214
+
215
+ class ChatGLM_Client(BaseLLMModel):
216
+ def __init__(self, model_name) -> None:
217
+ super().__init__(model_name=model_name)
218
+ from transformers import AutoTokenizer, AutoModel
219
+ import torch
220
+ global CHATGLM_TOKENIZER, CHATGLM_MODEL
221
+ if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
222
+ system_name = platform.system()
223
+ model_path = None
224
+ if os.path.exists("models"):
225
+ model_dirs = os.listdir("models")
226
+ if model_name in model_dirs:
227
+ model_path = f"models/{model_name}"
228
+ if model_path is not None:
229
+ model_source = model_path
230
+ else:
231
+ model_source = f"THUDM/{model_name}"
232
+ CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(
233
+ model_source, trust_remote_code=True
234
+ )
235
+ quantified = False
236
+ if "int4" in model_name:
237
+ quantified = True
238
+ if quantified:
239
+ model = AutoModel.from_pretrained(
240
+ model_source, trust_remote_code=True
241
+ ).half()
242
+ else:
243
+ model = AutoModel.from_pretrained(
244
+ model_source, trust_remote_code=True
245
+ ).half()
246
+ if torch.cuda.is_available():
247
+ # run on CUDA
248
+ logging.info("CUDA is available, using CUDA")
249
+ model = model.cuda()
250
+ # mps加速还存在一些问题,暂时不使用
251
+ elif system_name == "Darwin" and model_path is not None and not quantified:
252
+ logging.info("Running on macOS, using MPS")
253
+ # running on macOS and model already downloaded
254
+ model = model.to("mps")
255
+ else:
256
+ logging.info("GPU is not available, using CPU")
257
+ model = model.eval()
258
+ CHATGLM_MODEL = model
259
+
260
+ def _get_glm_style_input(self):
261
+ history = [x["content"] for x in self.history]
262
+ query = history.pop()
263
+ logging.debug(colorama.Fore.YELLOW +
264
+ f"{history}" + colorama.Fore.RESET)
265
+ assert (
266
+ len(history) % 2 == 0
267
+ ), f"History should be even length. current history is: {history}"
268
+ history = [[history[i], history[i + 1]]
269
+ for i in range(0, len(history), 2)]
270
+ return history, query
271
+
272
+ def get_answer_at_once(self):
273
+ history, query = self._get_glm_style_input()
274
+ response, _ = CHATGLM_MODEL.chat(
275
+ CHATGLM_TOKENIZER, query, history=history)
276
+ return response, len(response)
277
+
278
+ def get_answer_stream_iter(self):
279
+ history, query = self._get_glm_style_input()
280
+ for response, history in CHATGLM_MODEL.stream_chat(
281
+ CHATGLM_TOKENIZER,
282
+ query,
283
+ history,
284
+ max_length=self.token_upper_limit,
285
+ top_p=self.top_p,
286
+ temperature=self.temperature,
287
+ ):
288
+ yield response
289
+
290
+
291
+ class LLaMA_Client(BaseLLMModel):
292
+ def __init__(
293
+ self,
294
+ model_name,
295
+ lora_path=None,
296
+ ) -> None:
297
+ super().__init__(model_name=model_name)
298
+ from lmflow.datasets.dataset import Dataset
299
+ from lmflow.pipeline.auto_pipeline import AutoPipeline
300
+ from lmflow.models.auto_model import AutoModel
301
+ from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments
302
+
303
+ self.max_generation_token = 1000
304
+ self.end_string = "\n\n"
305
+ # We don't need input data
306
+ data_args = DatasetArguments(dataset_path=None)
307
+ self.dataset = Dataset(data_args)
308
+ self.system_prompt = ""
309
+
310
+ global LLAMA_MODEL, LLAMA_INFERENCER
311
+ if LLAMA_MODEL is None or LLAMA_INFERENCER is None:
312
+ model_path = None
313
+ if os.path.exists("models"):
314
+ model_dirs = os.listdir("models")
315
+ if model_name in model_dirs:
316
+ model_path = f"models/{model_name}"
317
+ if model_path is not None:
318
+ model_source = model_path
319
+ else:
320
+ model_source = f"decapoda-research/{model_name}"
321
+ # raise Exception(f"models目录下没有这个模型: {model_name}")
322
+ if lora_path is not None:
323
+ lora_path = f"lora/{lora_path}"
324
+ model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None,
325
+ use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
326
+ pipeline_args = InferencerArguments(
327
+ local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
328
+
329
+ with open(pipeline_args.deepspeed, "r") as f:
330
+ ds_config = json.load(f)
331
+ LLAMA_MODEL = AutoModel.get_model(
332
+ model_args,
333
+ tune_strategy="none",
334
+ ds_config=ds_config,
335
+ )
336
+ LLAMA_INFERENCER = AutoPipeline.get_pipeline(
337
+ pipeline_name="inferencer",
338
+ model_args=model_args,
339
+ data_args=data_args,
340
+ pipeline_args=pipeline_args,
341
+ )
342
+ # Chats
343
+ # model_name = model_args.model_name_or_path
344
+ # if model_args.lora_model_path is not None:
345
+ # model_name += f" + {model_args.lora_model_path}"
346
+
347
+ # context = (
348
+ # "You are a helpful assistant who follows the given instructions"
349
+ # " unconditionally."
350
+ # )
351
+
352
+ def _get_llama_style_input(self):
353
+ history = []
354
+ instruction = ""
355
+ if self.system_prompt:
356
+ instruction = (f"Instruction: {self.system_prompt}\n")
357
+ for x in self.history:
358
+ if x["role"] == "user":
359
+ history.append(f"{instruction}Input: {x['content']}")
360
+ else:
361
+ history.append(f"Output: {x['content']}")
362
+ context = "\n\n".join(history)
363
+ context += "\n\nOutput: "
364
+ return context
365
+
366
+ def get_answer_at_once(self):
367
+ context = self._get_llama_style_input()
368
+
369
+ input_dataset = self.dataset.from_dict(
370
+ {"type": "text_only", "instances": [{"text": context}]}
371
+ )
372
+
373
+ output_dataset = LLAMA_INFERENCER.inference(
374
+ model=LLAMA_MODEL,
375
+ dataset=input_dataset,
376
+ max_new_tokens=self.max_generation_token,
377
+ temperature=self.temperature,
378
+ )
379
+
380
+ response = output_dataset.to_dict()["instances"][0]["text"]
381
+ return response, len(response)
382
+
383
+ def get_answer_stream_iter(self):
384
+ context = self._get_llama_style_input()
385
+ partial_text = ""
386
+ step = 1
387
+ for _ in range(0, self.max_generation_token, step):
388
+ input_dataset = self.dataset.from_dict(
389
+ {"type": "text_only", "instances": [
390
+ {"text": context + partial_text}]}
391
+ )
392
+ output_dataset = LLAMA_INFERENCER.inference(
393
+ model=LLAMA_MODEL,
394
+ dataset=input_dataset,
395
+ max_new_tokens=step,
396
+ temperature=self.temperature,
397
+ )
398
+ response = output_dataset.to_dict()["instances"][0]["text"]
399
+ if response == "" or response == self.end_string:
400
+ break
401
+ partial_text += response
402
+ yield partial_text
403
+
404
+
405
+ class XMBot_Client(BaseLLMModel):
406
+ def __init__(self, api_key):
407
+ super().__init__(model_name="xmbot")
408
+ self.api_key = api_key
409
+ self.session_id = None
410
+ self.reset()
411
+ self.image_bytes = None
412
+ self.image_path = None
413
+ self.xm_history = []
414
+ self.url = "https://xmbot.net/web"
415
+
416
+ def reset(self):
417
+ self.session_id = str(uuid.uuid4())
418
+ return [], "已重置"
419
+
420
+ def try_read_image(self, filepath):
421
+ import base64
422
+
423
+ def is_image_file(filepath):
424
+ # 判断文件是否为图片
425
+ valid_image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
426
+ file_extension = os.path.splitext(filepath)[1].lower()
427
+ return file_extension in valid_image_extensions
428
+
429
+ def read_image_as_bytes(filepath):
430
+ # 读取图片文件并返回比特流
431
+ with open(filepath, "rb") as f:
432
+ image_bytes = f.read()
433
+ return image_bytes
434
+
435
+ if is_image_file(filepath):
436
+ logging.info(f"读取图片文件: {filepath}")
437
+ image_bytes = read_image_as_bytes(filepath)
438
+ base64_encoded_image = base64.b64encode(image_bytes).decode()
439
+ self.image_bytes = base64_encoded_image
440
+ self.image_path = filepath
441
+ else:
442
+ self.image_bytes = None
443
+ self.image_path = None
444
+
445
+ def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
446
+ fake_inputs = real_inputs
447
+ display_append = ""
448
+ limited_context = False
449
+ return limited_context, fake_inputs, display_append, real_inputs, chatbot
450
+
451
+ def handle_file_upload(self, files, chatbot):
452
+ """if the model accepts multi modal input, implement this function"""
453
+ if files:
454
+ for file in files:
455
+ if file.name:
456
+ logging.info(f"尝试读取图像: {file.name}")
457
+ self.try_read_image(file.name)
458
+ if self.image_path is not None:
459
+ chatbot = chatbot + [((self.image_path,), None)]
460
+ if self.image_bytes is not None:
461
+ logging.info("使用图片作为输入")
462
+ conv_id = str(uuid.uuid4())
463
+ data = {
464
+ "user_id": self.api_key,
465
+ "session_id": self.session_id,
466
+ "uuid": conv_id,
467
+ "data_type": "imgbase64",
468
+ "data": self.image_bytes
469
+ }
470
+ response = requests.post(self.url, json=data)
471
+ response = json.loads(response.text)
472
+ logging.info(f"图片回复: {response['data']}")
473
+ return None, chatbot, None
474
+
475
+ def get_answer_at_once(self):
476
+ question = self.history[-1]["content"]
477
+ conv_id = str(uuid.uuid4())
478
+ data = {
479
+ "user_id": self.api_key,
480
+ "session_id": self.session_id,
481
+ "uuid": conv_id,
482
+ "data_type": "text",
483
+ "data": question
484
+ }
485
+ response = requests.post(self.url, json=data)
486
+ response = json.loads(response.text)
487
+ return response["data"], len(response["data"])
488
+
489
+
490
+
491
+
492
+ def get_model(
493
+ model_name,
494
+ lora_model_path=None,
495
+ access_key=None,
496
+ temperature=None,
497
+ top_p=None,
498
+ system_prompt=None,
499
+ ) -> BaseLLMModel:
500
+ msg = f"模型设置为了: {model_name}"
501
+ model_type = ModelType.get_type(model_name)
502
+ lora_selector_visibility = False
503
+ lora_choices = []
504
+ dont_change_lora_selector = False
505
+ if model_type != ModelType.OpenAI:
506
+ config.local_embedding = True
507
+ # del current_model.model
508
+ model = None
509
+ try:
510
+ if model_type == ModelType.OpenAI:
511
+ logging.info(f"正在加载OpenAI模型: {model_name}")
512
+ model = OpenAIClient(
513
+ model_name=model_name,
514
+ api_key=access_key,
515
+ system_prompt=system_prompt,
516
+ temperature=temperature,
517
+ top_p=top_p,
518
+ )
519
+ elif model_type == ModelType.ChatGLM:
520
+ logging.info(f"正在加载ChatGLM模型: {model_name}")
521
+ model = ChatGLM_Client(model_name)
522
+ elif model_type == ModelType.LLaMA and lora_model_path == "":
523
+ msg = f"现在请为 {model_name} 选择LoRA模型"
524
+ logging.info(msg)
525
+ lora_selector_visibility = True
526
+ if os.path.isdir("lora"):
527
+ lora_choices = get_file_names(
528
+ "lora", plain=True, filetypes=[""])
529
+ lora_choices = ["No LoRA"] + lora_choices
530
+ elif model_type == ModelType.LLaMA and lora_model_path != "":
531
+ logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
532
+ dont_change_lora_selector = True
533
+ if lora_model_path == "No LoRA":
534
+ lora_model_path = None
535
+ msg += " + No LoRA"
536
+ else:
537
+ msg += f" + {lora_model_path}"
538
+ model = LLaMA_Client(model_name, lora_model_path)
539
+ elif model_type == ModelType.XMBot:
540
+ model = XMBot_Client(api_key=access_key)
541
+ elif model_type == ModelType.Unknown:
542
+ raise ValueError(f"未知模型: {model_name}")
543
+ logging.info(msg)
544
+ except Exception as e:
545
+ logging.error(e)
546
+ msg = f"{STANDARD_ERROR_MSG}: {e}"
547
+ if dont_change_lora_selector:
548
+ return model, msg
549
+ else:
550
+ return model, msg, gr.Dropdown.update(choices=lora_choices, visible=lora_selector_visibility)
551
+
552
+
553
+ if __name__ == "__main__":
554
+ with open("config.json", "r") as f:
555
+ openai_api_key = cjson.load(f)["openai_api_key"]
556
+ # set logging level to debug
557
+ logging.basicConfig(level=logging.DEBUG)
558
+ # client = ModelManager(model_name="gpt-3.5-turbo", access_key=openai_api_key)
559
+ client = get_model(model_name="chatglm-6b-int4")
560
+ chatbot = []
561
+ stream = False
562
+ # 测试账单功能
563
+ logging.info(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET)
564
+ logging.info(client.billing_info())
565
+ # 测试问答
566
+ logging.info(colorama.Back.GREEN + "测试问答" + colorama.Back.RESET)
567
+ question = "巴黎是中国的首都吗?"
568
+ for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
569
+ logging.info(i)
570
+ logging.info(f"测试问答后history : {client.history}")
571
+ # 测试记忆力
572
+ logging.info(colorama.Back.GREEN + "测试记忆力" + colorama.Back.RESET)
573
+ question = "我刚刚问了你什么问题?"
574
+ for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
575
+ logging.info(i)
576
+ logging.info(f"测试记忆力后history : {client.history}")
577
+ # 测试重试功能
578
+ logging.info(colorama.Back.GREEN + "测试重试功能" + colorama.Back.RESET)
579
+ for i in client.retry(chatbot=chatbot, stream=stream):
580
+ logging.info(i)
581
+ logging.info(f"重试后history : {client.history}")
582
+ # # 测试总结功能
583
+ # print(colorama.Back.GREEN + "测试总结功能" + colorama.Back.RESET)
584
+ # chatbot, msg = client.reduce_token_size(chatbot=chatbot)
585
+ # print(chatbot, msg)
586
+ # print(f"总结后history: {client.history}")
modules/overwrites.py CHANGED
@@ -4,6 +4,7 @@ import logging
4
  from llama_index import Prompt
5
  from typing import List, Tuple
6
  import mdtex2html
 
7
 
8
  from modules.presets import *
9
  from modules.llama_func import *
@@ -20,23 +21,60 @@ def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[st
20
 
21
 
22
  def postprocess(
23
- self, y: List[Tuple[str | None, str | None]]
24
- ) -> List[Tuple[str | None, str | None]]:
25
- """
26
- Parameters:
27
- y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format.
28
- Returns:
29
- List of tuples representing the message and response. Each message and response will be a string of HTML.
30
- """
31
- if y is None or y == []:
32
- return []
33
- user, bot = y[-1]
34
- if not detect_converted_mark(user):
35
- user = convert_asis(user)
36
- if not detect_converted_mark(bot):
37
- bot = convert_mdtext(bot)
38
- y[-1] = (user, bot)
39
- return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  with open("./assets/custom.js", "r", encoding="utf-8") as f, open("./assets/Kelpy-Codos.js", "r", encoding="utf-8") as f2:
42
  customJS = f.read()
 
4
  from llama_index import Prompt
5
  from typing import List, Tuple
6
  import mdtex2html
7
+ from gradio_client import utils as client_utils
8
 
9
  from modules.presets import *
10
  from modules.llama_func import *
 
21
 
22
 
23
  def postprocess(
24
+ self,
25
+ y: List[List[str | Tuple[str] | Tuple[str, str] | None] | Tuple],
26
+ ) -> List[List[str | Dict | None]]:
27
+ """
28
+ Parameters:
29
+ y: List of lists representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
30
+ Returns:
31
+ List of lists representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information. Or None if the message is not to be displayed.
32
+ """
33
+ if y is None:
34
+ return []
35
+ processed_messages = []
36
+ for message_pair in y:
37
+ assert isinstance(
38
+ message_pair, (tuple, list)
39
+ ), f"Expected a list of lists or list of tuples. Received: {message_pair}"
40
+ assert (
41
+ len(message_pair) == 2
42
+ ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
43
+
44
+ processed_messages.append(
45
+ [
46
+ self._postprocess_chat_messages(message_pair[0], "user"),
47
+ self._postprocess_chat_messages(message_pair[1], "bot"),
48
+ ]
49
+ )
50
+ return processed_messages
51
+
52
+ def postprocess_chat_messages(
53
+ self, chat_message: str | Tuple | List | None, message_type: str
54
+ ) -> str | Dict | None:
55
+ if chat_message is None:
56
+ return None
57
+ elif isinstance(chat_message, (tuple, list)):
58
+ filepath = chat_message[0]
59
+ mime_type = client_utils.get_mimetype(filepath)
60
+ filepath = self.make_temp_copy_if_needed(filepath)
61
+ return {
62
+ "name": filepath,
63
+ "mime_type": mime_type,
64
+ "alt_text": chat_message[1] if len(chat_message) > 1 else None,
65
+ "data": None, # These last two fields are filled in by the frontend
66
+ "is_file": True,
67
+ }
68
+ elif isinstance(chat_message, str):
69
+ if message_type == "bot":
70
+ if not detect_converted_mark(chat_message):
71
+ chat_message = convert_mdtext(chat_message)
72
+ elif message_type == "user":
73
+ if not detect_converted_mark(chat_message):
74
+ chat_message = convert_asis(chat_message)
75
+ return chat_message
76
+ else:
77
+ raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
78
 
79
  with open("./assets/custom.js", "r", encoding="utf-8") as f, open("./assets/Kelpy-Codos.js", "r", encoding="utf-8") as f2:
80
  customJS = f.read()
modules/presets.py CHANGED
@@ -1,89 +1,122 @@
1
  # -*- coding:utf-8 -*-
2
- import gradio as gr
3
  from pathlib import Path
4
 
 
 
 
 
 
 
 
5
  # ChatGPT 设置
6
- initial_prompt = "You are a helpful assistant."
7
  API_HOST = "api.openai.com"
8
  COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
9
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
10
  USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
11
  HISTORY_DIR = Path("history")
 
12
  TEMPLATES_DIR = "templates"
13
 
14
  # 错误信息
15
- standard_error_msg = "☹️发生了错误:" # 错误信息的标准前缀
16
- error_retrieve_prompt = "请检查网络连接,或者API-Key是否有效。" # 获取对话时发生错误
17
- connection_timeout_prompt = "连接超时,无法获取对话。" # 连接超时
18
- read_timeout_prompt = "读取超时,无法获取对话。" # 读取超时
19
- proxy_error_prompt = "代理错误,无法获取对话。" # 代理错误
20
- ssl_error_prompt = "SSL错误,无法获取对话。" # SSL 错误
21
- no_apikey_msg = "API key长度不是51位,请检查是否输入正确。" # API key 长度不足 51 位
22
- no_input_msg = "请输入对话内容。" # 未输入对话内容
23
-
24
- timeout_streaming = 10 # 流式对话时的超时时间
25
- timeout_all = 200 # 非流式对话时的超时时间
26
- enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
 
 
27
  HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
28
  CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
29
 
30
  SIM_K = 5
31
  INDEX_QUERY_TEMPRATURE = 1.0
32
 
33
- title = """<h1 align="left" style="min-width:200px; margin-top:6px; white-space: nowrap;">川虎ChatGPT 🚀</h1>"""
34
- description = """\
35
  <div align="center" style="margin:16px 0">
36
 
37
  由Bilibili [土川虎虎虎](https://space.bilibili.com/29125536) 和 [明昭MZhao](https://space.bilibili.com/24807452)开发
38
 
39
- 访问川虎ChatGPT的 [GitHub项目](https://github.com/GaiZhenbiao/ChuanhuChatGPT) 下载最新版脚本
40
 
41
- 此App使用 `gpt-3.5-turbo` 大语言模型
42
  </div>
43
  """
44
 
45
- footer = """\
46
- <div class="versions">{versions}</div>
 
 
 
 
 
 
 
 
47
  """
48
 
49
- summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
50
 
51
- MODELS = [
52
  "gpt-3.5-turbo",
53
  "gpt-3.5-turbo-0301",
54
  "gpt-4",
55
  "gpt-4-0314",
56
  "gpt-4-32k",
57
  "gpt-4-32k-0314",
58
- ] # 可选的模型
59
-
60
- MODEL_SOFT_TOKEN_LIMIT = {
61
- "gpt-3.5-turbo": {
62
- "streaming": 3500,
63
- "all": 3500
64
- },
65
- "gpt-3.5-turbo-0301": {
66
- "streaming": 3500,
67
- "all": 3500
68
- },
69
- "gpt-4": {
70
- "streaming": 7500,
71
- "all": 7500
72
- },
73
- "gpt-4-0314": {
74
- "streaming": 7500,
75
- "all": 7500
76
- },
77
- "gpt-4-32k": {
78
- "streaming": 31000,
79
- "all": 31000
80
- },
81
- "gpt-4-32k-0314": {
82
- "streaming": 31000,
83
- "all": 31000
84
- }
 
 
 
 
 
 
 
 
 
 
 
 
85
  }
86
 
 
 
 
 
87
  REPLY_LANGUAGES = [
88
  "简体中文",
89
  "繁體中文",
 
1
  # -*- coding:utf-8 -*-
2
+ import os
3
  from pathlib import Path
4
 
5
+ import gradio as gr
6
+
7
+ CHATGLM_MODEL = None
8
+ CHATGLM_TOKENIZER = None
9
+ LLAMA_MODEL = None
10
+ LLAMA_INFERENCER = None
11
+
12
  # ChatGPT 设置
13
+ INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
14
  API_HOST = "api.openai.com"
15
  COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
16
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
17
  USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
18
  HISTORY_DIR = Path("history")
19
+ HISTORY_DIR = "history"
20
  TEMPLATES_DIR = "templates"
21
 
22
  # 错误信息
23
+ STANDARD_ERROR_MSG = "☹️发生了错误:" # 错误信息的标准前缀
24
+ GENERAL_ERROR_MSG = "获取对话时发生错误,请查看后台日志"
25
+ ERROR_RETRIEVE_MSG = "请检查网络连接,或者API-Key是否有效。"
26
+ CONNECTION_TIMEOUT_MSG = "连接超时,无法获取对话。" # 连接超时
27
+ READ_TIMEOUT_MSG = "读取超时,无法获取对话。" # 读取超时
28
+ PROXY_ERROR_MSG = "代理错误,无法获取对话。" # 代理错误
29
+ SSL_ERROR_PROMPT = "SSL错误,无法获取对话。" # SSL 错误
30
+ NO_APIKEY_MSG = "API key为空,请检查是否输入正确。" # API key 长度不足 51 位
31
+ NO_INPUT_MSG = "请输入对话内容。" # 未输入对话内容
32
+ BILLING_NOT_APPLICABLE_MSG = "账单信息不适用" # 本地运行的模型返回的账单信息
33
+
34
+ TIMEOUT_STREAMING = 60 # 流式对话时的超时时间
35
+ TIMEOUT_ALL = 200 # 非流式对话时的超时时间
36
+ ENABLE_STREAMING_OPTION = True # 是否启用选择选择是否实时显示回答的勾选框
37
  HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
38
  CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
39
 
40
  SIM_K = 5
41
  INDEX_QUERY_TEMPRATURE = 1.0
42
 
43
+ CHUANHU_TITLE = """<h1 align="left">川虎Chat 🚀</h1>"""
44
+ CHUANHU_DESCRIPTION = """\
45
  <div align="center" style="margin:16px 0">
46
 
47
  由Bilibili [土川虎虎虎](https://space.bilibili.com/29125536) 和 [明昭MZhao](https://space.bilibili.com/24807452)开发
48
 
49
+ 访问川虎Chat的 [GitHub项目](https://github.com/GaiZhenbiao/ChuanhuChatGPT) 下载最新版脚本
50
 
 
51
  </div>
52
  """
53
 
54
+ FOOTER = """<div class="versions">{versions}</div>"""
55
+
56
+ APPEARANCE_SWITCHER = """
57
+ <div style="display: flex; justify-content: space-between;">
58
+ <span style="margin-top: 4px !important;">切换亮暗色主题</span>
59
+ <span><label class="apSwitch" for="checkbox">
60
+ <input type="checkbox" id="checkbox">
61
+ <div class="apSlider"></div>
62
+ </label></span>
63
+ </div>
64
  """
65
 
66
+ SUMMARIZE_PROMPT = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
67
 
68
+ ONLINE_MODELS = [
69
  "gpt-3.5-turbo",
70
  "gpt-3.5-turbo-0301",
71
  "gpt-4",
72
  "gpt-4-0314",
73
  "gpt-4-32k",
74
  "gpt-4-32k-0314",
75
+ "xmbot",
76
+ ]
77
+
78
+ LOCAL_MODELS = [
79
+ "chatglm-6b",
80
+ "chatglm-6b-int4",
81
+ "chatglm-6b-int4-qe",
82
+ "llama-7b-hf",
83
+ "llama-7b-hf-int4",
84
+ "llama-7b-hf-int8",
85
+ "llama-13b-hf",
86
+ "llama-13b-hf-int4",
87
+ "llama-30b-hf",
88
+ "llama-30b-hf-int4",
89
+ "llama-65b-hf"
90
+ ]
91
+
92
+ if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
93
+ MODELS = ONLINE_MODELS
94
+ else:
95
+ MODELS = ONLINE_MODELS + LOCAL_MODELS
96
+
97
+ DEFAULT_MODEL = 0
98
+
99
+ os.makedirs("models", exist_ok=True)
100
+ os.makedirs("lora", exist_ok=True)
101
+ os.makedirs("history", exist_ok=True)
102
+ for dir_name in os.listdir("models"):
103
+ if os.path.isdir(os.path.join("models", dir_name)):
104
+ if dir_name not in MODELS:
105
+ MODELS.append(dir_name)
106
+
107
+ MODEL_TOKEN_LIMIT = {
108
+ "gpt-3.5-turbo": 4096,
109
+ "gpt-3.5-turbo-0301": 4096,
110
+ "gpt-4": 8192,
111
+ "gpt-4-0314": 8192,
112
+ "gpt-4-32k": 32768,
113
+ "gpt-4-32k-0314": 32768
114
  }
115
 
116
+ TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
117
+ DEFAULT_TOKEN_LIMIT = 3000 # 默认的token上限
118
+ REDUCE_TOKEN_FACTOR = 0.5 # 与模型token上限想乘,得到目标token数。减少token占用时,将token占用减少到目标token数以下。
119
+
120
  REPLY_LANGUAGES = [
121
  "简体中文",
122
  "繁體中文",
modules/shared.py CHANGED
@@ -41,11 +41,11 @@ class State:
41
  def switching_api_key(self, func):
42
  if not hasattr(self, "api_key_queue"):
43
  return func
44
-
45
  def wrapped(*args, **kwargs):
46
  api_key = self.api_key_queue.get()
47
- args = list(args)[1:]
48
- ret = func(api_key, *args, **kwargs)
49
  self.api_key_queue.put(api_key)
50
  return ret
51
 
 
41
  def switching_api_key(self, func):
42
  if not hasattr(self, "api_key_queue"):
43
  return func
44
+
45
  def wrapped(*args, **kwargs):
46
  api_key = self.api_key_queue.get()
47
+ args[0].api_key = api_key
48
+ ret = func(*args, **kwargs)
49
  self.api_key_queue.put(api_key)
50
  return ret
51
 
modules/utils.py CHANGED
@@ -34,6 +34,85 @@ if TYPE_CHECKING:
34
  headers: List[str]
35
  data: List[List[str | int | bool]]
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def count_token(message):
39
  encoding = tiktoken.get_encoding("cl100k_base")
@@ -121,10 +200,13 @@ def convert_asis(userinput):
121
 
122
 
123
  def detect_converted_mark(userinput):
124
- if userinput.endswith(ALREADY_CONVERTED_MARK):
 
 
 
 
 
125
  return True
126
- else:
127
- return False
128
 
129
 
130
  def detect_language(code):
@@ -153,107 +235,22 @@ def construct_assistant(text):
153
  return construct_text("assistant", text)
154
 
155
 
156
- def construct_token_message(tokens: List[int]):
157
- token_sum = 0
158
- for i in range(len(tokens)):
159
- token_sum += sum(tokens[: i + 1])
160
- return f"Token 计数: {sum(tokens)},本次对话累计消耗了 {token_sum} tokens"
161
-
162
-
163
- def delete_first_conversation(history, previous_token_count):
164
- if history:
165
- del history[:2]
166
- del previous_token_count[0]
167
- return (
168
- history,
169
- previous_token_count,
170
- construct_token_message(previous_token_count),
171
- )
172
-
173
-
174
- def delete_last_conversation(chatbot, history, previous_token_count):
175
- if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
176
- logging.info("由于包含报错信息,只删除chatbot记录")
177
- chatbot.pop()
178
- return chatbot, history
179
- if len(history) > 0:
180
- logging.info("删除了一组对话历史")
181
- history.pop()
182
- history.pop()
183
- if len(chatbot) > 0:
184
- logging.info("删除了一组chatbot对话")
185
- chatbot.pop()
186
- if len(previous_token_count) > 0:
187
- logging.info("删除了一组对话的token计数记录")
188
- previous_token_count.pop()
189
- return (
190
- chatbot,
191
- history,
192
- previous_token_count,
193
- construct_token_message(previous_token_count),
194
- )
195
-
196
-
197
  def save_file(filename, system, history, chatbot, user_name):
198
- logging.info(f"{user_name} 保存对话历史中……")
199
- os.makedirs(HISTORY_DIR / user_name, exist_ok=True)
200
  if filename.endswith(".json"):
201
  json_s = {"system": system, "history": history, "chatbot": chatbot}
202
  print(json_s)
203
- with open(os.path.join(HISTORY_DIR / user_name, filename), "w") as f:
204
  json.dump(json_s, f)
205
  elif filename.endswith(".md"):
206
  md_s = f"system: \n- {system} \n"
207
  for data in history:
208
  md_s += f"\n{data['role']}: \n- {data['content']} \n"
209
- with open(os.path.join(HISTORY_DIR / user_name, filename), "w", encoding="utf8") as f:
210
  f.write(md_s)
211
- logging.info(f"{user_name} 保存对话历史完毕")
212
- return os.path.join(HISTORY_DIR / user_name, filename)
213
-
214
-
215
- def save_chat_history(filename, system, history, chatbot, user_name):
216
- if filename == "":
217
- return
218
- if not filename.endswith(".json"):
219
- filename += ".json"
220
- return save_file(filename, system, history, chatbot, user_name)
221
-
222
-
223
- def export_markdown(filename, system, history, chatbot, user_name):
224
- if filename == "":
225
- return
226
- if not filename.endswith(".md"):
227
- filename += ".md"
228
- return save_file(filename, system, history, chatbot, user_name)
229
-
230
-
231
- def load_chat_history(filename, system, history, chatbot, user_name):
232
- logging.info(f"{user_name} 加载对话历史中……")
233
- if type(filename) != str:
234
- filename = filename.name
235
- try:
236
- with open(os.path.join(HISTORY_DIR / user_name, filename), "r") as f:
237
- json_s = json.load(f)
238
- try:
239
- if type(json_s["history"][0]) == str:
240
- logging.info("历史记录格式为旧版,正在转换……")
241
- new_history = []
242
- for index, item in enumerate(json_s["history"]):
243
- if index % 2 == 0:
244
- new_history.append(construct_user(item))
245
- else:
246
- new_history.append(construct_assistant(item))
247
- json_s["history"] = new_history
248
- logging.info(new_history)
249
- except:
250
- # 没有对话历史
251
- pass
252
- logging.info(f"{user_name} 加载对话历史完毕")
253
- return filename, json_s["system"], json_s["history"], json_s["chatbot"]
254
- except FileNotFoundError:
255
- logging.info(f"{user_name} 没有找到对话历史文件,不执行任何操作")
256
- return filename, system, history, chatbot
257
 
258
 
259
  def sorted_by_pinyin(list):
@@ -261,7 +258,7 @@ def sorted_by_pinyin(list):
261
 
262
 
263
  def get_file_names(dir, plain=False, filetypes=[".json"]):
264
- logging.info(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
265
  files = []
266
  try:
267
  for type in filetypes:
@@ -279,14 +276,13 @@ def get_file_names(dir, plain=False, filetypes=[".json"]):
279
 
280
 
281
  def get_history_names(plain=False, user_name=""):
282
- logging.info(f"从用户 {user_name} 中获取历史记录文件名列表")
283
- return get_file_names(HISTORY_DIR / user_name, plain)
284
 
285
 
286
  def load_template(filename, mode=0):
287
- logging.info(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
288
  lines = []
289
- logging.info("Loading template...")
290
  if filename.endswith(".json"):
291
  with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
292
  lines = json.load(f)
@@ -310,23 +306,18 @@ def load_template(filename, mode=0):
310
 
311
 
312
  def get_template_names(plain=False):
313
- logging.info("获取模板文件名列表")
314
  return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
315
 
316
 
317
  def get_template_content(templates, selection, original_system_prompt):
318
- logging.info(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}")
319
  try:
320
  return templates[selection]
321
  except:
322
  return original_system_prompt
323
 
324
 
325
- def reset_state():
326
- logging.info("重置状态")
327
- return [], [], [], construct_token_message([0])
328
-
329
-
330
  def reset_textbox():
331
  logging.debug("重置文本框")
332
  return gr.update(value="")
@@ -388,7 +379,7 @@ def get_geoip():
388
  logging.warning(f"无法获取IP地址信息。\n{data}")
389
  if data["reason"] == "RateLimited":
390
  return (
391
- f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用。"
392
  )
393
  else:
394
  return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
@@ -418,7 +409,7 @@ def find_n(lst, max_num):
418
 
419
  def start_outputing():
420
  logging.debug("显示取消按钮,隐藏发送按钮")
421
- return gr.Button.update(visible=True), gr.Button.update(visible=False)
422
 
423
 
424
  def end_outputing():
@@ -440,8 +431,8 @@ def transfer_input(inputs):
440
  return (
441
  inputs,
442
  gr.update(value=""),
443
- gr.Button.update(visible=True),
444
  gr.Button.update(visible=False),
 
445
  )
446
 
447
 
@@ -504,15 +495,15 @@ def add_details(lst):
504
  return nodes
505
 
506
 
507
- def sheet_to_string(sheet):
508
- result = ""
509
  for index, row in sheet.iterrows():
510
  row_string = ""
511
  for column in sheet.columns:
512
  row_string += f"{column}: {row[column]}, "
513
  row_string = row_string.rstrip(", ")
514
  row_string += "."
515
- result += row_string + "\n"
516
  return result
517
 
518
  def excel_to_string(file_path):
@@ -520,17 +511,23 @@ def excel_to_string(file_path):
520
  excel_file = pd.read_excel(file_path, engine='openpyxl', sheet_name=None)
521
 
522
  # 初始化结果字符串
523
- result = ""
524
 
525
  # 遍历每一个工作表
526
  for sheet_name, sheet_data in excel_file.items():
527
- # 将工作表名称添加到结果字符串
528
- result += f"Sheet: {sheet_name}\n"
529
 
530
  # 处理当前工作表并添加到结果字符串
531
- result += sheet_to_string(sheet_data)
532
 
533
- # 在不同工作表之间添加分隔符
534
- result += "\n" + ("-" * 20) + "\n\n"
535
 
536
  return result
 
 
 
 
 
 
 
 
 
 
 
34
  headers: List[str]
35
  data: List[List[str | int | bool]]
36
 
37
+ def predict(current_model, *args):
38
+ iter = current_model.predict(*args)
39
+ for i in iter:
40
+ yield i
41
+
42
+ def billing_info(current_model):
43
+ return current_model.billing_info()
44
+
45
+ def set_key(current_model, *args):
46
+ return current_model.set_key(*args)
47
+
48
+ def load_chat_history(current_model, *args):
49
+ return current_model.load_chat_history(*args)
50
+
51
+ def interrupt(current_model, *args):
52
+ return current_model.interrupt(*args)
53
+
54
+ def reset(current_model, *args):
55
+ return current_model.reset(*args)
56
+
57
+ def retry(current_model, *args):
58
+ iter = current_model.retry(*args)
59
+ for i in iter:
60
+ yield i
61
+
62
+ def delete_first_conversation(current_model, *args):
63
+ return current_model.delete_first_conversation(*args)
64
+
65
+ def delete_last_conversation(current_model, *args):
66
+ return current_model.delete_last_conversation(*args)
67
+
68
+ def set_system_prompt(current_model, *args):
69
+ return current_model.set_system_prompt(*args)
70
+
71
+ def save_chat_history(current_model, *args):
72
+ return current_model.save_chat_history(*args)
73
+
74
+ def export_markdown(current_model, *args):
75
+ return current_model.export_markdown(*args)
76
+
77
+ def load_chat_history(current_model, *args):
78
+ return current_model.load_chat_history(*args)
79
+
80
+ def set_token_upper_limit(current_model, *args):
81
+ return current_model.set_token_upper_limit(*args)
82
+
83
+ def set_temperature(current_model, *args):
84
+ current_model.set_temperature(*args)
85
+
86
+ def set_top_p(current_model, *args):
87
+ current_model.set_top_p(*args)
88
+
89
+ def set_n_choices(current_model, *args):
90
+ current_model.set_n_choices(*args)
91
+
92
+ def set_stop_sequence(current_model, *args):
93
+ current_model.set_stop_sequence(*args)
94
+
95
+ def set_max_tokens(current_model, *args):
96
+ current_model.set_max_tokens(*args)
97
+
98
+ def set_presence_penalty(current_model, *args):
99
+ current_model.set_presence_penalty(*args)
100
+
101
+ def set_frequency_penalty(current_model, *args):
102
+ current_model.set_frequency_penalty(*args)
103
+
104
+ def set_logit_bias(current_model, *args):
105
+ current_model.set_logit_bias(*args)
106
+
107
+ def set_user_identifier(current_model, *args):
108
+ current_model.set_user_identifier(*args)
109
+
110
+ def set_single_turn(current_model, *args):
111
+ current_model.set_single_turn(*args)
112
+
113
+ def handle_file_upload(current_model, *args):
114
+ return current_model.handle_file_upload(*args)
115
+
116
 
117
  def count_token(message):
118
  encoding = tiktoken.get_encoding("cl100k_base")
 
200
 
201
 
202
  def detect_converted_mark(userinput):
203
+ try:
204
+ if userinput.endswith(ALREADY_CONVERTED_MARK):
205
+ return True
206
+ else:
207
+ return False
208
+ except:
209
  return True
 
 
210
 
211
 
212
  def detect_language(code):
 
235
  return construct_text("assistant", text)
236
 
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  def save_file(filename, system, history, chatbot, user_name):
239
+ logging.debug(f"{user_name} 保存对话历史中……")
240
+ os.makedirs(os.path.join(HISTORY_DIR, user_name), exist_ok=True)
241
  if filename.endswith(".json"):
242
  json_s = {"system": system, "history": history, "chatbot": chatbot}
243
  print(json_s)
244
+ with open(os.path.join(HISTORY_DIR, user_name, filename), "w") as f:
245
  json.dump(json_s, f)
246
  elif filename.endswith(".md"):
247
  md_s = f"system: \n- {system} \n"
248
  for data in history:
249
  md_s += f"\n{data['role']}: \n- {data['content']} \n"
250
+ with open(os.path.join(HISTORY_DIR, user_name, filename), "w", encoding="utf8") as f:
251
  f.write(md_s)
252
+ logging.debug(f"{user_name} 保存对话历史完毕")
253
+ return os.path.join(HISTORY_DIR, user_name, filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
 
256
  def sorted_by_pinyin(list):
 
258
 
259
 
260
  def get_file_names(dir, plain=False, filetypes=[".json"]):
261
+ logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
262
  files = []
263
  try:
264
  for type in filetypes:
 
276
 
277
 
278
  def get_history_names(plain=False, user_name=""):
279
+ logging.debug(f"从用户 {user_name} 中获取历史记录文件名列表")
280
+ return get_file_names(os.path.join(HISTORY_DIR, user_name), plain)
281
 
282
 
283
  def load_template(filename, mode=0):
284
+ logging.debug(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
285
  lines = []
 
286
  if filename.endswith(".json"):
287
  with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
288
  lines = json.load(f)
 
306
 
307
 
308
  def get_template_names(plain=False):
309
+ logging.debug("获取模板文件名列表")
310
  return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
311
 
312
 
313
  def get_template_content(templates, selection, original_system_prompt):
314
+ logging.debug(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}")
315
  try:
316
  return templates[selection]
317
  except:
318
  return original_system_prompt
319
 
320
 
 
 
 
 
 
321
  def reset_textbox():
322
  logging.debug("重置文本框")
323
  return gr.update(value="")
 
379
  logging.warning(f"无法获取IP地址信息。\n{data}")
380
  if data["reason"] == "RateLimited":
381
  return (
382
+ f"您的IP区域:未知。"
383
  )
384
  else:
385
  return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
 
409
 
410
  def start_outputing():
411
  logging.debug("显示取消按钮,隐藏发送按钮")
412
+ return gr.Button.update(visible=False), gr.Button.update(visible=True)
413
 
414
 
415
  def end_outputing():
 
431
  return (
432
  inputs,
433
  gr.update(value=""),
 
434
  gr.Button.update(visible=False),
435
+ gr.Button.update(visible=True),
436
  )
437
 
438
 
 
495
  return nodes
496
 
497
 
498
+ def sheet_to_string(sheet, sheet_name = None):
499
+ result = []
500
  for index, row in sheet.iterrows():
501
  row_string = ""
502
  for column in sheet.columns:
503
  row_string += f"{column}: {row[column]}, "
504
  row_string = row_string.rstrip(", ")
505
  row_string += "."
506
+ result.append(row_string)
507
  return result
508
 
509
  def excel_to_string(file_path):
 
511
  excel_file = pd.read_excel(file_path, engine='openpyxl', sheet_name=None)
512
 
513
  # 初始化结果字符串
514
+ result = []
515
 
516
  # 遍历每一个工作表
517
  for sheet_name, sheet_data in excel_file.items():
 
 
518
 
519
  # 处理当前工作表并添加到结果字符串
520
+ result += sheet_to_string(sheet_data, sheet_name=sheet_name)
521
 
 
 
522
 
523
  return result
524
+
525
+ def get_last_day_of_month(any_day):
526
+ # The day 28 exists in every month. 4 days later, it's always next month
527
+ next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
528
+ # subtracting the number of the current day brings us back one month
529
+ return next_month - datetime.timedelta(days=next_month.day)
530
+
531
+ def get_model_source(model_name, alternative_source):
532
+ if model_name == "gpt2-medium":
533
+ return "https://huggingface.co/gpt2-medium"
requirements.txt CHANGED
@@ -13,3 +13,4 @@ markdown
13
  PyPDF2
14
  pdfplumber
15
  pandas
 
 
13
  PyPDF2
14
  pdfplumber
15
  pandas
16
+ commentjson
requirements_advanced.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ icetk
4
+ protobuf==3.19.0
5
+ git+https://github.com/OptimalScale/LMFlow.git
6
+ cpm-kernels
7
+ sentence_transformers
run_Linux.sh CHANGED
@@ -1,10 +1,10 @@
1
  #!/bin/bash
2
 
3
  # 获取脚本所在目录
4
- script_dir=$(dirname "$0")
5
 
6
  # 将工作目录更改为脚本所在目录
7
- cd "$script_dir"
8
 
9
  # 检查Git仓库是否有更新
10
  git remote update
@@ -23,3 +23,9 @@ if ! git status -uno | grep 'up to date' > /dev/null; then
23
  # 重新启动服务器
24
  nohup python3 ChuanhuChatbot.py &
25
  fi
 
 
 
 
 
 
 
1
  #!/bin/bash
2
 
3
  # 获取脚本所在目录
4
+ script_dir=$(dirname "$(readlink -f "$0")")
5
 
6
  # 将工作目录更改为脚本所在目录
7
+ cd "$script_dir" || exit
8
 
9
  # 检查Git仓库是否有更新
10
  git remote update
 
23
  # 重新启动服务器
24
  nohup python3 ChuanhuChatbot.py &
25
  fi
26
+
27
+ # 检查ChuanhuChatbot.py是否在运行
28
+ if ! pgrep -f ChuanhuChatbot.py > /dev/null; then
29
+ # 如果没有运行,启动服务器
30
+ nohup python3 ChuanhuChatbot.py &
31
+ fi
run_macOS.command CHANGED
@@ -1,10 +1,10 @@
1
  #!/bin/bash
2
 
3
  # 获取脚本所在目录
4
- script_dir=$(dirname "$0")
5
 
6
  # 将工作目录更改为脚本所在目录
7
- cd "$script_dir"
8
 
9
  # 检查Git仓库是否有更新
10
  git remote update
@@ -23,3 +23,9 @@ if ! git status -uno | grep 'up to date' > /dev/null; then
23
  # 重新启动服务器
24
  nohup python3 ChuanhuChatbot.py &
25
  fi
 
 
 
 
 
 
 
1
  #!/bin/bash
2
 
3
  # 获取脚本所在目录
4
+ script_dir=$(dirname "$(readlink -f "$0")")
5
 
6
  # 将工作目录更改为脚本所在目录
7
+ cd "$script_dir" || exit
8
 
9
  # 检查Git仓库是否有更新
10
  git remote update
 
23
  # 重新启动服务器
24
  nohup python3 ChuanhuChatbot.py &
25
  fi
26
+
27
+ # 检查ChuanhuChatbot.py是否在运行
28
+ if ! pgrep -f ChuanhuChatbot.py > /dev/null; then
29
+ # 如果没有运行,启动服务器
30
+ nohup python3 ChuanhuChatbot.py &
31
+ fi