JohnSmith9982 commited on
Commit
fa4087a
·
1 Parent(s): c9c16d7

Upload 37 files

Browse files
ChuanhuChatbot.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ import os
3
+ import logging
4
+ import sys
5
+
6
+ import gradio as gr
7
+
8
+ from modules import config
9
+ from modules.config import *
10
+ from modules.utils import *
11
+ from modules.presets import *
12
+ from modules.overwrites import *
13
+ from modules.chat_func import *
14
+ from modules.openai_func import get_usage
15
+
16
+ gr.Chatbot.postprocess = postprocess
17
+ PromptHelper.compact_text_chunks = compact_text_chunks
18
+
19
+ with open("assets/custom.css", "r", encoding="utf-8") as f:
20
+ customCSS = f.read()
21
+
22
+ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
23
+ user_name = gr.State("")
24
+ history = gr.State([])
25
+ token_count = gr.State([])
26
+ promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
27
+ user_api_key = gr.State(my_api_key)
28
+ user_question = gr.State("")
29
+ outputing = gr.State(False)
30
+ topic = gr.State("未命名对话历史记录")
31
+
32
+ with gr.Row():
33
+ with gr.Column():
34
+ gr.HTML(title)
35
+ user_info = gr.Markdown(value="", elem_id="user_info")
36
+ gr.HTML('<center><a href="https://huggingface.co/spaces/JohnSmith9982/ChuanhuChatGPT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a></center>')
37
+ status_display = gr.Markdown(get_geoip(), elem_id="status_display")
38
+
39
+ # https://github.com/gradio-app/gradio/pull/3296
40
+ def create_greeting(request: gr.Request):
41
+ if hasattr(request, "username") and request.username: # is not None or is not ""
42
+ logging.info(f"Get User Name: {request.username}")
43
+ return gr.Markdown.update(value=f"User: {request.username}"), request.username
44
+ else:
45
+ return gr.Markdown.update(value=f"User: default", visible=False), ""
46
+ demo.load(create_greeting, inputs=None, outputs=[user_info, user_name])
47
+
48
+ with gr.Row().style(equal_height=True):
49
+ with gr.Column(scale=5):
50
+ with gr.Row():
51
+ chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%")
52
+ with gr.Row():
53
+ with gr.Column(scale=12):
54
+ user_input = gr.Textbox(
55
+ elem_id="user_input_tb",
56
+ show_label=False, placeholder="在这里输入"
57
+ ).style(container=False)
58
+ with gr.Column(min_width=70, scale=1):
59
+ submitBtn = gr.Button("发送", variant="primary")
60
+ cancelBtn = gr.Button("取消", variant="secondary", visible=False)
61
+ with gr.Row():
62
+ emptyBtn = gr.Button(
63
+ "🧹 新的对话",
64
+ )
65
+ retryBtn = gr.Button("🔄 重新生成")
66
+ delFirstBtn = gr.Button("🗑️ 删除最旧对话")
67
+ delLastBtn = gr.Button("🗑️ 删除最新对话")
68
+ reduceTokenBtn = gr.Button("♻️ 总结对话")
69
+
70
+ with gr.Column():
71
+ with gr.Column(min_width=50, scale=1):
72
+ with gr.Tab(label="ChatGPT"):
73
+ keyTxt = gr.Textbox(
74
+ show_label=True,
75
+ placeholder=f"OpenAI API-key...",
76
+ value=hide_middle_chars(my_api_key),
77
+ type="password",
78
+ visible=not HIDE_MY_KEY,
79
+ label="API-Key",
80
+ )
81
+ if multi_api_key:
82
+ usageTxt = gr.Markdown("多账号模式已开启,无需输入key,可直接开始对话", elem_id="usage_display")
83
+ else:
84
+ usageTxt = gr.Markdown("**发送消息** 或 **提交key** 以显示额度", elem_id="usage_display")
85
+ model_select_dropdown = gr.Dropdown(
86
+ label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0]
87
+ )
88
+ use_streaming_checkbox = gr.Checkbox(
89
+ label="实时传输回答", value=True, visible=enable_streaming_option
90
+ )
91
+ use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
92
+ language_select_dropdown = gr.Dropdown(
93
+ label="选择回复语言(针对搜索&索引功能)",
94
+ choices=REPLY_LANGUAGES,
95
+ multiselect=False,
96
+ value=REPLY_LANGUAGES[0],
97
+ )
98
+ index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
99
+ two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
100
+ # TODO: 公式ocr
101
+ # formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
102
+
103
+ with gr.Tab(label="Prompt"):
104
+ systemPromptTxt = gr.Textbox(
105
+ show_label=True,
106
+ placeholder=f"在这里输入System Prompt...",
107
+ label="System prompt",
108
+ value=initial_prompt,
109
+ lines=10,
110
+ ).style(container=False)
111
+ with gr.Accordion(label="加载Prompt模板", open=True):
112
+ with gr.Column():
113
+ with gr.Row():
114
+ with gr.Column(scale=6):
115
+ templateFileSelectDropdown = gr.Dropdown(
116
+ label="选择Prompt模板集合文件",
117
+ choices=get_template_names(plain=True),
118
+ multiselect=False,
119
+ value=get_template_names(plain=True)[0],
120
+ ).style(container=False)
121
+ with gr.Column(scale=1):
122
+ templateRefreshBtn = gr.Button("🔄 刷新")
123
+ with gr.Row():
124
+ with gr.Column():
125
+ templateSelectDropdown = gr.Dropdown(
126
+ label="从Prompt模板中加载",
127
+ choices=load_template(
128
+ get_template_names(plain=True)[0], mode=1
129
+ ),
130
+ multiselect=False,
131
+ ).style(container=False)
132
+
133
+ with gr.Tab(label="保存/加载"):
134
+ with gr.Accordion(label="保存/加载对话历史记录", open=True):
135
+ with gr.Column():
136
+ with gr.Row():
137
+ with gr.Column(scale=6):
138
+ historyFileSelectDropdown = gr.Dropdown(
139
+ label="从列表中加载对话",
140
+ choices=get_history_names(plain=True),
141
+ multiselect=False,
142
+ value=get_history_names(plain=True)[0],
143
+ )
144
+ with gr.Column(scale=1):
145
+ historyRefreshBtn = gr.Button("🔄 刷新")
146
+ with gr.Row():
147
+ with gr.Column(scale=6):
148
+ saveFileName = gr.Textbox(
149
+ show_label=True,
150
+ placeholder=f"设置文件名: 默认为.json,可选为.md",
151
+ label="设置保存文件名",
152
+ value="对话历史记录",
153
+ ).style(container=True)
154
+ with gr.Column(scale=1):
155
+ saveHistoryBtn = gr.Button("💾 保存对话")
156
+ exportMarkdownBtn = gr.Button("📝 导出为Markdown")
157
+ gr.Markdown("默认保存于history文件夹")
158
+ with gr.Row():
159
+ with gr.Column():
160
+ downloadFile = gr.File(interactive=True)
161
+
162
+ with gr.Tab(label="高级"):
163
+ gr.Markdown("# ⚠️ 务必谨慎更改 ⚠️\n\n如果无法使用请恢复默认设置")
164
+ default_btn = gr.Button("🔙 恢复默认设置")
165
+
166
+ with gr.Accordion("参数", open=False):
167
+ top_p = gr.Slider(
168
+ minimum=-0,
169
+ maximum=1.0,
170
+ value=1.0,
171
+ step=0.05,
172
+ interactive=True,
173
+ label="Top-p",
174
+ )
175
+ temperature = gr.Slider(
176
+ minimum=-0,
177
+ maximum=2.0,
178
+ value=1.0,
179
+ step=0.1,
180
+ interactive=True,
181
+ label="Temperature",
182
+ )
183
+
184
+ with gr.Accordion("网络设置", open=False, visible=False):
185
+ # 优先展示自定义的api_host
186
+ apihostTxt = gr.Textbox(
187
+ show_label=True,
188
+ placeholder=f"在这里输入API-Host...",
189
+ label="API-Host",
190
+ value=config.api_host or shared.API_HOST,
191
+ lines=1,
192
+ )
193
+ changeAPIURLBtn = gr.Button("🔄 切换API地址")
194
+ proxyTxt = gr.Textbox(
195
+ show_label=True,
196
+ placeholder=f"在这里输入代理地址...",
197
+ label="代理地址(示例:http://127.0.0.1:10809)",
198
+ value="",
199
+ lines=2,
200
+ )
201
+ changeProxyBtn = gr.Button("🔄 设置代理地址")
202
+
203
+ gr.Markdown(description)
204
+ gr.HTML(footer.format(versions=versions_html()), elem_id="footer")
205
+ chatgpt_predict_args = dict(
206
+ fn=predict,
207
+ inputs=[
208
+ user_api_key,
209
+ systemPromptTxt,
210
+ history,
211
+ user_question,
212
+ chatbot,
213
+ token_count,
214
+ top_p,
215
+ temperature,
216
+ use_streaming_checkbox,
217
+ model_select_dropdown,
218
+ use_websearch_checkbox,
219
+ index_files,
220
+ language_select_dropdown,
221
+ ],
222
+ outputs=[chatbot, history, status_display, token_count],
223
+ show_progress=True,
224
+ )
225
+
226
+ start_outputing_args = dict(
227
+ fn=start_outputing,
228
+ inputs=[],
229
+ outputs=[submitBtn, cancelBtn],
230
+ show_progress=True,
231
+ )
232
+
233
+ end_outputing_args = dict(
234
+ fn=end_outputing, inputs=[], outputs=[submitBtn, cancelBtn]
235
+ )
236
+
237
+ reset_textbox_args = dict(
238
+ fn=reset_textbox, inputs=[], outputs=[user_input]
239
+ )
240
+
241
+ transfer_input_args = dict(
242
+ fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn, cancelBtn], show_progress=True
243
+ )
244
+
245
+ get_usage_args = dict(
246
+ fn=get_usage, inputs=[user_api_key], outputs=[usageTxt], show_progress=False
247
+ )
248
+
249
+
250
+ # Chatbot
251
+ cancelBtn.click(cancel_outputing, [], [])
252
+
253
+ user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
254
+ user_input.submit(**get_usage_args)
255
+
256
+ submitBtn.click(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
257
+ submitBtn.click(**get_usage_args)
258
+
259
+ emptyBtn.click(
260
+ reset_state,
261
+ outputs=[chatbot, history, token_count, status_display],
262
+ show_progress=True,
263
+ )
264
+ emptyBtn.click(**reset_textbox_args)
265
+
266
+ retryBtn.click(**start_outputing_args).then(
267
+ retry,
268
+ [
269
+ user_api_key,
270
+ systemPromptTxt,
271
+ history,
272
+ chatbot,
273
+ token_count,
274
+ top_p,
275
+ temperature,
276
+ use_streaming_checkbox,
277
+ model_select_dropdown,
278
+ language_select_dropdown,
279
+ ],
280
+ [chatbot, history, status_display, token_count],
281
+ show_progress=True,
282
+ ).then(**end_outputing_args)
283
+ retryBtn.click(**get_usage_args)
284
+
285
+ delFirstBtn.click(
286
+ delete_first_conversation,
287
+ [history, token_count],
288
+ [history, token_count, status_display],
289
+ )
290
+
291
+ delLastBtn.click(
292
+ delete_last_conversation,
293
+ [chatbot, history, token_count],
294
+ [chatbot, history, token_count, status_display],
295
+ show_progress=True,
296
+ )
297
+
298
+ reduceTokenBtn.click(
299
+ reduce_token_size,
300
+ [
301
+ user_api_key,
302
+ systemPromptTxt,
303
+ history,
304
+ chatbot,
305
+ token_count,
306
+ top_p,
307
+ temperature,
308
+ gr.State(sum(token_count.value[-4:])),
309
+ model_select_dropdown,
310
+ language_select_dropdown,
311
+ ],
312
+ [chatbot, history, status_display, token_count],
313
+ show_progress=True,
314
+ )
315
+ reduceTokenBtn.click(**get_usage_args)
316
+
317
+ two_column.change(update_doc_config, [two_column], None)
318
+
319
+ # ChatGPT
320
+ keyTxt.change(submit_key, keyTxt, [user_api_key, status_display]).then(**get_usage_args)
321
+ keyTxt.submit(**get_usage_args)
322
+
323
+ # Template
324
+ templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
325
+ templateFileSelectDropdown.change(
326
+ load_template,
327
+ [templateFileSelectDropdown],
328
+ [promptTemplates, templateSelectDropdown],
329
+ show_progress=True,
330
+ )
331
+ templateSelectDropdown.change(
332
+ get_template_content,
333
+ [promptTemplates, templateSelectDropdown, systemPromptTxt],
334
+ [systemPromptTxt],
335
+ show_progress=True,
336
+ )
337
+
338
+ # S&L
339
+ saveHistoryBtn.click(
340
+ save_chat_history,
341
+ [saveFileName, systemPromptTxt, history, chatbot, user_name],
342
+ downloadFile,
343
+ show_progress=True,
344
+ )
345
+ saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
346
+ exportMarkdownBtn.click(
347
+ export_markdown,
348
+ [saveFileName, systemPromptTxt, history, chatbot, user_name],
349
+ downloadFile,
350
+ show_progress=True,
351
+ )
352
+ historyRefreshBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
353
+ historyFileSelectDropdown.change(
354
+ load_chat_history,
355
+ [historyFileSelectDropdown, systemPromptTxt, history, chatbot, user_name],
356
+ [saveFileName, systemPromptTxt, history, chatbot],
357
+ show_progress=True,
358
+ )
359
+ downloadFile.change(
360
+ load_chat_history,
361
+ [downloadFile, systemPromptTxt, history, chatbot, user_name],
362
+ [saveFileName, systemPromptTxt, history, chatbot],
363
+ )
364
+
365
+ # Advanced
366
+ default_btn.click(
367
+ reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
368
+ )
369
+ changeAPIURLBtn.click(
370
+ change_api_host,
371
+ [apihostTxt],
372
+ [status_display],
373
+ show_progress=True,
374
+ )
375
+ changeProxyBtn.click(
376
+ change_proxy,
377
+ [proxyTxt],
378
+ [status_display],
379
+ show_progress=True,
380
+ )
381
+
382
+ logging.info(
383
+ colorama.Back.GREEN
384
+ + "\n川虎的温馨提示:访问 http://localhost:7860 查看界面"
385
+ + colorama.Style.RESET_ALL
386
+ )
387
+ # 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
388
+ demo.title = "川虎ChatGPT 🚀"
389
+
390
+ if __name__ == "__main__":
391
+ reload_javascript()
392
+ # if running in Docker
393
+ if dockerflag:
394
+ if authflag:
395
+ demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
396
+ server_name="0.0.0.0",
397
+ server_port=7860,
398
+ auth=auth_list,
399
+ favicon_path="./assets/favicon.ico",
400
+ )
401
+ else:
402
+ demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
403
+ server_name="0.0.0.0",
404
+ server_port=7860,
405
+ share=False,
406
+ favicon_path="./assets/favicon.ico",
407
+ )
408
+ # if not running in Docker
409
+ else:
410
+ if authflag:
411
+ demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
412
+ share=False,
413
+ auth=auth_list,
414
+ favicon_path="./assets/favicon.ico",
415
+ inbrowser=True,
416
+ )
417
+ else:
418
+ demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
419
+ share=False, favicon_path="./assets/favicon.ico", inbrowser=True
420
+ ) # 改为 share=True 可以创建公开分享链接
421
+ # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
422
+ # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
423
+ # demo.queue(concurrency_count=CONCURRENT_COUNT).launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
README.md CHANGED
@@ -4,8 +4,8 @@ emoji: 🐯
4
  colorFrom: green
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.23.0
8
- app_file: app.py
9
  pinned: false
