Spaces:
Runtime error
Runtime error
Commit
·
fa4087a
1
Parent(s):
c9c16d7
Upload 37 files
Browse files- ChuanhuChatbot.py +423 -0
- README.md +2 -2
- assets/custom.css +13 -2
- assets/custom.js +70 -1
- modules/__pycache__/chat_func.cpython-39.pyc +0 -0
- modules/__pycache__/config.cpython-39.pyc +0 -0
- modules/__pycache__/llama_func.cpython-39.pyc +0 -0
- modules/__pycache__/openai_func.cpython-39.pyc +0 -0
- modules/__pycache__/overwrites.cpython-39.pyc +0 -0
- modules/__pycache__/pdf_func.cpython-39.pyc +0 -0
- modules/__pycache__/presets.cpython-39.pyc +0 -0
- modules/__pycache__/proxy_func.cpython-39.pyc +0 -0
- modules/__pycache__/shared.cpython-39.pyc +0 -0
- modules/__pycache__/utils.cpython-39.pyc +0 -0
- modules/chat_func.py +39 -55
- modules/config.py +145 -0
- modules/llama_func.py +33 -111
- modules/openai_func.py +26 -43
- modules/pdf_func.py +180 -0
- modules/presets.py +5 -3
- modules/shared.py +39 -8
- modules/utils.py +74 -58
- requirements.txt +3 -1
- templates/1 中文提示词.json +490 -0
- templates/2 English Prompts.csv +9 -1
- templates/3 繁體提示詞.json +490 -0
- templates/4 川虎的Prompts.json +14 -0
ChuanhuChatbot.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
from modules import config
|
9 |
+
from modules.config import *
|
10 |
+
from modules.utils import *
|
11 |
+
from modules.presets import *
|
12 |
+
from modules.overwrites import *
|
13 |
+
from modules.chat_func import *
|
14 |
+
from modules.openai_func import get_usage
|
15 |
+
|
16 |
+
gr.Chatbot.postprocess = postprocess
|
17 |
+
PromptHelper.compact_text_chunks = compact_text_chunks
|
18 |
+
|
19 |
+
with open("assets/custom.css", "r", encoding="utf-8") as f:
|
20 |
+
customCSS = f.read()
|
21 |
+
|
22 |
+
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
23 |
+
user_name = gr.State("")
|
24 |
+
history = gr.State([])
|
25 |
+
token_count = gr.State([])
|
26 |
+
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
27 |
+
user_api_key = gr.State(my_api_key)
|
28 |
+
user_question = gr.State("")
|
29 |
+
outputing = gr.State(False)
|
30 |
+
topic = gr.State("未命名对话历史记录")
|
31 |
+
|
32 |
+
with gr.Row():
|
33 |
+
with gr.Column():
|
34 |
+
gr.HTML(title)
|
35 |
+
user_info = gr.Markdown(value="", elem_id="user_info")
|
36 |
+
gr.HTML('<center><a href="https://huggingface.co/spaces/JohnSmith9982/ChuanhuChatGPT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a></center>')
|
37 |
+
status_display = gr.Markdown(get_geoip(), elem_id="status_display")
|
38 |
+
|
39 |
+
# https://github.com/gradio-app/gradio/pull/3296
|
40 |
+
def create_greeting(request: gr.Request):
|
41 |
+
if hasattr(request, "username") and request.username: # is not None or is not ""
|
42 |
+
logging.info(f"Get User Name: {request.username}")
|
43 |
+
return gr.Markdown.update(value=f"User: {request.username}"), request.username
|
44 |
+
else:
|
45 |
+
return gr.Markdown.update(value=f"User: default", visible=False), ""
|
46 |
+
demo.load(create_greeting, inputs=None, outputs=[user_info, user_name])
|
47 |
+
|
48 |
+
with gr.Row().style(equal_height=True):
|
49 |
+
with gr.Column(scale=5):
|
50 |
+
with gr.Row():
|
51 |
+
chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%")
|
52 |
+
with gr.Row():
|
53 |
+
with gr.Column(scale=12):
|
54 |
+
user_input = gr.Textbox(
|
55 |
+
elem_id="user_input_tb",
|
56 |
+
show_label=False, placeholder="在这里输入"
|
57 |
+
).style(container=False)
|
58 |
+
with gr.Column(min_width=70, scale=1):
|
59 |
+
submitBtn = gr.Button("发送", variant="primary")
|
60 |
+
cancelBtn = gr.Button("取消", variant="secondary", visible=False)
|
61 |
+
with gr.Row():
|
62 |
+
emptyBtn = gr.Button(
|
63 |
+
"🧹 新的对话",
|
64 |
+
)
|
65 |
+
retryBtn = gr.Button("🔄 重新生成")
|
66 |
+
delFirstBtn = gr.Button("🗑️ 删除最旧对话")
|
67 |
+
delLastBtn = gr.Button("🗑️ 删除最新对话")
|
68 |
+
reduceTokenBtn = gr.Button("♻️ 总结对话")
|
69 |
+
|
70 |
+
with gr.Column():
|
71 |
+
with gr.Column(min_width=50, scale=1):
|
72 |
+
with gr.Tab(label="ChatGPT"):
|
73 |
+
keyTxt = gr.Textbox(
|
74 |
+
show_label=True,
|
75 |
+
placeholder=f"OpenAI API-key...",
|
76 |
+
value=hide_middle_chars(my_api_key),
|
77 |
+
type="password",
|
78 |
+
visible=not HIDE_MY_KEY,
|
79 |
+
label="API-Key",
|
80 |
+
)
|
81 |
+
if multi_api_key:
|
82 |
+
usageTxt = gr.Markdown("多账号模式已开启,无需输入key,可直接开始对话", elem_id="usage_display")
|
83 |
+
else:
|
84 |
+
usageTxt = gr.Markdown("**发送消息** 或 **提交key** 以显示额度", elem_id="usage_display")
|
85 |
+
model_select_dropdown = gr.Dropdown(
|
86 |
+
label="选择模型", choices=MODELS, multiselect=False, value=MODELS[0]
|
87 |
+
)
|
88 |
+
use_streaming_checkbox = gr.Checkbox(
|
89 |
+
label="实时传输回答", value=True, visible=enable_streaming_option
|
90 |
+
)
|
91 |
+
use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
|
92 |
+
language_select_dropdown = gr.Dropdown(
|
93 |
+
label="选择回复语言(针对搜索&索引功能)",
|
94 |
+
choices=REPLY_LANGUAGES,
|
95 |
+
multiselect=False,
|
96 |
+
value=REPLY_LANGUAGES[0],
|
97 |
+
)
|
98 |
+
index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
|
99 |
+
two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
|
100 |
+
# TODO: 公式ocr
|
101 |
+
# formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
|
102 |
+
|
103 |
+
with gr.Tab(label="Prompt"):
|
104 |
+
systemPromptTxt = gr.Textbox(
|
105 |
+
show_label=True,
|
106 |
+
placeholder=f"在这里输入System Prompt...",
|
107 |
+
label="System prompt",
|
108 |
+
value=initial_prompt,
|
109 |
+
lines=10,
|
110 |
+
).style(container=False)
|
111 |
+
with gr.Accordion(label="加载Prompt模板", open=True):
|
112 |
+
with gr.Column():
|
113 |
+
with gr.Row():
|
114 |
+
with gr.Column(scale=6):
|
115 |
+
templateFileSelectDropdown = gr.Dropdown(
|
116 |
+
label="选择Prompt模板集合文件",
|
117 |
+
choices=get_template_names(plain=True),
|
118 |
+
multiselect=False,
|
119 |
+
value=get_template_names(plain=True)[0],
|
120 |
+
).style(container=False)
|
121 |
+
with gr.Column(scale=1):
|
122 |
+
templateRefreshBtn = gr.Button("🔄 刷新")
|
123 |
+
with gr.Row():
|
124 |
+
with gr.Column():
|
125 |
+
templateSelectDropdown = gr.Dropdown(
|
126 |
+
label="从Prompt模板中加载",
|
127 |
+
choices=load_template(
|
128 |
+
get_template_names(plain=True)[0], mode=1
|
129 |
+
),
|
130 |
+
multiselect=False,
|
131 |
+
).style(container=False)
|
132 |
+
|
133 |
+
with gr.Tab(label="保存/加载"):
|
134 |
+
with gr.Accordion(label="保存/加载对话历史记录", open=True):
|
135 |
+
with gr.Column():
|
136 |
+
with gr.Row():
|
137 |
+
with gr.Column(scale=6):
|
138 |
+
historyFileSelectDropdown = gr.Dropdown(
|
139 |
+
label="从列表中加载对话",
|
140 |
+
choices=get_history_names(plain=True),
|
141 |
+
multiselect=False,
|
142 |
+
value=get_history_names(plain=True)[0],
|
143 |
+
)
|
144 |
+
with gr.Column(scale=1):
|
145 |
+
historyRefreshBtn = gr.Button("🔄 刷新")
|
146 |
+
with gr.Row():
|
147 |
+
with gr.Column(scale=6):
|
148 |
+
saveFileName = gr.Textbox(
|
149 |
+
show_label=True,
|
150 |
+
placeholder=f"设置文件名: 默认为.json,可选为.md",
|
151 |
+
label="设置保存文件名",
|
152 |
+
value="对话历史记录",
|
153 |
+
).style(container=True)
|
154 |
+
with gr.Column(scale=1):
|
155 |
+
saveHistoryBtn = gr.Button("💾 保存对话")
|
156 |
+
exportMarkdownBtn = gr.Button("📝 导出为Markdown")
|
157 |
+
gr.Markdown("默认保存于history文件夹")
|
158 |
+
with gr.Row():
|
159 |
+
with gr.Column():
|
160 |
+
downloadFile = gr.File(interactive=True)
|
161 |
+
|
162 |
+
with gr.Tab(label="高级"):
|
163 |
+
gr.Markdown("# ⚠️ 务必谨慎更改 ⚠️\n\n如果无法使用请恢复默认设置")
|
164 |
+
default_btn = gr.Button("🔙 恢复默认设置")
|
165 |
+
|
166 |
+
with gr.Accordion("参数", open=False):
|
167 |
+
top_p = gr.Slider(
|
168 |
+
minimum=-0,
|
169 |
+
maximum=1.0,
|
170 |
+
value=1.0,
|
171 |
+
step=0.05,
|
172 |
+
interactive=True,
|
173 |
+
label="Top-p",
|
174 |
+
)
|
175 |
+
temperature = gr.Slider(
|
176 |
+
minimum=-0,
|
177 |
+
maximum=2.0,
|
178 |
+
value=1.0,
|
179 |
+
step=0.1,
|
180 |
+
interactive=True,
|
181 |
+
label="Temperature",
|
182 |
+
)
|
183 |
+
|
184 |
+
with gr.Accordion("网络设置", open=False, visible=False):
|
185 |
+
# 优先展示自定义的api_host
|
186 |
+
apihostTxt = gr.Textbox(
|
187 |
+
show_label=True,
|
188 |
+
placeholder=f"在这里输入API-Host...",
|
189 |
+
label="API-Host",
|
190 |
+
value=config.api_host or shared.API_HOST,
|
191 |
+
lines=1,
|
192 |
+
)
|
193 |
+
changeAPIURLBtn = gr.Button("🔄 切换API地址")
|
194 |
+
proxyTxt = gr.Textbox(
|
195 |
+
show_label=True,
|
196 |
+
placeholder=f"在这里输入代理地址...",
|
197 |
+
label="代理地址(示例:http://127.0.0.1:10809)",
|
198 |
+
value="",
|
199 |
+
lines=2,
|
200 |
+
)
|
201 |
+
changeProxyBtn = gr.Button("🔄 设置代理地址")
|
202 |
+
|
203 |
+
gr.Markdown(description)
|
204 |
+
gr.HTML(footer.format(versions=versions_html()), elem_id="footer")
|
205 |
+
chatgpt_predict_args = dict(
|
206 |
+
fn=predict,
|
207 |
+
inputs=[
|
208 |
+
user_api_key,
|
209 |
+
systemPromptTxt,
|
210 |
+
history,
|
211 |
+
user_question,
|
212 |
+
chatbot,
|
213 |
+
token_count,
|
214 |
+
top_p,
|
215 |
+
temperature,
|
216 |
+
use_streaming_checkbox,
|
217 |
+
model_select_dropdown,
|
218 |
+
use_websearch_checkbox,
|
219 |
+
index_files,
|
220 |
+
language_select_dropdown,
|
221 |
+
],
|
222 |
+
outputs=[chatbot, history, status_display, token_count],
|
223 |
+
show_progress=True,
|
224 |
+
)
|
225 |
+
|
226 |
+
start_outputing_args = dict(
|
227 |
+
fn=start_outputing,
|
228 |
+
inputs=[],
|
229 |
+
outputs=[submitBtn, cancelBtn],
|
230 |
+
show_progress=True,
|
231 |
+
)
|
232 |
+
|
233 |
+
end_outputing_args = dict(
|
234 |
+
fn=end_outputing, inputs=[], outputs=[submitBtn, cancelBtn]
|
235 |
+
)
|
236 |
+
|
237 |
+
reset_textbox_args = dict(
|
238 |
+
fn=reset_textbox, inputs=[], outputs=[user_input]
|
239 |
+
)
|
240 |
+
|
241 |
+
transfer_input_args = dict(
|
242 |
+
fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn, cancelBtn], show_progress=True
|
243 |
+
)
|
244 |
+
|
245 |
+
get_usage_args = dict(
|
246 |
+
fn=get_usage, inputs=[user_api_key], outputs=[usageTxt], show_progress=False
|
247 |
+
)
|
248 |
+
|
249 |
+
|
250 |
+
# Chatbot
|
251 |
+
cancelBtn.click(cancel_outputing, [], [])
|
252 |
+
|
253 |
+
user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
|
254 |
+
user_input.submit(**get_usage_args)
|
255 |
+
|
256 |
+
submitBtn.click(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
|
257 |
+
submitBtn.click(**get_usage_args)
|
258 |
+
|
259 |
+
emptyBtn.click(
|
260 |
+
reset_state,
|
261 |
+
outputs=[chatbot, history, token_count, status_display],
|
262 |
+
show_progress=True,
|
263 |
+
)
|
264 |
+
emptyBtn.click(**reset_textbox_args)
|
265 |
+
|
266 |
+
retryBtn.click(**start_outputing_args).then(
|
267 |
+
retry,
|
268 |
+
[
|
269 |
+
user_api_key,
|
270 |
+
systemPromptTxt,
|
271 |
+
history,
|
272 |
+
chatbot,
|
273 |
+
token_count,
|
274 |
+
top_p,
|
275 |
+
temperature,
|
276 |
+
use_streaming_checkbox,
|
277 |
+
model_select_dropdown,
|
278 |
+
language_select_dropdown,
|
279 |
+
],
|
280 |
+
[chatbot, history, status_display, token_count],
|
281 |
+
show_progress=True,
|
282 |
+
).then(**end_outputing_args)
|
283 |
+
retryBtn.click(**get_usage_args)
|
284 |
+
|
285 |
+
delFirstBtn.click(
|
286 |
+
delete_first_conversation,
|
287 |
+
[history, token_count],
|
288 |
+
[history, token_count, status_display],
|
289 |
+
)
|
290 |
+
|
291 |
+
delLastBtn.click(
|
292 |
+
delete_last_conversation,
|
293 |
+
[chatbot, history, token_count],
|
294 |
+
[chatbot, history, token_count, status_display],
|
295 |
+
show_progress=True,
|
296 |
+
)
|
297 |
+
|
298 |
+
reduceTokenBtn.click(
|
299 |
+
reduce_token_size,
|
300 |
+
[
|
301 |
+
user_api_key,
|
302 |
+
systemPromptTxt,
|
303 |
+
history,
|
304 |
+
chatbot,
|
305 |
+
token_count,
|
306 |
+
top_p,
|
307 |
+
temperature,
|
308 |
+
gr.State(sum(token_count.value[-4:])),
|
309 |
+
model_select_dropdown,
|
310 |
+
language_select_dropdown,
|
311 |
+
],
|
312 |
+
[chatbot, history, status_display, token_count],
|
313 |
+
show_progress=True,
|
314 |
+
)
|
315 |
+
reduceTokenBtn.click(**get_usage_args)
|
316 |
+
|
317 |
+
two_column.change(update_doc_config, [two_column], None)
|
318 |
+
|
319 |
+
# ChatGPT
|
320 |
+
keyTxt.change(submit_key, keyTxt, [user_api_key, status_display]).then(**get_usage_args)
|
321 |
+
keyTxt.submit(**get_usage_args)
|
322 |
+
|
323 |
+
# Template
|
324 |
+
templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
|
325 |
+
templateFileSelectDropdown.change(
|
326 |
+
load_template,
|
327 |
+
[templateFileSelectDropdown],
|
328 |
+
[promptTemplates, templateSelectDropdown],
|
329 |
+
show_progress=True,
|
330 |
+
)
|
331 |
+
templateSelectDropdown.change(
|
332 |
+
get_template_content,
|
333 |
+
[promptTemplates, templateSelectDropdown, systemPromptTxt],
|
334 |
+
[systemPromptTxt],
|
335 |
+
show_progress=True,
|
336 |
+
)
|
337 |
+
|
338 |
+
# S&L
|
339 |
+
saveHistoryBtn.click(
|
340 |
+
save_chat_history,
|
341 |
+
[saveFileName, systemPromptTxt, history, chatbot, user_name],
|
342 |
+
downloadFile,
|
343 |
+
show_progress=True,
|
344 |
+
)
|
345 |
+
saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
346 |
+
exportMarkdownBtn.click(
|
347 |
+
export_markdown,
|
348 |
+
[saveFileName, systemPromptTxt, history, chatbot, user_name],
|
349 |
+
downloadFile,
|
350 |
+
show_progress=True,
|
351 |
+
)
|
352 |
+
historyRefreshBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
353 |
+
historyFileSelectDropdown.change(
|
354 |
+
load_chat_history,
|
355 |
+
[historyFileSelectDropdown, systemPromptTxt, history, chatbot, user_name],
|
356 |
+
[saveFileName, systemPromptTxt, history, chatbot],
|
357 |
+
show_progress=True,
|
358 |
+
)
|
359 |
+
downloadFile.change(
|
360 |
+
load_chat_history,
|
361 |
+
[downloadFile, systemPromptTxt, history, chatbot, user_name],
|
362 |
+
[saveFileName, systemPromptTxt, history, chatbot],
|
363 |
+
)
|
364 |
+
|
365 |
+
# Advanced
|
366 |
+
default_btn.click(
|
367 |
+
reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
|
368 |
+
)
|
369 |
+
changeAPIURLBtn.click(
|
370 |
+
change_api_host,
|
371 |
+
[apihostTxt],
|
372 |
+
[status_display],
|
373 |
+
show_progress=True,
|
374 |
+
)
|
375 |
+
changeProxyBtn.click(
|
376 |
+
change_proxy,
|
377 |
+
[proxyTxt],
|
378 |
+
[status_display],
|
379 |
+
show_progress=True,
|
380 |
+
)
|
381 |
+
|
382 |
+
logging.info(
|
383 |
+
colorama.Back.GREEN
|
384 |
+
+ "\n川虎的温馨提示:访问 http://localhost:7860 查看界面"
|
385 |
+
+ colorama.Style.RESET_ALL
|
386 |
+
)
|
387 |
+
# 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
|
388 |
+
demo.title = "川虎ChatGPT 🚀"
|
389 |
+
|
390 |
+
if __name__ == "__main__":
|
391 |
+
reload_javascript()
|
392 |
+
# if running in Docker
|
393 |
+
if dockerflag:
|
394 |
+
if authflag:
|
395 |
+
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
396 |
+
server_name="0.0.0.0",
|
397 |
+
server_port=7860,
|
398 |
+
auth=auth_list,
|
399 |
+
favicon_path="./assets/favicon.ico",
|
400 |
+
)
|
401 |
+
else:
|
402 |
+
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
403 |
+
server_name="0.0.0.0",
|
404 |
+
server_port=7860,
|
405 |
+
share=False,
|
406 |
+
favicon_path="./assets/favicon.ico",
|
407 |
+
)
|
408 |
+
# if not running in Docker
|
409 |
+
else:
|
410 |
+
if authflag:
|
411 |
+
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
412 |
+
share=False,
|
413 |
+
auth=auth_list,
|
414 |
+
favicon_path="./assets/favicon.ico",
|
415 |
+
inbrowser=True,
|
416 |
+
)
|
417 |
+
else:
|
418 |
+
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
419 |
+
share=False, favicon_path="./assets/favicon.ico", inbrowser=True
|
420 |
+
) # 改为 share=True 可以创建公开分享链接
|
421 |
+
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
|
422 |
+
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
|
423 |
+
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
|
README.md
CHANGED
@@ -4,8 +4,8 @@ emoji: 🐯
|
|
4 |
colorFrom: green
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: gpl-3.0
|
11 |
---
|
|
|
4 |
colorFrom: green
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.24.1
|
8 |
+
app_file: ChuanhuChatbot.py
|
9 |
pinned: false
|
10 |
license: gpl-3.0
|
11 |
---
|
assets/custom.css
CHANGED
@@ -18,10 +18,22 @@ footer {
|
|
18 |
opacity: 0.85;
|
19 |
}
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
/* status_display */
|
22 |
#status_display {
|
23 |
display: flex;
|
24 |
-
min-height:
|
25 |
align-items: flex-end;
|
26 |
justify-content: flex-end;
|
27 |
}
|
@@ -110,7 +122,6 @@ ol:not(.options), ul:not(.options) {
|
|
110 |
background-color: var(--neutral-950) !important;
|
111 |
}
|
112 |
}
|
113 |
-
|
114 |
/* 对话气泡 */
|
115 |
[class *= "message"] {
|
116 |
border-radius: var(--radius-xl) !important;
|
|
|
18 |
opacity: 0.85;
|
19 |
}
|
20 |
|
21 |
+
/* user_info */
|
22 |
+
#user_info {
|
23 |
+
white-space: nowrap;
|
24 |
+
margin-top: -1.3em !important;
|
25 |
+
padding-left: 112px !important;
|
26 |
+
}
|
27 |
+
#user_info p {
|
28 |
+
font-size: .85em;
|
29 |
+
font-family: monospace;
|
30 |
+
color: var(--body-text-color-subdued);
|
31 |
+
}
|
32 |
+
|
33 |
/* status_display */
|
34 |
#status_display {
|
35 |
display: flex;
|
36 |
+
min-height: 2em;
|
37 |
align-items: flex-end;
|
38 |
justify-content: flex-end;
|
39 |
}
|
|
|
122 |
background-color: var(--neutral-950) !important;
|
123 |
}
|
124 |
}
|
|
|
125 |
/* 对话气泡 */
|
126 |
[class *= "message"] {
|
127 |
border-radius: var(--radius-xl) !important;
|
assets/custom.js
CHANGED
@@ -1 +1,70 @@
|
|
1 |
-
// custom javascript here
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// custom javascript here
|
2 |
+
const MAX_HISTORY_LENGTH = 32;
|
3 |
+
|
4 |
+
var key_down_history = [];
|
5 |
+
var currentIndex = -1;
|
6 |
+
var user_input_ta;
|
7 |
+
|
8 |
+
var ga = document.getElementsByTagName("gradio-app");
|
9 |
+
var targetNode = ga[0];
|
10 |
+
var observer = new MutationObserver(function(mutations) {
|
11 |
+
for (var i = 0; i < mutations.length; i++) {
|
12 |
+
if (mutations[i].addedNodes.length) {
|
13 |
+
var user_input_tb = document.getElementById('user_input_tb');
|
14 |
+
if (user_input_tb) {
|
15 |
+
// 监听到user_input_tb被添加到DOM树中
|
16 |
+
// 这里可以编写元素加载完成后需要执行的代码
|
17 |
+
user_input_ta = user_input_tb.querySelector("textarea");
|
18 |
+
if (user_input_ta){
|
19 |
+
observer.disconnect(); // 停止监听
|
20 |
+
// 在 textarea 上监听 keydown 事件
|
21 |
+
user_input_ta.addEventListener("keydown", function (event) {
|
22 |
+
var value = user_input_ta.value.trim();
|
23 |
+
// 判断按下的是否为方向键
|
24 |
+
if (event.code === 'ArrowUp' || event.code === 'ArrowDown') {
|
25 |
+
// 如果按下的是方向键,且输入框中有内容,且历史记录中没有该内容,则不执行操作
|
26 |
+
if(value && key_down_history.indexOf(value) === -1)
|
27 |
+
return;
|
28 |
+
// 对于需要响应的动作,阻止默认行为。
|
29 |
+
event.preventDefault();
|
30 |
+
var length = key_down_history.length;
|
31 |
+
if(length === 0) {
|
32 |
+
currentIndex = -1; // 如果历史记录为空,直接将当前选中的记录重置
|
33 |
+
return;
|
34 |
+
}
|
35 |
+
if (currentIndex === -1) {
|
36 |
+
currentIndex = length;
|
37 |
+
}
|
38 |
+
if (event.code === 'ArrowUp' && currentIndex > 0) {
|
39 |
+
currentIndex--;
|
40 |
+
user_input_ta.value = key_down_history[currentIndex];
|
41 |
+
} else if (event.code === 'ArrowDown' && currentIndex < length - 1) {
|
42 |
+
currentIndex++;
|
43 |
+
user_input_ta.value = key_down_history[currentIndex];
|
44 |
+
}
|
45 |
+
user_input_ta.selectionStart = user_input_ta.value.length;
|
46 |
+
user_input_ta.selectionEnd = user_input_ta.value.length;
|
47 |
+
const input_event = new InputEvent("input", {bubbles: true, cancelable: true});
|
48 |
+
user_input_ta.dispatchEvent(input_event);
|
49 |
+
}else if(event.code === "Enter") {
|
50 |
+
if (value) {
|
51 |
+
currentIndex = -1;
|
52 |
+
if(key_down_history.indexOf(value) === -1){
|
53 |
+
key_down_history.push(value);
|
54 |
+
if (key_down_history.length > MAX_HISTORY_LENGTH) {
|
55 |
+
key_down_history.shift();
|
56 |
+
}
|
57 |
+
}
|
58 |
+
}
|
59 |
+
}
|
60 |
+
});
|
61 |
+
break;
|
62 |
+
}
|
63 |
+
}
|
64 |
+
}
|
65 |
+
}
|
66 |
+
});
|
67 |
+
|
68 |
+
// 监听目标节点的子节点列表是否发生变化
|
69 |
+
observer.observe(targetNode, { childList: true , subtree: true });
|
70 |
+
|
modules/__pycache__/chat_func.cpython-39.pyc
CHANGED
Binary files a/modules/__pycache__/chat_func.cpython-39.pyc and b/modules/__pycache__/chat_func.cpython-39.pyc differ
|
|
modules/__pycache__/config.cpython-39.pyc
ADDED
Binary file (3.18 kB). View file
|
|
modules/__pycache__/llama_func.cpython-39.pyc
CHANGED
Binary files a/modules/__pycache__/llama_func.cpython-39.pyc and b/modules/__pycache__/llama_func.cpython-39.pyc differ
|
|
modules/__pycache__/openai_func.cpython-39.pyc
CHANGED
Binary files a/modules/__pycache__/openai_func.cpython-39.pyc and b/modules/__pycache__/openai_func.cpython-39.pyc differ
|
|
modules/__pycache__/overwrites.cpython-39.pyc
CHANGED
Binary files a/modules/__pycache__/overwrites.cpython-39.pyc and b/modules/__pycache__/overwrites.cpython-39.pyc differ
|
|
modules/__pycache__/pdf_func.cpython-39.pyc
ADDED
Binary file (6.13 kB). View file
|
|
modules/__pycache__/presets.cpython-39.pyc
CHANGED
Binary files a/modules/__pycache__/presets.cpython-39.pyc and b/modules/__pycache__/presets.cpython-39.pyc differ
|
|
modules/__pycache__/proxy_func.cpython-39.pyc
ADDED
Binary file (718 Bytes). View file
|
|
modules/__pycache__/shared.cpython-39.pyc
CHANGED
Binary files a/modules/__pycache__/shared.cpython-39.pyc and b/modules/__pycache__/shared.cpython-39.pyc differ
|
|
modules/__pycache__/utils.cpython-39.pyc
CHANGED
Binary files a/modules/__pycache__/utils.cpython-39.pyc and b/modules/__pycache__/utils.cpython-39.pyc differ
|
|
modules/chat_func.py
CHANGED
@@ -13,14 +13,13 @@ import colorama
|
|
13 |
from duckduckgo_search import ddg
|
14 |
import asyncio
|
15 |
import aiohttp
|
16 |
-
|
17 |
-
from llama_index.indices.query.schema import QueryBundle
|
18 |
-
from langchain.llms import OpenAIChat
|
19 |
|
20 |
from modules.presets import *
|
21 |
from modules.llama_func import *
|
22 |
from modules.utils import *
|
23 |
-
|
|
|
24 |
|
25 |
# logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
|
26 |
|
@@ -36,6 +35,7 @@ initial_prompt = "You are a helpful assistant."
|
|
36 |
HISTORY_DIR = "history"
|
37 |
TEMPLATES_DIR = "templates"
|
38 |
|
|
|
39 |
def get_response(
|
40 |
openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
|
41 |
):
|
@@ -61,20 +61,19 @@ def get_response(
|
|
61 |
else:
|
62 |
timeout = timeout_all
|
63 |
|
64 |
-
proxies = get_proxies()
|
65 |
|
66 |
-
# 如果有自定义的api-
|
67 |
-
if shared.state.
|
68 |
-
logging.info(f"使用自定义API URL: {shared.state.
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
|
79 |
return response
|
80 |
|
@@ -146,7 +145,7 @@ def stream_predict(
|
|
146 |
|
147 |
if fake_input is not None:
|
148 |
history[-2] = construct_user(fake_input)
|
149 |
-
for chunk in response.iter_lines():
|
150 |
if counter == 0:
|
151 |
counter += 1
|
152 |
continue
|
@@ -166,9 +165,7 @@ def stream_predict(
|
|
166 |
# decode each line as response data is in bytes
|
167 |
if chunklength > 6 and "delta" in chunk["choices"][0]:
|
168 |
finish_reason = chunk["choices"][0]["finish_reason"]
|
169 |
-
status_text = construct_token_message(
|
170 |
-
sum(all_token_counts), stream=True
|
171 |
-
)
|
172 |
if finish_reason == "stop":
|
173 |
yield get_return_value()
|
174 |
break
|
@@ -253,14 +250,6 @@ def predict_all(
|
|
253 |
status_text = standard_error_msg + str(response)
|
254 |
return chatbot, history, status_text, all_token_counts
|
255 |
|
256 |
-
def is_repeated_string(s):
|
257 |
-
n = len(s)
|
258 |
-
for i in range(1, n // 2 + 1):
|
259 |
-
if n % i == 0:
|
260 |
-
sub = s[:i]
|
261 |
-
if sub * (n // i) == s:
|
262 |
-
return True
|
263 |
-
return False
|
264 |
|
265 |
def predict(
|
266 |
openai_api_key,
|
@@ -278,11 +267,12 @@ def predict(
|
|
278 |
reply_language="中文",
|
279 |
should_check_token_count=True,
|
280 |
): # repetition_penalty, top_k
|
|
|
|
|
|
|
|
|
|
|
281 |
logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
|
282 |
-
if is_repeated_string(inputs):
|
283 |
-
print("================== 有人来浪费了 ======================")
|
284 |
-
yield chatbot+[(inputs, "🖕️🖕️🖕️🖕️🖕️看不起你")], history, "🖕️🖕️🖕️🖕️🖕️🖕️", all_token_counts
|
285 |
-
return
|
286 |
if should_check_token_count:
|
287 |
yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
|
288 |
if reply_language == "跟随问题语言(不稳定)":
|
@@ -300,12 +290,14 @@ def predict(
|
|
300 |
msg = "索引构建完成,获取回答中……"
|
301 |
logging.info(msg)
|
302 |
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
|
|
|
|
309 |
reference_results = [n.node.text for n in nodes]
|
310 |
reference_results = add_source_numbers(reference_results, use_source=False)
|
311 |
display_reference = add_details(reference_results)
|
@@ -337,7 +329,7 @@ def predict(
|
|
337 |
else:
|
338 |
display_reference = ""
|
339 |
|
340 |
-
if len(openai_api_key)
|
341 |
status_text = standard_error_msg + no_apikey_msg
|
342 |
logging.info(status_text)
|
343 |
chatbot.append((inputs, ""))
|
@@ -412,23 +404,15 @@ def predict(
|
|
412 |
max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["all"]
|
413 |
|
414 |
if sum(all_token_counts) > max_token and should_check_token_count:
|
415 |
-
|
|
|
|
|
|
|
|
|
|
|
416 |
logging.info(status_text)
|
|
|
417 |
yield chatbot, history, status_text, all_token_counts
|
418 |
-
iter = reduce_token_size(
|
419 |
-
openai_api_key,
|
420 |
-
system_prompt,
|
421 |
-
history,
|
422 |
-
chatbot,
|
423 |
-
all_token_counts,
|
424 |
-
top_p,
|
425 |
-
temperature,
|
426 |
-
max_token//2,
|
427 |
-
selected_model=selected_model,
|
428 |
-
)
|
429 |
-
for chatbot, history, status_text, all_token_counts in iter:
|
430 |
-
status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
|
431 |
-
yield chatbot, history, status_text, all_token_counts
|
432 |
|
433 |
|
434 |
def retry(
|
@@ -507,7 +491,7 @@ def reduce_token_size(
|
|
507 |
token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
|
508 |
msg = f"保留了最近{num_chat}轮对话"
|
509 |
yield chatbot, history, msg + "," + construct_token_message(
|
510 |
-
|
511 |
), token_count
|
512 |
logging.info(msg)
|
513 |
logging.info("减少token数量完毕")
|
|
|
13 |
from duckduckgo_search import ddg
|
14 |
import asyncio
|
15 |
import aiohttp
|
16 |
+
|
|
|
|
|
17 |
|
18 |
from modules.presets import *
|
19 |
from modules.llama_func import *
|
20 |
from modules.utils import *
|
21 |
+
from . import shared
|
22 |
+
from modules.config import retrieve_proxy
|
23 |
|
24 |
# logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
|
25 |
|
|
|
35 |
HISTORY_DIR = "history"
|
36 |
TEMPLATES_DIR = "templates"
|
37 |
|
38 |
+
@shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
|
39 |
def get_response(
|
40 |
openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
|
41 |
):
|
|
|
61 |
else:
|
62 |
timeout = timeout_all
|
63 |
|
|
|
64 |
|
65 |
+
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
|
66 |
+
if shared.state.completion_url != COMPLETION_URL:
|
67 |
+
logging.info(f"使用自定义API URL: {shared.state.completion_url}")
|
68 |
|
69 |
+
with retrieve_proxy():
|
70 |
+
response = requests.post(
|
71 |
+
shared.state.completion_url,
|
72 |
+
headers=headers,
|
73 |
+
json=payload,
|
74 |
+
stream=True,
|
75 |
+
timeout=timeout,
|
76 |
+
)
|
77 |
|
78 |
return response
|
79 |
|
|
|
145 |
|
146 |
if fake_input is not None:
|
147 |
history[-2] = construct_user(fake_input)
|
148 |
+
for chunk in tqdm(response.iter_lines()):
|
149 |
if counter == 0:
|
150 |
counter += 1
|
151 |
continue
|
|
|
165 |
# decode each line as response data is in bytes
|
166 |
if chunklength > 6 and "delta" in chunk["choices"][0]:
|
167 |
finish_reason = chunk["choices"][0]["finish_reason"]
|
168 |
+
status_text = construct_token_message(all_token_counts)
|
|
|
|
|
169 |
if finish_reason == "stop":
|
170 |
yield get_return_value()
|
171 |
break
|
|
|
250 |
status_text = standard_error_msg + str(response)
|
251 |
return chatbot, history, status_text, all_token_counts
|
252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
def predict(
|
255 |
openai_api_key,
|
|
|
267 |
reply_language="中文",
|
268 |
should_check_token_count=True,
|
269 |
): # repetition_penalty, top_k
|
270 |
+
from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
|
271 |
+
from llama_index.indices.query.schema import QueryBundle
|
272 |
+
from langchain.llms import OpenAIChat
|
273 |
+
|
274 |
+
|
275 |
logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
|
|
|
|
|
|
|
|
|
276 |
if should_check_token_count:
|
277 |
yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
|
278 |
if reply_language == "跟随问题语言(不稳定)":
|
|
|
290 |
msg = "索引构建完成,获取回答中……"
|
291 |
logging.info(msg)
|
292 |
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
293 |
+
with retrieve_proxy():
|
294 |
+
llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
|
295 |
+
prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
|
296 |
+
from llama_index import ServiceContext
|
297 |
+
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
|
298 |
+
query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
|
299 |
+
query_bundle = QueryBundle(inputs)
|
300 |
+
nodes = query_object.retrieve(query_bundle)
|
301 |
reference_results = [n.node.text for n in nodes]
|
302 |
reference_results = add_source_numbers(reference_results, use_source=False)
|
303 |
display_reference = add_details(reference_results)
|
|
|
329 |
else:
|
330 |
display_reference = ""
|
331 |
|
332 |
+
if len(openai_api_key) == 0 and not shared.state.multi_api_key:
|
333 |
status_text = standard_error_msg + no_apikey_msg
|
334 |
logging.info(status_text)
|
335 |
chatbot.append((inputs, ""))
|
|
|
404 |
max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["all"]
|
405 |
|
406 |
if sum(all_token_counts) > max_token and should_check_token_count:
|
407 |
+
print(all_token_counts)
|
408 |
+
count = 0
|
409 |
+
while sum(all_token_counts) > max_token - 500 and sum(all_token_counts) > 0:
|
410 |
+
count += 1
|
411 |
+
del all_token_counts[0]
|
412 |
+
del history[:2]
|
413 |
logging.info(status_text)
|
414 |
+
status_text = f"为了防止token超限,模型忘记了早期的 {count} 轮对话"
|
415 |
yield chatbot, history, status_text, all_token_counts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
|
417 |
|
418 |
def retry(
|
|
|
491 |
token_count = previous_token_count[-num_chat:] if num_chat > 0 else []
|
492 |
msg = f"保留了最近{num_chat}轮对话"
|
493 |
yield chatbot, history, msg + "," + construct_token_message(
|
494 |
+
token_count if len(token_count) > 0 else [0],
|
495 |
), token_count
|
496 |
logging.info(msg)
|
497 |
logging.info("减少token数量完毕")
|
modules/config.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from contextlib import contextmanager
|
3 |
+
import os
|
4 |
+
import logging
|
5 |
+
import sys
|
6 |
+
import json
|
7 |
+
|
8 |
+
from . import shared
|
9 |
+
|
10 |
+
|
11 |
+
__all__ = [
|
12 |
+
"my_api_key",
|
13 |
+
"authflag",
|
14 |
+
"auth_list",
|
15 |
+
"dockerflag",
|
16 |
+
"retrieve_proxy",
|
17 |
+
"log_level",
|
18 |
+
"advance_docs",
|
19 |
+
"update_doc_config",
|
20 |
+
"multi_api_key",
|
21 |
+
]
|
22 |
+
|
23 |
+
# 添加一个统一的config文件,避免文件过多造成的疑惑(优先级最低)
|
24 |
+
# 同时,也可以为后续支持自定义功能提供config的帮助
|
25 |
+
if os.path.exists("config.json"):
|
26 |
+
with open("config.json", "r", encoding='utf-8') as f:
|
27 |
+
config = json.load(f)
|
28 |
+
else:
|
29 |
+
config = {}
|
30 |
+
|
31 |
+
## 处理docker if we are running in Docker
|
32 |
+
dockerflag = config.get("dockerflag", False)
|
33 |
+
if os.environ.get("dockerrun") == "yes":
|
34 |
+
dockerflag = True
|
35 |
+
|
36 |
+
## 处理 api-key 以及 允许的用户列表
|
37 |
+
my_api_key = config.get("openai_api_key", "") # 在这里输入你的 API 密钥
|
38 |
+
my_api_key = os.environ.get("my_api_key", my_api_key)
|
39 |
+
|
40 |
+
## 多账户机制
|
41 |
+
multi_api_key = config.get("multi_api_key", False) # 是否开启多账户机制
|
42 |
+
if multi_api_key:
|
43 |
+
api_key_list = config.get("api_key_list", [])
|
44 |
+
if len(api_key_list) == 0:
|
45 |
+
logging.error("多账号模式已开启,但api_key_list为空,请检查config.json")
|
46 |
+
sys.exit(1)
|
47 |
+
shared.state.set_api_key_queue(api_key_list)
|
48 |
+
|
49 |
+
auth_list = config.get("users", []) # 实际上是使用者的列表
|
50 |
+
authflag = len(auth_list) > 0 # 是否开启认证的状态值,改为判断auth_list长度
|
51 |
+
|
52 |
+
# 处理自定义的api_host,优先读环境变量的配置,如果存在则自动装配
|
53 |
+
api_host = os.environ.get("api_host", config.get("api_host", ""))
|
54 |
+
if api_host:
|
55 |
+
shared.state.set_api_host(api_host)
|
56 |
+
|
57 |
+
if dockerflag:
|
58 |
+
if my_api_key == "empty":
|
59 |
+
logging.error("Please give a api key!")
|
60 |
+
sys.exit(1)
|
61 |
+
# auth
|
62 |
+
username = os.environ.get("USERNAME")
|
63 |
+
password = os.environ.get("PASSWORD")
|
64 |
+
if not (isinstance(username, type(None)) or isinstance(password, type(None))):
|
65 |
+
auth_list.append((os.environ.get("USERNAME"), os.environ.get("PASSWORD")))
|
66 |
+
authflag = True
|
67 |
+
else:
|
68 |
+
if (
|
69 |
+
not my_api_key
|
70 |
+
and os.path.exists("api_key.txt")
|
71 |
+
and os.path.getsize("api_key.txt")
|
72 |
+
):
|
73 |
+
with open("api_key.txt", "r") as f:
|
74 |
+
my_api_key = f.read().strip()
|
75 |
+
if os.path.exists("auth.json"):
|
76 |
+
authflag = True
|
77 |
+
with open("auth.json", "r", encoding='utf-8') as f:
|
78 |
+
auth = json.load(f)
|
79 |
+
for _ in auth:
|
80 |
+
if auth[_]["username"] and auth[_]["password"]:
|
81 |
+
auth_list.append((auth[_]["username"], auth[_]["password"]))
|
82 |
+
else:
|
83 |
+
logging.error("请检查auth.json文件中的用户名和密码!")
|
84 |
+
sys.exit(1)
|
85 |
+
|
86 |
+
@contextmanager
|
87 |
+
def retrieve_openai_api(api_key = None):
|
88 |
+
old_api_key = os.environ.get("OPENAI_API_KEY", "")
|
89 |
+
if api_key is None:
|
90 |
+
os.environ["OPENAI_API_KEY"] = my_api_key
|
91 |
+
yield my_api_key
|
92 |
+
else:
|
93 |
+
os.environ["OPENAI_API_KEY"] = api_key
|
94 |
+
yield api_key
|
95 |
+
os.environ["OPENAI_API_KEY"] = old_api_key
|
96 |
+
|
97 |
+
## 处理log
|
98 |
+
log_level = config.get("log_level", "INFO")
|
99 |
+
logging.basicConfig(
|
100 |
+
level=log_level,
|
101 |
+
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
102 |
+
)
|
103 |
+
|
104 |
+
## 处理代理:
|
105 |
+
http_proxy = config.get("http_proxy", "")
|
106 |
+
https_proxy = config.get("https_proxy", "")
|
107 |
+
http_proxy = os.environ.get("HTTP_PROXY", http_proxy)
|
108 |
+
https_proxy = os.environ.get("HTTPS_PROXY", https_proxy)
|
109 |
+
|
110 |
+
# 重置系统变量,在不需要设置的时候不设置环境变量,以免引起全局代理报错
|
111 |
+
os.environ["HTTP_PROXY"] = ""
|
112 |
+
os.environ["HTTPS_PROXY"] = ""
|
113 |
+
|
114 |
+
@contextmanager
|
115 |
+
def retrieve_proxy(proxy=None):
|
116 |
+
"""
|
117 |
+
1, 如果proxy = NONE,设置环境变量,并返回最新设置的代理
|
118 |
+
2,如果proxy != NONE,更新当前的代理配置,但是不更新环境变量
|
119 |
+
"""
|
120 |
+
global http_proxy, https_proxy
|
121 |
+
if proxy is not None:
|
122 |
+
http_proxy = proxy
|
123 |
+
https_proxy = proxy
|
124 |
+
yield http_proxy, https_proxy
|
125 |
+
else:
|
126 |
+
old_var = os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"]
|
127 |
+
os.environ["HTTP_PROXY"] = http_proxy
|
128 |
+
os.environ["HTTPS_PROXY"] = https_proxy
|
129 |
+
yield http_proxy, https_proxy # return new proxy
|
130 |
+
|
131 |
+
# return old proxy
|
132 |
+
os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"] = old_var
|
133 |
+
|
134 |
+
|
135 |
+
## 处理advance docs
|
136 |
+
advance_docs = defaultdict(lambda: defaultdict(dict))
|
137 |
+
advance_docs.update(config.get("advance_docs", {}))
|
138 |
+
def update_doc_config(two_column_pdf):
|
139 |
+
global advance_docs
|
140 |
+
if two_column_pdf:
|
141 |
+
advance_docs["pdf"]["two_column"] = True
|
142 |
+
else:
|
143 |
+
advance_docs["pdf"]["two_column"] = False
|
144 |
+
|
145 |
+
logging.info(f"更新后的文件参数为:{advance_docs}")
|
modules/llama_func.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import os
|
2 |
import logging
|
3 |
|
4 |
-
from llama_index import GPTSimpleVectorIndex, ServiceContext
|
5 |
from llama_index import download_loader
|
6 |
from llama_index import (
|
7 |
Document,
|
@@ -10,8 +9,6 @@ from llama_index import (
|
|
10 |
QuestionAnswerPrompt,
|
11 |
RefinePrompt,
|
12 |
)
|
13 |
-
from langchain.llms import OpenAI
|
14 |
-
from langchain.chat_models import ChatOpenAI
|
15 |
import colorama
|
16 |
import PyPDF2
|
17 |
from tqdm import tqdm
|
@@ -43,28 +40,40 @@ def get_documents(file_src):
|
|
43 |
logging.debug("Loading documents...")
|
44 |
logging.debug(f"file_src: {file_src}")
|
45 |
for file in file_src:
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
48 |
logging.debug("Loading PDF...")
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
text_raw = pdftext
|
55 |
-
elif
|
56 |
-
logging.debug("Loading
|
57 |
DocxReader = download_loader("DocxReader")
|
58 |
loader = DocxReader()
|
59 |
-
text_raw = loader.load_data(file=
|
60 |
-
elif
|
61 |
logging.debug("Loading EPUB...")
|
62 |
EpubReader = download_loader("EpubReader")
|
63 |
loader = EpubReader()
|
64 |
-
text_raw = loader.load_data(file=
|
|
|
|
|
|
|
65 |
else:
|
66 |
logging.debug("Loading text file...")
|
67 |
-
with open(
|
68 |
text_raw = f.read()
|
69 |
text = add_space(text_raw)
|
70 |
# text = block_split(text)
|
@@ -84,6 +93,9 @@ def construct_index(
|
|
84 |
embedding_limit=None,
|
85 |
separator=" "
|
86 |
):
|
|
|
|
|
|
|
87 |
os.environ["OPENAI_API_KEY"] = api_key
|
88 |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
|
89 |
embedding_limit = None if embedding_limit == 0 else embedding_limit
|
@@ -101,10 +113,11 @@ def construct_index(
|
|
101 |
try:
|
102 |
documents = get_documents(file_src)
|
103 |
logging.info("构建索引中……")
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
108 |
logging.debug("索引构建完成!")
|
109 |
os.makedirs("./index", exist_ok=True)
|
110 |
index.save_to_disk(f"./index/{index_name}.json")
|
@@ -117,97 +130,6 @@ def construct_index(
|
|
117 |
return None
|
118 |
|
119 |
|
120 |
-
def chat_ai(
|
121 |
-
api_key,
|
122 |
-
index,
|
123 |
-
question,
|
124 |
-
context,
|
125 |
-
chatbot,
|
126 |
-
reply_language,
|
127 |
-
):
|
128 |
-
os.environ["OPENAI_API_KEY"] = api_key
|
129 |
-
|
130 |
-
logging.info(f"Question: {question}")
|
131 |
-
|
132 |
-
response, chatbot_display, status_text = ask_ai(
|
133 |
-
api_key,
|
134 |
-
index,
|
135 |
-
question,
|
136 |
-
replace_today(PROMPT_TEMPLATE),
|
137 |
-
REFINE_TEMPLATE,
|
138 |
-
SIM_K,
|
139 |
-
INDEX_QUERY_TEMPRATURE,
|
140 |
-
context,
|
141 |
-
reply_language,
|
142 |
-
)
|
143 |
-
if response is None:
|
144 |
-
status_text = "查询失败,请换个问法试试"
|
145 |
-
return context, chatbot
|
146 |
-
response = response
|
147 |
-
|
148 |
-
context.append({"role": "user", "content": question})
|
149 |
-
context.append({"role": "assistant", "content": response})
|
150 |
-
chatbot.append((question, chatbot_display))
|
151 |
-
|
152 |
-
os.environ["OPENAI_API_KEY"] = ""
|
153 |
-
return context, chatbot, status_text
|
154 |
-
|
155 |
-
|
156 |
-
def ask_ai(
|
157 |
-
api_key,
|
158 |
-
index,
|
159 |
-
question,
|
160 |
-
prompt_tmpl,
|
161 |
-
refine_tmpl,
|
162 |
-
sim_k=5,
|
163 |
-
temprature=0,
|
164 |
-
prefix_messages=[],
|
165 |
-
reply_language="中文",
|
166 |
-
):
|
167 |
-
os.environ["OPENAI_API_KEY"] = api_key
|
168 |
-
|
169 |
-
logging.debug("Index file found")
|
170 |
-
logging.debug("Querying index...")
|
171 |
-
llm_predictor = LLMPredictor(
|
172 |
-
llm=ChatOpenAI(
|
173 |
-
temperature=temprature,
|
174 |
-
model_name="gpt-3.5-turbo-0301",
|
175 |
-
prefix_messages=prefix_messages,
|
176 |
-
)
|
177 |
-
)
|
178 |
-
|
179 |
-
response = None # Initialize response variable to avoid UnboundLocalError
|
180 |
-
qa_prompt = QuestionAnswerPrompt(prompt_tmpl.replace("{reply_language}", reply_language))
|
181 |
-
rf_prompt = RefinePrompt(refine_tmpl.replace("{reply_language}", reply_language))
|
182 |
-
response = index.query(
|
183 |
-
question,
|
184 |
-
similarity_top_k=sim_k,
|
185 |
-
text_qa_template=qa_prompt,
|
186 |
-
refine_template=rf_prompt,
|
187 |
-
response_mode="compact",
|
188 |
-
)
|
189 |
-
|
190 |
-
if response is not None:
|
191 |
-
logging.info(f"Response: {response}")
|
192 |
-
ret_text = response.response
|
193 |
-
nodes = []
|
194 |
-
for index, node in enumerate(response.source_nodes):
|
195 |
-
brief = node.source_text[:25].replace("\n", "")
|
196 |
-
nodes.append(
|
197 |
-
f"<details><summary>[{index + 1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
|
198 |
-
)
|
199 |
-
new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
|
200 |
-
logging.info(
|
201 |
-
f"Response: {colorama.Fore.BLUE}{ret_text}{colorama.Style.RESET_ALL}"
|
202 |
-
)
|
203 |
-
os.environ["OPENAI_API_KEY"] = ""
|
204 |
-
return ret_text, new_response, f"查询消耗了{llm_predictor.last_token_usage} tokens"
|
205 |
-
else:
|
206 |
-
logging.warning("No response found, returning None")
|
207 |
-
os.environ["OPENAI_API_KEY"] = ""
|
208 |
-
return None
|
209 |
-
|
210 |
-
|
211 |
def add_space(text):
|
212 |
punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
|
213 |
for cn_punc, en_punc in punctuations.items():
|
|
|
1 |
import os
|
2 |
import logging
|
3 |
|
|
|
4 |
from llama_index import download_loader
|
5 |
from llama_index import (
|
6 |
Document,
|
|
|
9 |
QuestionAnswerPrompt,
|
10 |
RefinePrompt,
|
11 |
)
|
|
|
|
|
12 |
import colorama
|
13 |
import PyPDF2
|
14 |
from tqdm import tqdm
|
|
|
40 |
logging.debug("Loading documents...")
|
41 |
logging.debug(f"file_src: {file_src}")
|
42 |
for file in file_src:
|
43 |
+
filepath = file.name
|
44 |
+
filename = os.path.basename(filepath)
|
45 |
+
file_type = os.path.splitext(filepath)[1]
|
46 |
+
logging.info(f"loading file: {filename}")
|
47 |
+
if file_type == ".pdf":
|
48 |
logging.debug("Loading PDF...")
|
49 |
+
try:
|
50 |
+
from modules.pdf_func import parse_pdf
|
51 |
+
from modules.config import advance_docs
|
52 |
+
two_column = advance_docs["pdf"].get("two_column", False)
|
53 |
+
pdftext = parse_pdf(filepath, two_column).text
|
54 |
+
except:
|
55 |
+
pdftext = ""
|
56 |
+
with open(filepath, 'rb') as pdfFileObj:
|
57 |
+
pdfReader = PyPDF2.PdfReader(pdfFileObj)
|
58 |
+
for page in tqdm(pdfReader.pages):
|
59 |
+
pdftext += page.extract_text()
|
60 |
text_raw = pdftext
|
61 |
+
elif file_type == ".docx":
|
62 |
+
logging.debug("Loading Word...")
|
63 |
DocxReader = download_loader("DocxReader")
|
64 |
loader = DocxReader()
|
65 |
+
text_raw = loader.load_data(file=filepath)[0].text
|
66 |
+
elif file_type == ".epub":
|
67 |
logging.debug("Loading EPUB...")
|
68 |
EpubReader = download_loader("EpubReader")
|
69 |
loader = EpubReader()
|
70 |
+
text_raw = loader.load_data(file=filepath)[0].text
|
71 |
+
elif file_type == ".xlsx":
|
72 |
+
logging.debug("Loading Excel...")
|
73 |
+
text_raw = excel_to_string(filepath)
|
74 |
else:
|
75 |
logging.debug("Loading text file...")
|
76 |
+
with open(filepath, "r", encoding="utf-8") as f:
|
77 |
text_raw = f.read()
|
78 |
text = add_space(text_raw)
|
79 |
# text = block_split(text)
|
|
|
93 |
embedding_limit=None,
|
94 |
separator=" "
|
95 |
):
|
96 |
+
from langchain.chat_models import ChatOpenAI
|
97 |
+
from llama_index import GPTSimpleVectorIndex, ServiceContext
|
98 |
+
|
99 |
os.environ["OPENAI_API_KEY"] = api_key
|
100 |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
|
101 |
embedding_limit = None if embedding_limit == 0 else embedding_limit
|
|
|
113 |
try:
|
114 |
documents = get_documents(file_src)
|
115 |
logging.info("构建索引中……")
|
116 |
+
with retrieve_proxy():
|
117 |
+
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit=chunk_size_limit)
|
118 |
+
index = GPTSimpleVectorIndex.from_documents(
|
119 |
+
documents, service_context=service_context
|
120 |
+
)
|
121 |
logging.debug("索引构建完成!")
|
122 |
os.makedirs("./index", exist_ok=True)
|
123 |
index.save_to_disk(f"./index/{index_name}.json")
|
|
|
130 |
return None
|
131 |
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
def add_space(text):
|
134 |
punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
|
135 |
for cn_punc, en_punc in punctuations.items():
|
modules/openai_func.py
CHANGED
@@ -10,8 +10,8 @@ from modules.presets import (
|
|
10 |
read_timeout_prompt
|
11 |
)
|
12 |
|
13 |
-
from
|
14 |
-
from modules.
|
15 |
import os, datetime
|
16 |
|
17 |
def get_billing_data(openai_api_key, billing_url):
|
@@ -19,58 +19,35 @@ def get_billing_data(openai_api_key, billing_url):
|
|
19 |
"Content-Type": "application/json",
|
20 |
"Authorization": f"Bearer {openai_api_key}"
|
21 |
}
|
22 |
-
|
23 |
timeout = timeout_all
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
if response.status_code == 200:
|
33 |
data = response.json()
|
34 |
return data
|
35 |
else:
|
36 |
raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
|
37 |
-
|
38 |
|
39 |
def get_usage(openai_api_key):
|
40 |
try:
|
41 |
-
|
42 |
-
|
|
|
|
|
43 |
try:
|
44 |
-
|
45 |
-
total_used = balance_data["total_used"] if balance_data["total_used"] else 0
|
46 |
-
usage_percent = round(total_used / (total_used+balance) * 100, 2)
|
47 |
except Exception as e:
|
48 |
-
logging.error(f"API
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
if balance == 0:
|
53 |
-
last_day_of_month = datetime.datetime.now().strftime("%Y-%m-%d")
|
54 |
-
first_day_of_month = datetime.datetime.now().replace(day=1).strftime("%Y-%m-%d")
|
55 |
-
usage_url = f"{USAGE_API_URL}?start_date={first_day_of_month}&end_date={last_day_of_month}"
|
56 |
-
try:
|
57 |
-
usage_data = get_billing_data(openai_api_key, usage_url)
|
58 |
-
except Exception as e:
|
59 |
-
logging.error(f"获取API使用情况失败:"+str(e))
|
60 |
-
return f"**获取API使用情况失败**"
|
61 |
-
return f"**本月使用金额** \u3000 ${usage_data['total_usage'] / 100}"
|
62 |
-
|
63 |
-
# return f"**免费额度**(已用/余额)\u3000${total_used} / ${balance}"
|
64 |
-
return f"""\
|
65 |
-
<b>免费额度使用情况</b>
|
66 |
-
<div class="progress-bar">
|
67 |
-
<div class="progress" style="width: {usage_percent}%;">
|
68 |
-
<span class="progress-text">{usage_percent}%</span>
|
69 |
-
</div>
|
70 |
-
</div>
|
71 |
-
<div style="display: flex; justify-content: space-between;"><span>已用 ${total_used}</span><span>可用 ${balance}</span></div>
|
72 |
-
"""
|
73 |
-
|
74 |
except requests.exceptions.ConnectTimeout:
|
75 |
status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
76 |
return status_text
|
@@ -80,3 +57,9 @@ def get_usage(openai_api_key):
|
|
80 |
except Exception as e:
|
81 |
logging.error(f"获取API使用情况失败:"+str(e))
|
82 |
return standard_error_msg + error_retrieve_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
read_timeout_prompt
|
11 |
)
|
12 |
|
13 |
+
from . import shared
|
14 |
+
from modules.config import retrieve_proxy
|
15 |
import os, datetime
|
16 |
|
17 |
def get_billing_data(openai_api_key, billing_url):
|
|
|
19 |
"Content-Type": "application/json",
|
20 |
"Authorization": f"Bearer {openai_api_key}"
|
21 |
}
|
22 |
+
|
23 |
timeout = timeout_all
|
24 |
+
with retrieve_proxy():
|
25 |
+
response = requests.get(
|
26 |
+
billing_url,
|
27 |
+
headers=headers,
|
28 |
+
timeout=timeout,
|
29 |
+
)
|
30 |
+
|
|
|
31 |
if response.status_code == 200:
|
32 |
data = response.json()
|
33 |
return data
|
34 |
else:
|
35 |
raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
|
36 |
+
|
37 |
|
38 |
def get_usage(openai_api_key):
|
39 |
try:
|
40 |
+
curr_time = datetime.datetime.now()
|
41 |
+
last_day_of_month = get_last_day_of_month(curr_time).strftime("%Y-%m-%d")
|
42 |
+
first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
|
43 |
+
usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
|
44 |
try:
|
45 |
+
usage_data = get_billing_data(openai_api_key, usage_url)
|
|
|
|
|
46 |
except Exception as e:
|
47 |
+
logging.error(f"获取API使用情况失败:"+str(e))
|
48 |
+
return f"**获取API使用情况失败**"
|
49 |
+
rounded_usage = "{:.5f}".format(usage_data['total_usage']/100)
|
50 |
+
return f"**本月使用金额** \u3000 ${rounded_usage}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
except requests.exceptions.ConnectTimeout:
|
52 |
status_text = standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
|
53 |
return status_text
|
|
|
57 |
except Exception as e:
|
58 |
logging.error(f"获取API使用情况失败:"+str(e))
|
59 |
return standard_error_msg + error_retrieve_prompt
|
60 |
+
|
61 |
+
def get_last_day_of_month(any_day):
|
62 |
+
# The day 28 exists in every month. 4 days later, it's always next month
|
63 |
+
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
|
64 |
+
# subtracting the number of the current day brings us back one month
|
65 |
+
return next_month - datetime.timedelta(days=next_month.day)
|
modules/pdf_func.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from types import SimpleNamespace
|
2 |
+
import pdfplumber
|
3 |
+
import logging
|
4 |
+
from llama_index import Document
|
5 |
+
|
6 |
+
def prepare_table_config(crop_page):
|
7 |
+
"""Prepare table查找边界, 要求page为原始page
|
8 |
+
|
9 |
+
From https://github.com/jsvine/pdfplumber/issues/242
|
10 |
+
"""
|
11 |
+
page = crop_page.root_page # root/parent
|
12 |
+
cs = page.curves + page.edges
|
13 |
+
def curves_to_edges():
|
14 |
+
"""See https://github.com/jsvine/pdfplumber/issues/127"""
|
15 |
+
edges = []
|
16 |
+
for c in cs:
|
17 |
+
edges += pdfplumber.utils.rect_to_edges(c)
|
18 |
+
return edges
|
19 |
+
edges = curves_to_edges()
|
20 |
+
return {
|
21 |
+
"vertical_strategy": "explicit",
|
22 |
+
"horizontal_strategy": "explicit",
|
23 |
+
"explicit_vertical_lines": edges,
|
24 |
+
"explicit_horizontal_lines": edges,
|
25 |
+
"intersection_y_tolerance": 10,
|
26 |
+
}
|
27 |
+
|
28 |
+
def get_text_outside_table(crop_page):
|
29 |
+
ts = prepare_table_config(crop_page)
|
30 |
+
if len(ts["explicit_vertical_lines"]) == 0 or len(ts["explicit_horizontal_lines"]) == 0:
|
31 |
+
return crop_page
|
32 |
+
|
33 |
+
### Get the bounding boxes of the tables on the page.
|
34 |
+
bboxes = [table.bbox for table in crop_page.root_page.find_tables(table_settings=ts)]
|
35 |
+
def not_within_bboxes(obj):
|
36 |
+
"""Check if the object is in any of the table's bbox."""
|
37 |
+
def obj_in_bbox(_bbox):
|
38 |
+
"""See https://github.com/jsvine/pdfplumber/blob/stable/pdfplumber/table.py#L404"""
|
39 |
+
v_mid = (obj["top"] + obj["bottom"]) / 2
|
40 |
+
h_mid = (obj["x0"] + obj["x1"]) / 2
|
41 |
+
x0, top, x1, bottom = _bbox
|
42 |
+
return (h_mid >= x0) and (h_mid < x1) and (v_mid >= top) and (v_mid < bottom)
|
43 |
+
return not any(obj_in_bbox(__bbox) for __bbox in bboxes)
|
44 |
+
|
45 |
+
return crop_page.filter(not_within_bboxes)
|
46 |
+
# 请使用 LaTeX 表达公式,行内公式以 $ 包裹,行间公式以 $$ 包裹
|
47 |
+
|
48 |
+
extract_words = lambda page: page.extract_words(keep_blank_chars=True, y_tolerance=0, x_tolerance=1, extra_attrs=["fontname", "size", "object_type"])
|
49 |
+
# dict_keys(['text', 'x0', 'x1', 'top', 'doctop', 'bottom', 'upright', 'direction', 'fontname', 'size'])
|
50 |
+
|
51 |
+
def get_title_with_cropped_page(first_page):
|
52 |
+
title = [] # 处理标题
|
53 |
+
x0,top,x1,bottom = first_page.bbox # 获取页面边框
|
54 |
+
|
55 |
+
for word in extract_words(first_page):
|
56 |
+
word = SimpleNamespace(**word)
|
57 |
+
|
58 |
+
if word.size >= 14:
|
59 |
+
title.append(word.text)
|
60 |
+
title_bottom = word.bottom
|
61 |
+
elif word.text == "Abstract": # 获取页面abstract
|
62 |
+
top = word.top
|
63 |
+
|
64 |
+
user_info = [i["text"] for i in extract_words(first_page.within_bbox((x0,title_bottom,x1,top)))]
|
65 |
+
# 裁剪掉上半部分, within_bbox: full_included; crop: partial_included
|
66 |
+
return title, user_info, first_page.within_bbox((x0,top,x1,bottom))
|
67 |
+
|
68 |
+
def get_column_cropped_pages(pages, two_column=True):
|
69 |
+
new_pages = []
|
70 |
+
for page in pages:
|
71 |
+
if two_column:
|
72 |
+
left = page.within_bbox((0, 0, page.width/2, page.height),relative=True)
|
73 |
+
right = page.within_bbox((page.width/2, 0, page.width, page.height), relative=True)
|
74 |
+
new_pages.append(left)
|
75 |
+
new_pages.append(right)
|
76 |
+
else:
|
77 |
+
new_pages.append(page)
|
78 |
+
|
79 |
+
return new_pages
|
80 |
+
|
81 |
+
def parse_pdf(filename, two_column = True):
|
82 |
+
level = logging.getLogger().level
|
83 |
+
if level == logging.getLevelName("DEBUG"):
|
84 |
+
logging.getLogger().setLevel("INFO")
|
85 |
+
|
86 |
+
with pdfplumber.open(filename) as pdf:
|
87 |
+
title, user_info, first_page = get_title_with_cropped_page(pdf.pages[0])
|
88 |
+
new_pages = get_column_cropped_pages([first_page] + pdf.pages[1:], two_column)
|
89 |
+
|
90 |
+
chapters = []
|
91 |
+
# tuple (chapter_name, [pageid] (start,stop), chapter_text)
|
92 |
+
create_chapter = lambda page_start,name_top,name_bottom: SimpleNamespace(
|
93 |
+
name=[],
|
94 |
+
name_top=name_top,
|
95 |
+
name_bottom=name_bottom,
|
96 |
+
record_chapter_name = True,
|
97 |
+
|
98 |
+
page_start=page_start,
|
99 |
+
page_stop=None,
|
100 |
+
|
101 |
+
text=[],
|
102 |
+
)
|
103 |
+
cur_chapter = None
|
104 |
+
|
105 |
+
# 按页遍历PDF文档
|
106 |
+
for idx, page in enumerate(new_pages):
|
107 |
+
page = get_text_outside_table(page)
|
108 |
+
|
109 |
+
# 按行遍历页面文本
|
110 |
+
for word in extract_words(page):
|
111 |
+
word = SimpleNamespace(**word)
|
112 |
+
|
113 |
+
# 检查行文本是否以12号字体打印,如果是,则将其作为新章节开始
|
114 |
+
if word.size >= 11: # 出现chapter name
|
115 |
+
if cur_chapter is None:
|
116 |
+
cur_chapter = create_chapter(page.page_number, word.top, word.bottom)
|
117 |
+
elif not cur_chapter.record_chapter_name or (cur_chapter.name_bottom != cur_chapter.name_bottom and cur_chapter.name_top != cur_chapter.name_top):
|
118 |
+
# 不再继续写chapter name
|
119 |
+
cur_chapter.page_stop = page.page_number # stop id
|
120 |
+
chapters.append(cur_chapter)
|
121 |
+
# 重置当前chapter信息
|
122 |
+
cur_chapter = create_chapter(page.page_number, word.top, word.bottom)
|
123 |
+
|
124 |
+
# print(word.size, word.top, word.bottom, word.text)
|
125 |
+
cur_chapter.name.append(word.text)
|
126 |
+
else:
|
127 |
+
cur_chapter.record_chapter_name = False # chapter name 结束
|
128 |
+
cur_chapter.text.append(word.text)
|
129 |
+
else:
|
130 |
+
# 处理最后一个章节
|
131 |
+
cur_chapter.page_stop = page.page_number # stop id
|
132 |
+
chapters.append(cur_chapter)
|
133 |
+
|
134 |
+
for i in chapters:
|
135 |
+
logging.info(f"section: {i.name} pages:{i.page_start, i.page_stop} word-count:{len(i.text)}")
|
136 |
+
logging.debug(" ".join(i.text))
|
137 |
+
|
138 |
+
title = " ".join(title)
|
139 |
+
user_info = " ".join(user_info)
|
140 |
+
text = f"Article Title: {title}, Information:{user_info}\n"
|
141 |
+
for idx, chapter in enumerate(chapters):
|
142 |
+
chapter.name = " ".join(chapter.name)
|
143 |
+
text += f"The {idx}th Chapter {chapter.name}: " + " ".join(chapter.text) + "\n"
|
144 |
+
|
145 |
+
logging.getLogger().setLevel(level)
|
146 |
+
return Document(text=text, extra_info={"title": title})
|
147 |
+
|
148 |
+
BASE_POINTS = """
|
149 |
+
1. Who are the authors?
|
150 |
+
2. What is the process of the proposed method?
|
151 |
+
3. What is the performance of the proposed method? Please note down its performance metrics.
|
152 |
+
4. What are the baseline models and their performances? Please note down these baseline methods.
|
153 |
+
5. What dataset did this paper use?
|
154 |
+
"""
|
155 |
+
|
156 |
+
READING_PROMPT = """
|
157 |
+
You are a researcher helper bot. You can help the user with research paper reading and summarizing. \n
|
158 |
+
Now I am going to send you a paper. You need to read it and summarize it for me part by part. \n
|
159 |
+
When you are reading, You need to focus on these key points:{}
|
160 |
+
"""
|
161 |
+
|
162 |
+
READING_PROMT_V2 = """
|
163 |
+
You are a researcher helper bot. You can help the user with research paper reading and summarizing. \n
|
164 |
+
Now I am going to send you a paper. You need to read it and summarize it for me part by part. \n
|
165 |
+
When you are reading, You need to focus on these key points:{},
|
166 |
+
|
167 |
+
And You need to generate a brief but informative title for this part.
|
168 |
+
Your return format:
|
169 |
+
- title: '...'
|
170 |
+
- summary: '...'
|
171 |
+
"""
|
172 |
+
|
173 |
+
SUMMARY_PROMPT = "You are a researcher helper bot. Now you need to read the summaries of a research paper."
|
174 |
+
|
175 |
+
|
176 |
+
if __name__ == '__main__':
|
177 |
+
# Test code
|
178 |
+
z = parse_pdf("./build/test.pdf")
|
179 |
+
print(z["user_info"])
|
180 |
+
print(z["title"])
|
modules/presets.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
# -*- coding:utf-8 -*-
|
2 |
import gradio as gr
|
|
|
3 |
|
4 |
# ChatGPT 设置
|
5 |
initial_prompt = "You are a helpful assistant."
|
6 |
-
|
|
|
7 |
BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
|
8 |
USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
|
9 |
-
HISTORY_DIR = "history"
|
10 |
TEMPLATES_DIR = "templates"
|
11 |
|
12 |
# 错误信息
|
@@ -28,7 +30,7 @@ CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
|
|
28 |
SIM_K = 5
|
29 |
INDEX_QUERY_TEMPRATURE = 1.0
|
30 |
|
31 |
-
title = """<h1 align="left" style="min-width:200px; margin-top:
|
32 |
description = """\
|
33 |
<div align="center" style="margin:16px 0">
|
34 |
|
|
|
1 |
# -*- coding:utf-8 -*-
|
2 |
import gradio as gr
|
3 |
+
from pathlib import Path
|
4 |
|
5 |
# ChatGPT 设置
|
6 |
initial_prompt = "You are a helpful assistant."
|
7 |
+
API_HOST = "api.openai.com"
|
8 |
+
COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
|
9 |
BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
|
10 |
USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
|
11 |
+
HISTORY_DIR = Path("history")
|
12 |
TEMPLATES_DIR = "templates"
|
13 |
|
14 |
# 错误信息
|
|
|
30 |
SIM_K = 5
|
31 |
INDEX_QUERY_TEMPRATURE = 1.0
|
32 |
|
33 |
+
title = """<h1 align="left" style="min-width:200px; margin-top:6px; white-space: nowrap;">川虎ChatGPT 🚀</h1>"""
|
34 |
description = """\
|
35 |
<div align="center" style="margin:16px 0">
|
36 |
|
modules/shared.py
CHANGED
@@ -1,8 +1,13 @@
|
|
1 |
-
from modules.presets import
|
|
|
|
|
2 |
|
3 |
class State:
|
4 |
interrupted = False
|
5 |
-
|
|
|
|
|
|
|
6 |
|
7 |
def interrupt(self):
|
8 |
self.interrupted = True
|
@@ -10,15 +15,41 @@ class State:
|
|
10 |
def recover(self):
|
11 |
self.interrupted = False
|
12 |
|
13 |
-
def
|
14 |
-
self.
|
|
|
|
|
|
|
15 |
|
16 |
-
def
|
17 |
-
self.
|
18 |
-
|
|
|
|
|
|
|
19 |
|
20 |
def reset_all(self):
|
21 |
self.interrupted = False
|
22 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
state = State()
|
|
|
1 |
+
from modules.presets import COMPLETION_URL, BALANCE_API_URL, USAGE_API_URL, API_HOST
|
2 |
+
import os
|
3 |
+
import queue
|
4 |
|
5 |
class State:
|
6 |
interrupted = False
|
7 |
+
multi_api_key = False
|
8 |
+
completion_url = COMPLETION_URL
|
9 |
+
balance_api_url = BALANCE_API_URL
|
10 |
+
usage_api_url = USAGE_API_URL
|
11 |
|
12 |
def interrupt(self):
|
13 |
self.interrupted = True
|
|
|
15 |
def recover(self):
|
16 |
self.interrupted = False
|
17 |
|
18 |
+
def set_api_host(self, api_host):
|
19 |
+
self.completion_url = f"https://{api_host}/v1/chat/completions"
|
20 |
+
self.balance_api_url = f"https://{api_host}/dashboard/billing/credit_grants"
|
21 |
+
self.usage_api_url = f"https://{api_host}/dashboard/billing/usage"
|
22 |
+
os.environ["OPENAI_API_BASE"] = f"https://{api_host}/v1"
|
23 |
|
24 |
+
def reset_api_host(self):
|
25 |
+
self.completion_url = COMPLETION_URL
|
26 |
+
self.balance_api_url = BALANCE_API_URL
|
27 |
+
self.usage_api_url = USAGE_API_URL
|
28 |
+
os.environ["OPENAI_API_BASE"] = f"https://{API_HOST}/v1"
|
29 |
+
return API_HOST
|
30 |
|
31 |
def reset_all(self):
|
32 |
self.interrupted = False
|
33 |
+
self.completion_url = COMPLETION_URL
|
34 |
+
|
35 |
+
def set_api_key_queue(self, api_key_list):
|
36 |
+
self.multi_api_key = True
|
37 |
+
self.api_key_queue = queue.Queue()
|
38 |
+
for api_key in api_key_list:
|
39 |
+
self.api_key_queue.put(api_key)
|
40 |
+
|
41 |
+
def switching_api_key(self, func):
|
42 |
+
if not hasattr(self, "api_key_queue"):
|
43 |
+
return func
|
44 |
+
|
45 |
+
def wrapped(*args, **kwargs):
|
46 |
+
api_key = self.api_key_queue.get()
|
47 |
+
args = list(args)[1:]
|
48 |
+
ret = func(api_key, *args, **kwargs)
|
49 |
+
self.api_key_queue.put(api_key)
|
50 |
+
return ret
|
51 |
+
|
52 |
+
return wrapped
|
53 |
+
|
54 |
|
55 |
state = State()
|
modules/utils.py
CHANGED
@@ -21,14 +21,11 @@ from markdown import markdown
|
|
21 |
from pygments import highlight
|
22 |
from pygments.lexers import get_lexer_by_name
|
23 |
from pygments.formatters import HtmlFormatter
|
|
|
24 |
|
25 |
from modules.presets import *
|
26 |
-
|
27 |
-
|
28 |
-
logging.basicConfig(
|
29 |
-
level=logging.INFO,
|
30 |
-
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
31 |
-
)
|
32 |
|
33 |
if TYPE_CHECKING:
|
34 |
from typing import TypedDict
|
@@ -156,8 +153,11 @@ def construct_assistant(text):
|
|
156 |
return construct_text("assistant", text)
|
157 |
|
158 |
|
159 |
-
def construct_token_message(
|
160 |
-
|
|
|