10
  license: gpl-3.0
11
  ---
 
4
  colorFrom: green
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.24.1
8
+ app_file: ChuanhuChatbot.py
9
  pinned: false
10
  license: gpl-3.0
11
  ---
assets/custom.css CHANGED
@@ -18,10 +18,22 @@ footer {
18
  opacity: 0.85;
19
  }
20
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  /* status_display */
22
  #status_display {
23
  display: flex;
24
- min-height: 2.5em;
25
  align-items: flex-end;
26
  justify-content: flex-end;
27
  }
@@ -110,7 +122,6 @@ ol:not(.options), ul:not(.options) {
110
  background-color: var(--neutral-950) !important;
111
  }
112
  }
113
-
114
  /* 对话气泡 */
115
  [class *= "message"] {
116
  border-radius: var(--radius-xl) !important;
 
18
  opacity: 0.85;
19
  }
20
 
21
+ /* user_info */
22
+ #user_info {
23
+ white-space: nowrap;
24
+ margin-top: -1.3em !important;
25
+ padding-left: 112px !important;
26
+ }
27
+ #user_info p {
28
+ font-size: .85em;
29
+ font-family: monospace;
30
+ color: var(--body-text-color-subdued);
31
+ }
32
+
33
  /* status_display */
34
  #status_display {
35
  display: flex;
36
+ min-height: 2em;
37
  align-items: flex-end;
38
  justify-content: flex-end;
39
  }
 
122
  background-color: var(--neutral-950) !important;
123
  }
124
  }
 
125
  /* 对话气泡 */
126
  [class *= "message"] {
127
  border-radius: var(--radius-xl) !important;
assets/custom.js CHANGED
@@ -1 +1,70 @@
1
- // custom javascript here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // custom javascript here
2
+ const MAX_HISTORY_LENGTH = 32;
3
+
4
+ var key_down_history = [];
5
+ var currentIndex = -1;
6
+ var user_input_ta;
7
+
8
+ var ga = document.getElementsByTagName("gradio-app");
9
+ var targetNode = ga[0];
10
+ var observer = new MutationObserver(function(mutations) {
11
+ for (var i = 0; i < mutations.length; i++) {
12
+ if (mutations[i].addedNodes.length) {
13
+ var user_input_tb = document.getElementById('user_input_tb');
14
+ if (user_input_tb) {
15
+ // 监听到user_input_tb被添加到DOM树中
16
+ // 这里可以编写元素加载完成后需要执行的代码
17
+ user_input_ta = user_input_tb.querySelector("textarea");
18
+ if (user_input_ta){
19
+ observer.disconnect(); // 停止监听
20
+ // 在 textarea 上监听 keydown 事件
21
+ user_input_ta.addEventListener("keydown", function (event) {
22
+ var value = user_input_ta.value.trim();
23
+ // 判断按下的是否为方向键
24
+ if (event.code === 'ArrowUp' || event.code === 'ArrowDown') {
25
+ // 如果按下的是方向键,且输入框中有内容,且历史记录中没有该内容,则不执行操作
26
+ if(value && key_down_history.indexOf(value) === -1)
27
+ return;
28
+ // 对于需要响应的动作,阻止默认行为。
29
+ event.preventDefault();
30
+ var length = key_down_history.length;
31
+ if(length === 0) {
32
+ currentIndex = -1; // 如果历史记录为空,直接将当前选中的记录重置
33
+ return;
34
+ }
35
+ if (currentIndex === -1) {
36
+ currentIndex = length;
37
+ }
38
+ if (event.code === 'ArrowUp' && currentIndex > 0) {
39
+ currentIndex--;
40
+ user_input_ta.value = key_down_history[currentIndex];
41
+ } else if (event.code === 'ArrowDown' && currentIndex < length - 1) {
42
+ currentIndex++;
43
+ user_input_ta.value = key_down_history[currentIndex];
44
+ }
45
+ user_input_ta.selectionStart = user_input_ta.value.length;
46
+ user_input_ta.selectionEnd = user_input_ta.value.length;
47
+ const input_event = new InputEvent("input", {bubbles: true, cancelable: true});
48
+ user_input_ta.dispatchEvent(input_event);
49
+ }else if(event.code === "Enter") {
50
+ if (value) {
51
+ currentIndex = -1;
52
+ if(key_down_history.indexOf(value) === -1){
53
+ key_down_history.push(value);
54
+ if (key_down_history.length > MAX_HISTORY_LENGTH) {
55
+ key_down_history.shift();
56
+ }
57
+ }
58
+ }
59
+ }
60
+ });
61
+ break;
62
+ }
63
+ }
64
+ }
65
+ }
66
+ });
67
+
68
+ // 监听目标节点的子节点列表是否发生变化
69
+ observer.observe(targetNode, { childList: true , subtree: true });
70
+
modules/__pycache__/chat_func.cpython-39.pyc CHANGED
Binary files a/modules/__pycache__/chat_func.cpython-39.pyc and b/modules/__pycache__/chat_func.cpython-39.pyc differ
 
modules/__pycache__/config.cpython-39.pyc ADDED
Binary file (3.18 kB). View file
 
modules/__pycache__/llama_func.cpython-39.pyc CHANGED
Binary files a/modules/__pycache__/llama_func.cpython-39.pyc and b/modules/__pycache__/llama_func.cpython-39.pyc differ
 
modules/__pycache__/openai_func.cpython-39.pyc CHANGED
Binary files a/modules/__pycache__/openai_func.cpython-39.pyc and b/modules/__pycache__/openai_func.cpython-39.pyc differ
 
modules/__pycache__/overwrites.cpython-39.pyc CHANGED
Binary files a/modules/__pycache__/overwrites.cpython-39.pyc and b/modules/__pycache__/overwrites.cpython-39.pyc differ
 
modules/__pycache__/pdf_func.cpython-39.pyc ADDED
Binary file (6.13 kB). View file
 
modules/__pycache__/presets.cpython-39.pyc CHANGED
Binary files a/modules/__pycache__/presets.cpython-39.pyc and b/modules/__pycache__/presets.cpython-39.pyc differ
 
modules/__pycache__/proxy_func.cpython-39.pyc ADDED
Binary file (718 Bytes). View file
 
modules/__pycache__/shared.cpython-39.pyc CHANGED
Binary files a/modules/__pycache__/shared.cpython-39.pyc and b/modules/__pycache__/shared.cpython-39.pyc differ
 
modules/__pycache__/utils.cpython-39.pyc CHANGED
Binary files a/modules/__pycache__/utils.cpython-39.pyc and b/modules/__pycache__/utils.cpython-39.pyc differ
 
modules/chat_func.py CHANGED
@@ -13,14 +13,13 @@ import colorama
13
  from duckduckgo_search import ddg
14
  import asyncio
15
  import aiohttp
16
- from llama_index.indices.query.vector_store import GPTVectorStoreIndexQuery
17
- from llama_index.indices.query.schema import QueryBundle
18
- from langchain.llms import OpenAIChat
19
 
20
  from modules.presets import *
21
  from modules.llama_func import *
22
  from modules.utils import *
23
- import modules.shared as shared
 
24
 
25
  # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
26
 
@@ -36,6 +35,7 @@ initial_prompt = "You are a helpful assistant."
36
  HISTORY_DIR = "history"
37
  TEMPLATES_DIR = "templates"
38
 
 
39
  def get_response(
40
  openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
41
  ):
@@ -61,20 +61,19 @@ def get_response(
61
  else:
62
  timeout = timeout_all
63
 
64
- proxies = get_proxies()
65
 
66
- # 如果有自定义的api-url,使用自定义url发送请求,否则使用默认设置发送请求
67
- if shared.state.api_url != API_URL:
68
- logging.info(f"使用自定义API URL: {shared.state.api_url}")
69
 
70
- response = requests.post(
71
- shared.state.api_url,
72
- headers=headers,
73
- json=payload,
74
- stream=True,
75
- timeout=timeout,
76
- proxies=proxies,
77
- )
78
 
79
  return response
80
 
@@ -146,7 +145,7 @@ def stream_predict(
146
 
147
  if fake_input is not None:
148
  history[-2] = construct_user(fake_input)
149
- for chunk in response.iter_lines():
150
  if counter == 0:
151
  counter += 1
152
  continue
@@ -166,9 +165,7 @@ def stream_predict(
166
  # decode each line as response data is in bytes
167
  if chunklength > 6 and "delta" in chunk["choices"][0]:
168
  finish_reason = chunk["choices"][0]["finish_reason"]
169
- status_text = construct_token_message(
170
- sum(all_token_counts), stream=True
171
- )
172
  if finish_reason == "stop":
173
  yield get_return_value()
174
  break
@@ -253,14 +250,6 @@ def predict_all(
253
  status_text = standard_error_msg + str(response)
254
  return chatbot, history, status_text, all_token_counts
255
 
256
- def is_repeated_string(s):
257
- n = len(s)
258
- for i in range(1, n // 2 + 1):
259
- if n % i == 0:
260
- sub = s[:i]
261
- if sub * (n // i) == s:
262
- return True
263
- return False
264
 
265
  def predict(
266
  openai_api_key,
@@ -278,11 +267,12 @@ def predict(
278
  reply_language="中文",
279
  should_check_token_count=True,
280
  ): # repetition_penalty, top_k
 
 
 
 
 
281
  logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
282
- if is_repeated_string(inputs):
283
- print("================== 有人来浪费了 ======================")
284
- yield chatbot+[(inputs, "🖕️🖕️🖕️🖕️🖕️看不起你")], history, "🖕️🖕️🖕️🖕️🖕️🖕️", all_token_counts
285
- return
286
  if should_check_token_count:
287
  yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
288
  if reply_language == "跟随问题语言(不稳定)":
@@ -300,12 +290,14 @@ def predict(
300
  msg = "索引构建完成,获取回答中……"
301
  logging.info(msg)
302
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
303
- llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
304
- prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
305
- service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
306
- query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
307
- query_bundle = QueryBundle(inputs)
308
- nodes = query_object.retrieve(query_bundle)
 
 
309
  reference_results = [n.node.text for n in nodes]
310
  reference_results = add_source_numbers(reference_results, use_source=False)
311
  display_reference = add_details(reference_results)
@@ -337,7 +329,7 @@ def predict(
337
  else:
338
  display_reference = ""
339
 
340
- if len(openai_api_key) != 51:
341
  status_text = standard_error_msg + no_apikey_msg
342
  logging.info(status_text)
343
  chatbot.append((inputs, ""))
@@ -412,23 +404,15 @@ def predict(
412
  max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["all"]
413
 
414
  if sum(all_token_counts) > max_token and should_check_token_count:
415
- status_text = f"精简token中{all_token_counts}/{max_token}"
 
 
 
 
 
416
  logging.info(status_text)
 
417
  yield chatbot, history, status_text, all_token_counts
418
- iter = reduce_token_size(
419
- openai_api_key,
420
- system_prompt,
421
- history,
422
- chatbot,
423
- all_token_counts,
424
- top_p,
425
- temperature,
426
- max_token//2,
427
- selected_model=selected_model,
428
- )
429
- for chatbot, history, status_text, all_token_counts in iter:
430
- status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
431
- yield chatbot, history, status_text, all_token_counts
432
 
433
 
434
  def retry(
@@ -507,7 +491,7 @@ def reduce_token_size(
507
  token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
508
  msg = f"保留了最近{num_chat}轮对话"
509
  yield chatbot, history, msg + "," + construct_token_message(
510
- sum(token_count) if len(token_count) > 0 else 0,
511
  ), token_count
512
  logging.info(msg)
513
  logging.info("减少token数量完毕")
 
13
  from duckduckgo_search import ddg
14
  import asyncio
15
  import aiohttp
16
+
 
 
17
 
18
  from modules.presets import *
19
  from modules.llama_func import *
20
  from modules.utils import *
21
+ from . import shared
22
+ from modules.config import retrieve_proxy
23
 
24
  # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
25
 
 
35
  HISTORY_DIR = "history"
36
  TEMPLATES_DIR = "templates"
37
 
38
+ @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
39
  def get_response(
40
  openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
41
  ):
 
61
  else:
62
  timeout = timeout_all
63
 
 
64
 
65
+ # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
66
+ if shared.state.completion_url != COMPLETION_URL:
67
+ logging.info(f"使用自定义API URL: {shared.state.completion_url}")
68
 
69
+ with retrieve_proxy():
70
+ response = requests.post(
71
+ shared.state.completion_url,
72
+ headers=headers,
73
+ json=payload,
74
+ stream=True,
75
+ timeout=timeout,
76
+ )
77
 
78
  return response
79
 
 
145
 
146
  if fake_input is not None:
147
  history[-2] = construct_user(fake_input)
148
+ for chunk in tqdm(response.iter_lines()):
149
  if counter == 0:
150
  counter += 1
151
  continue
 
165
  # decode each line as response data is in bytes
166
  if chunklength > 6 and "delta" in chunk["choices"][0]:
167
  finish_reason = chunk["choices"][0]["finish_reason"]
168
+ status_text = construct_token_message(all_token_counts)
 
 
169
  if finish_reason == "stop":
170
  yield get_return_value()
171
  break
 
250
  status_text = standard_error_msg + str(response)
251
  return chatbot, history, status_text, all_token_counts
252
 
 
 
 
 
 
 
 
 
253
 
254
  def predict(
255
  openai_api_key,
 
267
  reply_language="中文",
268
  should_check_token_count=True,
269
  ): # repetition_penalty, top_k
270
+ from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
271
+ from llama_index.indices.query.schema import QueryBundle
272
+ from langchain.llms import OpenAIChat
273
+
274
+
275
  logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
 
 
 
 
276
  if should_check_token_count:
277
  yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
278
  if reply_language == "跟随问题语言(不稳定)":
 
290
  msg = "索引构建完成,获取回答中……"
291
  logging.info(msg)
292
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
293
+ with retrieve_proxy():
294
+ llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
295
+ prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
296
+ from llama_index import ServiceContext
297
+ service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
298
+ query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
299
+ query_bundle = QueryBundle(inputs)
300
+ nodes = query_object.retrieve(query_bundle)
301
  reference_results = [n.node.text for n in nodes]
302
  reference_results = add_source_numbers(reference_results, use_source=False)
303
  display_reference = add_details(reference_results)
 
329
  else:
330
  display_reference = ""
331
 
332
+ if len(openai_api_key) == 0 and not shared.state.multi_api_key:
333
  status_text = standard_error_msg + no_apikey_msg
334
  logging.info(status_text)
335
  chatbot.append((inputs, ""))
 
404
  max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["all"]
405
 
406
  if sum(all_token_counts) > max_token and should_check_token_count:
407
+ print(all_token_counts)
408
+ count = 0
409
+ while sum(all_token_counts) > max_token - 500 and sum(all_token_counts) > 0:
410
+ count += 1
411
+ del all_token_counts[0]
412
+ del history[:2]
413
  logging.info(status_text)
414
+ status_text = f"为了防止token超限,模型忘记了早期的 {count} 轮对话"
415
  yield chatbot, history, status_text, all_token_counts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
 
418
  def retry(
 
491
  token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
492
  msg = f"保留了最近{num_chat}轮对话"
493
  yield chatbot, history, msg + "," + construct_token_message(
494
+ token_count if len(token_count) > 0 else [0],
495
  ), token_count
496
  logging.info(msg)
497
  logging.info("减少token数量完毕")
modules/config.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from contextlib import contextmanager
3
+ import os
4
+ import logging
5
+ import sys
6
+ import json
7
+
8
+ from . import shared
9
+
10
+
11
+ __all__ = [
12
+ "my_api_key",
13
+ "authflag",
14
+ "auth_list",
15
+ "dockerflag",
16
+ "retrieve_proxy",
17
+ "log_level",
18
+ "advance_docs",
19
+ "update_doc_config",
20
+ "multi_api_key",
21
+ ]
22
+
23
+ # 添加一个统一的config文件,避免文件过多造成的疑惑(优先级最低)
24
+ # 同时,也可以为后续支持自定义功能提供config的帮助
25
+ if os.path.exists("config.json"):
26
+ with open("config.json", "r", encoding='utf-8') as f:
27
+ config = json.load(f)
28
+ else:
29
+ config = {}
30
+
31
+ ## 处理docker if we are running in Docker
32
+ dockerflag = config.get("dockerflag", False)
33
+ if os.environ.get("dockerrun") == "yes":
34
+ dockerflag = True
35
+
36
+ ## 处理 api-key 以及 允许的用户列表
37
+ my_api_key = config.get("openai_api_key", "") # 在这里输入你的 API 密钥
38
+ my_api_key = os.environ.get("my_api_key", my_api_key)
39
+
40
+ ## 多账户机制
41
+ multi_api_key = config.get("multi_api_key", False) # 是否开启多账户机制
42
+ if multi_api_key:
43
+ api_key_list = config.get("api_key_list", [])
44
+ if len(api_key_list) == 0:
45
+ logging.error("多账号模式已开启,但api_key_list为空,请检查config.json")
46
+ sys.exit(1)
47
+ shared.state.set_api_key_queue(api_key_list)
48
+
49
+ auth_list = config.get("users", []) # 实际上是使用者的列表
50
+ authflag = len(auth_list) > 0 # 是否开启认证的状态值,改为判断auth_list长度
51
+
52
+ # 处理自定义的api_host,优先读环境变量的配置,如果存在则自动装配
53
+ api_host = os.environ.get("api_host", config.get("api_host", ""))
54
+ if api_host:
55
+ shared.state.set_api_host(api_host)
56
+
57
+ if dockerflag:
58
+ if my_api_key == "empty":
59
+ logging.error("Please give a api key!")
60
+ sys.exit(1)
61
+ # auth
62
+ username = os.environ.get("USERNAME")
63
+ password = os.environ.get("PASSWORD")
64
+ if not (isinstance(username, type(None)) or isinstance(password, type(None))):
65
+ auth_list.append((os.environ.get("USERNAME"), os.environ.get("PASSWORD")))
66
+ authflag = True
67
+ else:
68
+ if (
69
+ not my_api_key
70
+ and os.path.exists("api_key.txt")
71
+ and os.path.getsize("api_key.txt")
72
+ ):
73
+ with open("api_key.txt", "r") as f:
74
+ my_api_key = f.read().strip()
75
+ if os.path.exists("auth.json"):
76
+ authflag = True
77
+ with open("auth.json", "r", encoding='utf-8') as f:
78
+ auth = json.load(f)
79
+ for _ in auth:
80
+ if auth[_]["username"] and auth[_]["password"]:
81
+ auth_list.append((auth[_]["username"], auth[_]["password"]))
82
+ else:
83
+ logging.error("请检查auth.json文件中的用户名和密码!")
84
+ sys.exit(1)
85
+
86
+ @contextmanager
87
+ def retrieve_openai_api(api_key = None):
88
+ old_api_key = os.environ.get("OPENAI_API_KEY", "")
89
+ if api_key is None:
90
+ os.environ["OPENAI_API_KEY"] = my_api_key
91
+ yield my_api_key
92
+ else:
93
+ os.environ["OPENAI_API_KEY"] = api_key
94
+ yield api_key
95
+ os.environ["OPENAI_API_KEY"] = old_api_key
96
+
97
+ ## 处理log
98
+ log_level = config.get("log_level", "INFO")
99
+ logging.basicConfig(
100
+ level=log_level,
101
+ format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
102
+ )
103
+
104
+ ## 处理代理:
105
+ http_proxy = config.get("http_proxy", "")
106
+ https_proxy = config.get("https_proxy", "")
107
+ http_proxy = os.environ.get("HTTP_PROXY", http_proxy)
108
+ https_proxy = os.environ.get("HTTPS_PROXY", https_proxy)
109
+
110
+ # 重置系统变量,在不需要设置的时候不设置环境变量,以免引起全局代理报错
111
+ os.environ["HTTP_PROXY"] = ""
112
+ os.environ["HTTPS_PROXY"] = ""
113
+
114
+ @contextmanager
115
+ def retrieve_proxy(proxy=None):
116
+ """
117
+ 1, 如果proxy = NONE,设置环境变量,并返回最新设置的代理
118
+ 2,如果proxy != NONE,更新当前的代理配置,但是不更新环境变量
119
+ """
120
+ global http_proxy, https_proxy
121
+ if proxy is not None:
122
+ http_proxy = proxy
123
+ https_proxy = proxy
124
+ yield http_proxy, https_proxy
125
+ else:
126
+ old_var = os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"]
127
+ os.environ["HTTP_PROXY"] = http_proxy
128
+ os.environ["HTTPS_PROXY"] = https_proxy
129
+ yield http_proxy, https_proxy # return new proxy
130
+
131
+ # return old proxy
132
+ os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"] = old_var
133
+
134
+
135
+ ## 处理advance docs
136
+ advance_docs = defaultdict(lambda: defaultdict(dict))
137
+ advance_docs.update(config.get("advance_docs", {}))
138
+ def update_doc_config(two_column_pdf):
139
+ global advance_docs
140
+ if two_column_pdf:
141
+ advance_docs["pdf"]["two_column"] = True
142
+ else:
143
+ advance_docs["pdf"]["two_column"] = False
144
+
145
+ logging.info(f"更新后的文件参数为:{advance_docs}")
modules/llama_func.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  import logging
3
 
4
- from llama_index import GPTSimpleVectorIndex, ServiceContext
5
  from llama_index import download_loader
6
  from llama_index import (
7
  Document,
@@ -10,8 +9,6 @@ from llama_index import (
10
  QuestionAnswerPrompt,
11
  RefinePrompt,
12
  )
13
- from langchain.llms import OpenAI
14
- from langchain.chat_models import ChatOpenAI
15
  import colorama
16
  import PyPDF2
17
  from tqdm import tqdm
@@ -43,28 +40,40 @@ def get_documents(file_src):
43
  logging.debug("Loading documents...")
44
  logging.debug(f"file_src: {file_src}")
45
  for file in file_src:
46
- logging.info(f"loading file: {file.name}")
47
- if os.path.splitext(file.name)[1] == ".pdf":
 
 
 
48
  logging.debug("Loading PDF...")
49
- pdftext = ""
50
- with open(file.name, 'rb') as pdfFileObj:
51
- pdfReader = PyPDF2.PdfReader(pdfFileObj)
52
- for page in tqdm(pdfReader.pages):
53
- pdftext += page.extract_text()
 
 
 
 
 
 
54
  text_raw = pdftext
55
- elif os.path.splitext(file.name)[1] == ".docx":
56
- logging.debug("Loading DOCX...")
57
  DocxReader = download_loader("DocxReader")
58
  loader = DocxReader()
59
- text_raw = loader.load_data(file=file.name)[0].text
60
- elif os.path.splitext(file.name)[1] == ".epub":
61
  logging.debug("Loading EPUB...")
62
  EpubReader = download_loader("EpubReader")
63
  loader = EpubReader()
64
- text_raw = loader.load_data(file=file.name)[0].text
 
 
 
65
  else:
66
  logging.debug("Loading text file...")
67
- with open(file.name, "r", encoding="utf-8") as f:
68
  text_raw = f.read()
69
  text = add_space(text_raw)
70
  # text = block_split(text)
@@ -84,6 +93,9 @@ def construct_index(
84
  embedding_limit=None,
85
  separator=" "
86
  ):
 
 
 
87
  os.environ["OPENAI_API_KEY"] = api_key
88
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
89
  embedding_limit = None if embedding_limit == 0 else embedding_limit
@@ -101,10 +113,11 @@ def construct_index(
101
  try:
102
  documents = get_documents(file_src)
103
  logging.info("构建索引中……")
104
- service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit=chunk_size_limit)
105
- index = GPTSimpleVectorIndex.from_documents(
106
- documents, service_context=service_context
107
- )
 
108
  logging.debug("索引构建完成!")
109
  os.makedirs("./index", exist_ok=True)
110
  index.save_to_disk(f"./index/{index_name}.json")
@@ -117,97 +130,6 @@ def construct_index(
117
  return None
118
 
119
 
120
- def chat_ai(
121
- api_key,
122
- index,
123
- question,
124
- context,
125
- chatbot,
126
- reply_language,
127
- ):
128
- os.environ["OPENAI_API_KEY"] = api_key
129
-
130
- logging.info(f"Question: {question}")
131
-
132
- response, chatbot_display, status_text = ask_ai(
133
- api_key,
134
- index,
135
- question,
136
- replace_today(PROMPT_TEMPLATE),
137
- REFINE_TEMPLATE,
138
- SIM_K,
139
- INDEX_QUERY_TEMPRATURE,
140
- context,
141
- reply_language,
142
- )
143
- if response is None:
144
- status_text = "查询失败,请换个问法试试"
145
- return context, chatbot
146
- response = response
147
-
148
- context.append({"role": "user", "content": question})
149
- context.append({"role": "assistant", "content": response})
150
- chatbot.append((question, chatbot_display))
151
-
152
- os.environ["OPENAI_API_KEY"] = ""
153
- return context, chatbot, status_text
154
-
155
-
156
- def ask_ai(
157
- api_key,
158
- index,
159
- question,
160
- prompt_tmpl,
161
- refine_tmpl,
162
- sim_k=5,
163
- temprature=0,
164
- prefix_messages=[],
165
- reply_language="中文",
166
- ):
167
- os.environ["OPENAI_API_KEY"] = api_key
168
-
169
- logging.debug("Index file found")
170
- logging.debug("Querying index...")
171
- llm_predictor = LLMPredictor(
172
- llm=ChatOpenAI(
173
- temperature=temprature,
174
- model_name="gpt-3.5-turbo-0301",
175
- prefix_messages=prefix_messages,
176
- )
177
- )
178
-
179
- response = None # Initialize response variable to avoid UnboundLocalError
180
- qa_prompt = QuestionAnswerPrompt(prompt_tmpl.replace("{reply_language}", reply_language))
181
- rf_prompt = RefinePrompt(refine_tmpl.replace("{reply_language}", reply_language))
182
- response = index.query(
183
- question,
184
- similarity_top_k=sim_k,
185
- text_qa_template=qa_prompt,
186
- refine_template=rf_prompt,
187
- response_mode="compact",
188
- )
189
-
190
- if response is not None:
191
- logging.info(f"Response: {response}")
192
- ret_text = response.response
193
- nodes = []
194
- for index, node in enumerate(response.source_nodes):
195
- brief = node.source_text[:25].replace("\n", "")
196
- nodes.append(
197
- f"<details><summary>[{index + 1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
198
- )
199
- new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
200
- logging.info(
201
- f"Response: {colorama.Fore.BLUE}{ret_text}{colorama.Style.RESET_ALL}"
202
- )
203
- os.environ["OPENAI_API_KEY"] = ""
204
- return ret_text, new_response, f"查询消耗了{llm_predictor.last_token_usage} tokens"
205
- else:
206
- logging.warning("No response found, returning None")
207
- os.environ["OPENAI_API_KEY"] = ""
208
- return None
209
-
210
-
211
  def add_space(text):
212
  punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
213
  for cn_punc, en_punc in punctuations.items():
 
1
  import os
2
  import logging
3
 
 
4
  from llama_index import download_loader
5
  from llama_index import (
6
  Document,
 
9
  QuestionAnswerPrompt,
10
  RefinePrompt,
11
  )
 
 
12
  import colorama
13
  import PyPDF2
14
  from tqdm import tqdm
 
40
  logging.debug("Loading documents...")
41
  logging.debug(f"file_src: {file_src}")
42
  for file in file_src:
43
+ filepath = file.name
44
+ filename = os.path.basename(filepath)
45
+ file_type = os.path.splitext(filepath)[1]
46
+ logging.info(f"loading file: {filename}")
47
+ if file_type == ".pdf":
48
  logging.debug("Loading PDF...")
49
+ try:
50
+ from modules.pdf_func import parse_pdf
51
+ from modules.config import advance_docs
52
+ two_column = advance_docs["pdf"].get("two_column", False)
53
+ pdftext = parse_pdf(filepath, two_column).text
54
+ except:
55
+ pdftext = ""
56
+ with open(filepath, 'rb') as pdfFileObj:
57
+ pdfReader = PyPDF2.PdfReader(pdfFileObj)
58
+ for page in tqdm(pdfReader.pages):
59
+ pdftext += page.extract_text()
60
  text_raw = pdftext
61
+ elif file_type == ".docx":
62
+ logging.debug("Loading Word...")
63
  DocxReader = download_loader("DocxReader")
64
  loader = DocxReader()
65
+ text_raw = loader.load_data(file=filepath)[0].text
66
+ elif file_type == ".epub":
67
  logging.debug("Loading EPUB...")
68
  EpubReader = download_loader("EpubReader")
69
  loader = EpubReader()
70
+ text_raw = loader.load_data(file=filepath)[0].text
71
+ elif file_type == ".xlsx":
72
+ logging.debug("Loading Excel...")
73
+ text_raw = excel_to_string(filepath)
74
  else:
75
  logging.debug("Loading text file...")
76
+ with open(filepath, "r", encoding="utf-8") as f:
77
  text_raw = f.read()
78
  text = add_space(text_raw)
79
  # text = block_split(text)
 
93
  embedding_limit=None,
94
  separator=" "
95
  ):
96
+ from langchain.chat_models import ChatOpenAI
97
+ from llama_index import GPTSimpleVectorIndex, ServiceContext
98
+
99
  os.environ["OPENAI_API_KEY"] = api_key
100
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
101
  embedding_limit = None if embedding_limit == 0 else embedding_limit
 
113
  try:
114
  documents = get_documents(file_src)
115
  logging.info("构建索引中……")
116
+ with retrieve_proxy():
117
+ service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit=chunk_size_limit)
118
+ index = GPTSimpleVectorIndex.from_documents(
119
+ documents, service_context=service_context
120
+ )
121
  logging.debug("索引构建完成!")
122
  os.makedirs("./index", exist_ok=True)
123
  index.save_to_disk(f"./index/{index_name}.json")
 
130
  return None
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def add_space(text):
134
  punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
135
  for cn_punc, en_punc in punctuations.items():
modules/openai_func.py CHANGED
@@ -10,8 +10,8 @@ from modules.presets import (
10
  read_timeout_prompt
11
  )
12
 
13
- from modules import shared
14
- from modules.utils import get_proxies
15
  import os, datetime
16
 
17
  def get_billing_data(openai_api_key, billing_url):
@@ -19,58 +19,35 @@ def get_billing_data(openai_api_key, billing_url):
19
  "Content-Type": "application/json",
20
  "Authorization": f"Bearer {openai_api_key}"
21
  }
22
-
23
  timeout = timeout_all
24
- proxies = get_proxies()
25
- response = requests.get(
26
- billing_url,
27
- headers=headers,
28
- timeout=timeout,
29
- proxies=proxies,
30
- )
31
-
32
  if response.status_code == 200:
33
  data = response.json()
34
  return data
35
  else:
36
  raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
37
-
38
 
39
  def get_usage(openai_api_key):
40
  try:
41
- balance_data=get_billing_data(openai_api_key, BALANCE_API_URL)
42
- logging.debug(balance_data)
 
 
43
  try:
44
- balance = balance_data["total_available"] if balance_data["total_available"] else 0
45
- total_used = balance_data["total_used"] if balance_data["total_used"] else 0
46
- usage_percent = round(total_used / (total_used+balance) * 100, 2)
47
  except Exception as e:
48
- logging.error(f"API使用情况解析失败:"+str(e))
49
- balance = 0
50
- total_used=0
51
- return f"**API使用情况解析失败**"
52
- if balance == 0:
53
- last_day_of_month = datetime.datetime.now().strftime("%Y-%m-%d")
54
- first_day_of_month = datetime.datetime.now().replace(day=1).strftime("%Y-%m-%d")
55
- usage_url = f"{USAGE_API_URL}?start_date={first_day_of_month}&end_date={last_day_of_month}"
56
- try:
57
- usage_data = get_billing_data(openai_api_key, usage_url)
58
- except Exception as e:
59
- logging.error(f"获取API使用情况失败:"+str(e))
60
- return f"**获取API使用情况失败**"
61
- return f"**本月使用金额** \u3000 ${usage_data['total_usage'] / 100}"
62
-
63
- # return f"**免费额度**(已用/余额)\u3000${total_used} / ${balance}"
64
- return f"""\
65
- <b>免费额度使用情况</b>
66
- <div class="progress-bar">
67
- <div class="progress" style="width: {usage_percent}%;">
68
- <span class="progress-text">{usage_percent}%</span>
69
- </div>
70
- </div>
71
- <div style="display: flex; justify-content: space-between;"><span>已用 ${total_used}</span><span>可用 ${balance}</span></div>
72
- """
73
-
74
  except requests.exceptions.ConnectTimeout:
75
  status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
76
  return status_text
@@ -80,3 +57,9 @@ def get_usage(openai_api_key):
80
  except Exception as e:
81
  logging.error(f"获取API使用情况失败:"+str(e))
82
  return standard_error_msg + error_retrieve_prompt
 
 
 
 
 
 
 
10
  read_timeout_prompt
11
  )
12
 
13
+ from . import shared
14
+ from modules.config import retrieve_proxy
15
  import os, datetime
16
 
17
  def get_billing_data(openai_api_key, billing_url):
 
19
  "Content-Type": "application/json",
20
  "Authorization": f"Bearer {openai_api_key}"
21
  }
22
+
23
  timeout = timeout_all
24
+ with retrieve_proxy():
25
+ response = requests.get(
26
+ billing_url,
27
+ headers=headers,
28
+ timeout=timeout,
29
+ )
30
+
 
31
  if response.status_code == 200:
32
  data = response.json()
33
  return data
34
  else:
35
  raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
36
+
37
 
38
  def get_usage(openai_api_key):
39
  try:
40
+ curr_time = datetime.datetime.now()
41
+ last_day_of_month = get_last_day_of_month(curr_time).strftime("%Y-%m-%d")
42
+ first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
43
+ usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
44
  try:
45
+ usage_data = get_billing_data(openai_api_key, usage_url)
 
 
46
  except Exception as e:
47
+ logging.error(f"获取API使用情况失败:"+str(e))
48
+ return f"**获取API使用情况失败**"
49
+ rounded_usage = "{:.5f}".format(usage_data['total_usage']/100)
50
+ return f"**本月使用金额** \u3000 ${rounded_usage}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  except requests.exceptions.ConnectTimeout:
52
  status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
53
  return status_text
 
57
  except Exception as e:
58
  logging.error(f"获取API使用情况失败:"+str(e))
59
  return standard_error_msg + error_retrieve_prompt
60
+
61
+ def get_last_day_of_month(any_day):
62
+ # The day 28 exists in every month. 4 days later, it's always next month
63
+ next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
64
+ # subtracting the number of the current day brings us back one month
65
+ return next_month - datetime.timedelta(days=next_month.day)
modules/pdf_func.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import SimpleNamespace
2
+ import pdfplumber
3
+ import logging
4
+ from llama_index import Document
5
+
6
+ def prepare_table_config(crop_page):
7
+ """Prepare table查找边界, 要求page为原始page
8
+
9
+ From https://github.com/jsvine/pdfplumber/issues/242
10
+ """
11
+ page = crop_page.root_page # root/parent
12
+ cs = page.curves + page.edges
13
+ def curves_to_edges():
14
+ """See https://github.com/jsvine/pdfplumber/issues/127"""
15
+ edges = []
16
+ for c in cs:
17
+ edges += pdfplumber.utils.rect_to_edges(c)
18
+ return edges
19
+ edges = curves_to_edges()
20
+ return {
21
+ "vertical_strategy": "explicit",
22
+ "horizontal_strategy": "explicit",
23
+ "explicit_vertical_lines": edges,
24
+ "explicit_horizontal_lines": edges,
25
+ "intersection_y_tolerance": 10,
26
+ }
27
+
28
+ def get_text_outside_table(crop_page):
29
+ ts = prepare_table_config(crop_page)
30
+ if len(ts["explicit_vertical_lines"]) == 0 or len(ts["explicit_horizontal_lines"]) == 0:
31
+ return crop_page
32
+
33
+ ### Get the bounding boxes of the tables on the page.
34
+ bboxes = [table.bbox for table in crop_page.root_page.find_tables(table_settings=ts)]
35
+ def not_within_bboxes(obj):
36
+ """Check if the object is in any of the table's bbox."""
37
+ def obj_in_bbox(_bbox):
38
+ """See https://github.com/jsvine/pdfplumber/blob/stable/pdfplumber/table.py#L404"""
39
+ v_mid = (obj["top"] + obj["bottom"]) / 2
40
+ h_mid = (obj["x0"] + obj["x1"]) / 2
41
+ x0, top, x1, bottom = _bbox
42
+ return (h_mid >= x0) and (h_mid < x1) and (v_mid >= top) and (v_mid < bottom)
43
+ return not any(obj_in_bbox(__bbox) for __bbox in bboxes)
44
+
45
+ return crop_page.filter(not_within_bboxes)
46
+ # 请使用 LaTeX 表达公式,行内公式以 $ 包裹,行间公式以 $$ 包裹
47
+
48
+ extract_words = lambda page: page.extract_words(keep_blank_chars=True, y_tolerance=0, x_tolerance=1, extra_attrs=["fontname", "size", "object_type"])
49
+ # dict_keys(['text', 'x0', 'x1', 'top', 'doctop', 'bottom', 'upright', 'direction', 'fontname', 'size'])
50
+
51
+ def get_title_with_cropped_page(first_page):
52
+ title = [] # 处理标题
53
+ x0,top,x1,bottom = first_page.bbox # 获取页面边框
54
+
55
+ for word in extract_words(first_page):
56
+ word = SimpleNamespace(**word)
57
+
58
+ if word.size >= 14:
59
+ title.append(word.text)
60
+ title_bottom = word.bottom
61
+ elif word.text == "Abstract": # 获取页面abstract
62
+ top = word.top
63
+
64
+ user_info = [i["text"] for i in extract_words(first_page.within_bbox((x0,title_bottom,x1,top)))]
65
+ # 裁剪掉上半部分, within_bbox: full_included; crop: partial_included
66
+ return title, user_info, first_page.within_bbox((x0,top,x1,bottom))
67
+
68
+ def get_column_cropped_pages(pages, two_column=True):
69
+ new_pages = []
70
+ for page in pages:
71
+ if two_column:
72
+ left = page.within_bbox((0, 0, page.width/2, page.height),relative=True)
73
+ right = page.within_bbox((page.width/2, 0, page.width, page.height), relative=True)
74
+ new_pages.append(left)
75
+ new_pages.append(right)
76
+ else:
77
+ new_pages.append(page)
78
+
79
+ return new_pages
80
+
81
+ def parse_pdf(filename, two_column = True):
82
+ level = logging.getLogger().level
83
+ if level == logging.getLevelName("DEBUG"):
84
+ logging.getLogger().setLevel("INFO")
85
+
86
+ with pdfplumber.open(filename) as pdf:
87
+ title, user_info, first_page = get_title_with_cropped_page(pdf.pages[0])
88
+ new_pages = get_column_cropped_pages([first_page] + pdf.pages[1:], two_column)
89
+
90
+ chapters = []
91
+ # tuple (chapter_name, [pageid] (start,stop), chapter_text)
92
+ create_chapter = lambda page_start,name_top,name_bottom: SimpleNamespace(
93
+ name=[],
94
+ name_top=name_top,
95
+ name_bottom=name_bottom,
96
+ record_chapter_name = True,
97
+
98
+ page_start=page_start,
99
+ page_stop=None,
100
+
101
+ text=[],
102
+ )
103
+ cur_chapter = None
104
+
105
+ # 按页遍历PDF文档
106
+ for idx, page in enumerate(new_pages):
107
+ page = get_text_outside_table(page)
108
+
109
+ # 按行遍历页面文本
110
+ for word in extract_words(page):
111
+ word = SimpleNamespace(**word)
112
+
113
+ # 检查行文本是否以12号字体打印,如果是,则将其作为新章节开始
114
+ if word.size >= 11: # 出现chapter name
115
+ if cur_chapter is None:
116
+ cur_chapter = create_chapter(page.page_number, word.top, word.bottom)
117
+ elif not cur_chapter.record_chapter_name or (cur_chapter.name_bottom != cur_chapter.name_bottom and cur_chapter.name_top != cur_chapter.name_top):
118
+ # 不再继续写chapter name
119
+ cur_chapter.page_stop = page.page_number # stop id
120
+ chapters.append(cur_chapter)
121
+ # 重置当前chapter信息
122
+ cur_chapter = create_chapter(page.page_number, word.top, word.bottom)
123
+
124
+ # print(word.size, word.top, word.bottom, word.text)
125
+ cur_chapter.name.append(word.text)
126
+ else:
127
+ cur_chapter.record_chapter_name = False # chapter name 结束
128
+ cur_chapter.text.append(word.text)
129
+ else:
130
+ # 处理最后一个章节
131
+ cur_chapter.page_stop = page.page_number # stop id
132
+ chapters.append(cur_chapter)
133
+
134
+ for i in chapters:
135
+ logging.info(f"section: {i.name} pages:{i.page_start, i.page_stop} word-count:{len(i.text)}")
136
+ logging.debug(" ".join(i.text))
137
+
138
+ title = " ".join(title)
139
+ user_info = " ".join(user_info)
140
+ text = f"Article Title: {title}, Information:{user_info}\n"
141
+ for idx, chapter in enumerate(chapters):
142
+ chapter.name = " ".join(chapter.name)
143
+ text += f"The {idx}th Chapter {chapter.name}: " + " ".join(chapter.text) + "\n"
144
+
145
+ logging.getLogger().setLevel(level)
146
+ return Document(text=text, extra_info={"title": title})
147
+
148
+ BASE_POINTS = """
149
+ 1. Who are the authors?
150
+ 2. What is the process of the proposed method?
151
+ 3. What is the performance of the proposed method? Please note down its performance metrics.
152
+ 4. What are the baseline models and their performances? Please note down these baseline methods.
153
+ 5. What dataset did this paper use?
154
+ """
155
+
156
+ READING_PROMPT = """
157
+ You are a researcher helper bot. You can help the user with research paper reading and summarizing. \n
158
+ Now I am going to send you a paper. You need to read it and summarize it for me part by part. \n
159
+ When you are reading, You need to focus on these key points:{}
160
+ """
161
+
162
+ READING_PROMT_V2 = """
163
+ You are a researcher helper bot. You can help the user with research paper reading and summarizing. \n
164
+ Now I am going to send you a paper. You need to read it and summarize it for me part by part. \n
165
+ When you are reading, You need to focus on these key points:{},
166
+
167
+ And You need to generate a brief but informative title for this part.
168
+ Your return format:
169
+ - title: '...'
170
+ - summary: '...'
171
+ """
172
+
173
+ SUMMARY_PROMPT = "You are a researcher helper bot. Now you need to read the summaries of a research paper."
174
+
175
+
176
+ if __name__ == '__main__':
177
+ # Test code
178
+ z = parse_pdf("./build/test.pdf")
179
+ print(z["user_info"])
180
+ print(z["title"])
modules/presets.py CHANGED
@@ -1,12 +1,14 @@
1
  # -*- coding:utf-8 -*-
2
  import gradio as gr
 
3
 
4
  # ChatGPT 设置
5
  initial_prompt = "You are a helpful assistant."
6
- API_URL = "https://api.openai.com/v1/chat/completions"
 
7
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
8
  USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
9
- HISTORY_DIR = "history"
10
  TEMPLATES_DIR = "templates"
11
 
12
  # 错误信息
@@ -28,7 +30,7 @@ CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
28
  SIM_K = 5
29
  INDEX_QUERY_TEMPRATURE = 1.0
30
 
31
- title = """<h1 align="left" style="min-width:200px; margin-top:0;">川虎ChatGPT 🚀</h1>"""
32
  description = """\
33
  <div align="center" style="margin:16px 0">
34
 
 
1
  # -*- coding:utf-8 -*-
2
  import gradio as gr
3
+ from pathlib import Path
4
 
5
  # ChatGPT 设置
6
  initial_prompt = "You are a helpful assistant."
7
+ API_HOST = "api.openai.com"
8
+ COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
9
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
10
  USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
11
+ HISTORY_DIR = Path("history")
12
  TEMPLATES_DIR = "templates"
13
 
14
  # 错误信息
 
30
  SIM_K = 5
31
  INDEX_QUERY_TEMPRATURE = 1.0
32
 
33
+ title = """<h1 align="left" style="min-width:200px; margin-top:6px; white-space: nowrap;">川虎ChatGPT 🚀</h1>"""
34
  description = """\
35
  <div align="center" style="margin:16px 0">
36
 
modules/shared.py CHANGED
@@ -1,8 +1,13 @@
1
- from modules.presets import API_URL
 
 
2
 
3
  class State:
4
  interrupted = False
5
- api_url = API_URL
 
 
 
6
 
7
  def interrupt(self):
8
  self.interrupted = True
@@ -10,15 +15,41 @@ class State:
10
  def recover(self):
11
  self.interrupted = False
12
 
13
- def set_api_url(self, api_url):
14
- self.api_url = api_url
 
 
 
15
 
16
- def reset_api_url(self):
17
- self.api_url = API_URL
18
- return self.api_url
 
 
 
19
 
20
  def reset_all(self):
21
  self.interrupted = False
22
- self.api_url = API_URL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  state = State()
 
1
+ from modules.presets import COMPLETION_URL, BALANCE_API_URL, USAGE_API_URL, API_HOST
2
+ import os
3
+ import queue
4
 
5
  class State:
6
  interrupted = False
7
+ multi_api_key = False
8
+ completion_url = COMPLETION_URL
9
+ balance_api_url = BALANCE_API_URL
10
+ usage_api_url = USAGE_API_URL
11
 
12
  def interrupt(self):
13
  self.interrupted = True
 
15
  def recover(self):
16
  self.interrupted = False
17
 
18
+ def set_api_host(self, api_host):
19
+ self.completion_url = f"https://{api_host}/v1/chat/completions"
20
+ self.balance_api_url = f"https://{api_host}/dashboard/billing/credit_grants"
21
+ self.usage_api_url = f"https://{api_host}/dashboard/billing/usage"
22
+ os.environ["OPENAI_API_BASE"] = f"https://{api_host}/v1"
23
 
24
+ def reset_api_host(self):
25
+ self.completion_url = COMPLETION_URL
26
+ self.balance_api_url = BALANCE_API_URL
27
+ self.usage_api_url = USAGE_API_URL
28
+ os.environ["OPENAI_API_BASE"] = f"https://{API_HOST}/v1"
29
+ return API_HOST
30
 
31
  def reset_all(self):
32
  self.interrupted = False
33
+ self.completion_url = COMPLETION_URL
34
+
35
+ def set_api_key_queue(self, api_key_list):
36
+ self.multi_api_key = True
37
+ self.api_key_queue = queue.Queue()
38
+ for api_key in api_key_list:
39
+ self.api_key_queue.put(api_key)
40
+
41
+ def switching_api_key(self, func):
42
+ if not hasattr(self, "api_key_queue"):
43
+ return func
44
+
45
+ def wrapped(*args, **kwargs):
46
+ api_key = self.api_key_queue.get()
47
+ args = list(args)[1:]
48
+ ret = func(api_key, *args, **kwargs)
49
+ self.api_key_queue.put(api_key)
50
+ return ret
51
+
52
+ return wrapped
53
+
54
 
55
  state = State()
modules/utils.py CHANGED
@@ -21,14 +21,11 @@ from markdown import markdown
21
  from pygments import highlight
22
  from pygments.lexers import get_lexer_by_name
23
  from pygments.formatters import HtmlFormatter
 
24
 
25
  from modules.presets import *
26
- import modules.shared as shared
27
-
28
- logging.basicConfig(
29
- level=logging.INFO,
30
- format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
31
- )
32
 
33
  if TYPE_CHECKING:
34
  from typing import TypedDict
@@ -156,8 +153,11 @@ def construct_assistant(text):
156
  return construct_text("assistant", text)
157
 
158
 
159
- def construct_token_message(token, stream=False):
160
- return f"Token 计数: {token}"