AllenYkl commited on
Commit
c5ac60d
1 Parent(s): b5ea958

Upload 8 files

Browse files
bin_public/app/Chatbot.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+
3
+ import sys
4
+ from overwrites import *
5
+ from chat_func import *
6
+ from bin_public.utils.tools import *
7
+ from bin_public.utils.utils_db import *
8
+ from bin_public.config.presets import *
9
+
10
+ my_api_key = ""
11
+
12
+ # if we are running in Docker
13
+ if os.environ.get('dockerrun') == 'yes':
14
+ dockerflag = True
15
+ else:
16
+ dockerflag = False
17
+
18
+ authflag = False
19
+
20
+ if dockerflag:
21
+ my_api_key = os.environ.get('my_api_key')
22
+ if my_api_key == "empty":
23
+ print("Please give a api key!")
24
+ sys.exit(1)
25
+ # auth
26
+ username = os.environ.get('USERNAME')
27
+ password = os.environ.get('PASSWORD')
28
+ if not (isinstance(username, type(None)) or isinstance(password, type(None))):
29
+ authflag = True
30
+ else:
31
+ '''if not my_api_key and os.path.exists("api_key.txt") and os.path.getsize("api_key.txt"): # API key 所在的文件
32
+ with open("api_key.txt", "r") as f:
33
+ my_api_key = f.read().strip()'''
34
+
35
+
36
+
37
+ if os.path.exists("auth.json"):
38
+ with open("auth.json", "r") as f:
39
+ auth = json.load(f)
40
+ username = auth["username"]
41
+ password = auth["password"]
42
+ if username != "" and password != "":
43
+ authflag = True
44
+
45
+ gr.Chatbot.postprocess = postprocess
46
+ PromptHelper.compact_text_chunks = compact_text_chunks
47
+
48
+ with gr.Blocks(css=customCSS) as demo:
49
+ history = gr.State([])
50
+ token_count = gr.State([])
51
+ invite_code = gr.State()
52
+ promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
53
+ TRUECOMSTANT = gr.State(True)
54
+ FALSECONSTANT = gr.State(False)
55
+ topic = gr.State("未命名对话历史记录")
56
+
57
+ # gr.HTML("""
58
+ # <div style="text-align: center; margin-top: 20px;">
59
+ # """)
60
+ gr.HTML(title)
61
+
62
+ with gr.Row(scale=1).style(equal_height=True):
63
+ with gr.Column(scale=5):
64
+ with gr.Row(scale=1):
65
+ chatbot = gr.Chatbot().style(height=600) # .style(color_map=("#1D51EE", "#585A5B"))
66
+ with gr.Row(scale=1):
67
+ with gr.Column(scale=12):
68
+ user_input = gr.Textbox(show_label=False, placeholder="在这里输入").style(
69
+ container=False)
70
+ with gr.Column(min_width=50, scale=1):
71
+ submitBtn = gr.Button("🚀", variant="primary")
72
+ with gr.Row(scale=1):
73
+ emptyBtn = gr.Button("🧹 新的对话", )
74
+ retryBtn = gr.Button("🔄 重新生成")
75
+ delLastBtn = gr.Button("🗑️ 删除一条对话")
76
+ reduceTokenBtn = gr.Button("♻️ 总结对话")
77
+
78
+ with gr.Column():
79
+ with gr.Column(min_width=50, scale=1):
80
+ status_display = gr.Markdown("status: ready")
81
+ with gr.Tab(label="ChatGPT"):
82
+ keyTXT = gr.Textbox(show_label=True, placeholder=f"OpenAI API-key...",
83
+ type="password", visible=not HIDE_MY_KEY, label="API-Key/Invite-Code")
84
+
85
+ keyTxt = gr.Textbox(visible=False)
86
+
87
+ key_button = gr.Button("Enter")
88
+
89
+ model_select_dropdown = gr.Dropdown(label="选择模型", choices=MODELS, multiselect=False,
90
+ value=MODELS[0])
91
+ with gr.Accordion("参数", open=False):
92
+ temperature = gr.Slider(minimum=-0, maximum=2.0, value=1.0,
93
+ step=0.1, interactive=True, label="Temperature", )
94
+ top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05,
95
+ interactive=True, label="Top-p (nucleus sampling)", visible=False)
96
+ use_streaming_checkbox = gr.Checkbox(label="实时传输回答", value=True, visible=enable_streaming_option)
97
+ use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
98
+ index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
99
+
100
+
101
+ with gr.Tab(label="Prompt"):
102
+ systemPromptTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入System Prompt...",
103
+ label="System prompt", value=initial_prompt).style(container=True)
104
+ with gr.Accordion(label="加载Prompt模板", open=True):
105
+ with gr.Column():
106
+ with gr.Row():
107
+ with gr.Column(scale=6):
108
+ templateFileSelectDropdown = gr.Dropdown(label="选择Prompt模板集合文件",
109
+ choices=get_template_names(plain=True),
110
+ multiselect=False,
111
+ value=get_template_names(plain=True)[0])
112
+ with gr.Column(scale=1):
113
+ templateRefreshBtn = gr.Button("🔄 刷新")
114
+ with gr.Row():
115
+ with gr.Column():
116
+ templateSelectDropdown = gr.Dropdown(label="从Prompt模板中加载", choices=load_template(
117
+ get_template_names(plain=True)[0], mode=1), multiselect=False, value=
118
+ load_template(
119
+ get_template_names(plain=True)[0], mode=1)[
120
+ 0])
121
+
122
+ with gr.Tab(label="保存/加载"):
123
+ with gr.Accordion(label="保存/加载对话历史记录", open=True):
124
+ with gr.Column():
125
+ with gr.Row():
126
+ with gr.Column(scale=6):
127
+ historyFileSelectDropdown = gr.Dropdown(
128
+ label="从列表中加载对话",
129
+ choices=get_history_names(plain=True),
130
+ multiselect=False,
131
+ value=get_history_names(plain=True)[0],
132
+ visible=False
133
+ )
134
+
135
+ with gr.Row():
136
+ with gr.Column(scale=6):
137
+ saveFileName = gr.Textbox(
138
+ show_label=True,
139
+ placeholder=f"设置文件名: 默认为.json,可选为.md",
140
+ label="设置保存文件名",
141
+ value="对话历史记录",
142
+ ).style(container=True)
143
+ with gr.Column(scale=1):
144
+ saveHistoryBtn = gr.Button("💾 保存对话")
145
+ exportMarkdownBtn = gr.Button("📝 导出为Markdown")
146
+ #gr.Markdown("默认保存于history文件夹")
147
+ with gr.Row():
148
+ with gr.Column():
149
+ downloadFile = gr.File(interactive=True)
150
+
151
+ gr.HTML("""
152
+ <div style="text-align: center; margin-top: 20px; margin-bottom: 20px;">
153
+ """)
154
+ gr.Markdown(description)
155
+
156
+ # 输入为api key则保持不变,为邀请码则调用中心的api key
157
+ key_button.click(key_preprocessing, [keyTXT], [status_display, keyTxt, invite_code])
158
+
159
+ user_input.submit(predict, [
160
+ keyTxt,
161
+ invite_code,
162
+ systemPromptTxt,
163
+ history,
164
+ user_input,
165
+ chatbot,
166
+ token_count,
167
+ top_p,
168
+ temperature,
169
+ use_streaming_checkbox,
170
+ model_select_dropdown,
171
+ use_websearch_checkbox,
172
+ index_files],
173
+ [chatbot, history, status_display, token_count], show_progress=True)
174
+ user_input.submit(reset_textbox, [], [user_input])
175
+
176
+ submitBtn.click(predict, [
177
+ keyTxt,
178
+ invite_code,
179
+ systemPromptTxt,
180
+ history,
181
+ user_input,
182
+ chatbot,
183
+ token_count,
184
+ top_p,
185
+ temperature,
186
+ use_streaming_checkbox,
187
+ model_select_dropdown,
188
+ use_websearch_checkbox,
189
+ index_files],
190
+ [chatbot, history, status_display, token_count], show_progress=True)
191
+
192
+ submitBtn.click(reset_textbox, [], [user_input])
193
+
194
+ emptyBtn.click(reset_state, outputs=[chatbot, history, token_count, status_display], show_progress=True)
195
+
196
+ retryBtn.click(retry,
197
+ [keyTxt, invite_code, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox,
198
+ model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
199
+
200
+ delLastBtn.click(delete_last_conversation, [chatbot, history, token_count], [
201
+ chatbot, history, token_count, status_display], show_progress=True)
202
+
203
+ reduceTokenBtn.click(reduce_token_size, [keyTxt, invite_code, systemPromptTxt, history, chatbot, token_count, top_p,
204
+ temperature, use_streaming_checkbox, model_select_dropdown],
205
+ [chatbot, history, status_display, token_count], show_progress=True)
206
+ # History
207
+ saveHistoryBtn.click(
208
+ save_chat_history,
209
+ [saveFileName, systemPromptTxt, history, chatbot],
210
+ downloadFile,
211
+ show_progress=True,
212
+ )
213
+ saveHistoryBtn.click(get_history_names, None, [historyFileSelectDropdown])
214
+ exportMarkdownBtn.click(
215
+ export_markdown,
216
+ [saveFileName, systemPromptTxt, history, chatbot],
217
+ downloadFile,
218
+ show_progress=True,
219
+ )
220
+ #historyRefreshBtn.click(get_history_names, None, [historyFileSelectDropdown])
221
+ historyFileSelectDropdown.change(
222
+ load_chat_history,
223
+ [historyFileSelectDropdown, systemPromptTxt, history, chatbot],
224
+ [saveFileName, systemPromptTxt, history, chatbot],
225
+ show_progress=True,
226
+ )
227
+ downloadFile.change(
228
+ load_chat_history,
229
+ [downloadFile, systemPromptTxt, history, chatbot],
230
+ [saveFileName, systemPromptTxt, history, chatbot],
231
+ )
232
+
233
+
234
+ # Template
235
+ templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
236
+
237
+ templateFileSelectDropdown.change(load_template, [templateFileSelectDropdown],
238
+ [promptTemplates, templateSelectDropdown], show_progress=True)
239
+
240
+ templateSelectDropdown.change(get_template_content, [promptTemplates, templateSelectDropdown, systemPromptTxt],
241
+ [systemPromptTxt], show_progress=True)
242
+
243
+ logging.info( "\n访问 http://localhost:7860 查看界面")
244
+ # 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
245
+ demo.title = "ChatGPT-长江商学院 🚀"
246
+
247
+ if __name__ == "__main__":
248
+ #if running in Docker
249
+ if dockerflag:
250
+ if authflag:
251
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860,auth=(username, password))
252
+ else:
253
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False)
254
+ #if not running in Docker
255
+ else:
256
+ if authflag:
257
+ demo.queue().launch(share=False, auth=(username, password))
258
+ else:
259
+ demo.queue().launch(share=False) # 改为 share=True 可以创建公开分享链接
bin_public/app/app.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ import sys
3
+ from overwrites import *
4
+ from chat_func import *
5
+ from bin_public.utils.tools import *
6
+ from bin_public.utils.utils_db import *
7
+ from bin_public.config.presets import *
8
+
9
+ my_api_key = ""
10
+
11
+ # if we are running in Docker
12
+ if os.environ.get('dockerrun') == 'yes':
13
+ dockerflag = True
14
+ else:
15
+ dockerflag = False
16
+
17
+ authflag = False
18
+
19
+ if dockerflag:
20
+ my_api_key = os.environ.get('my_api_key')
21
+ if my_api_key == "empty":
22
+ print("Please give a api key!")
23
+ sys.exit(1)
24
+ # auth
25
+ username = os.environ.get('USERNAME')
26
+ password = os.environ.get('PASSWORD')
27
+ if not (isinstance(username, type(None)) or isinstance(password, type(None))):
28
+ authflag = True
29
+ else:
30
+ '''if not my_api_key and os.path.exists("api_key.txt") and os.path.getsize("api_key.txt"): # API key 所在的文件
31
+ with open("api_key.txt", "r") as f:
32
+ my_api_key = f.read().strip()'''
33
+
34
+
35
+
36
+ if os.path.exists("auth.json"):
37
+ with open("auth.json", "r") as f:
38
+ auth = json.load(f)
39
+ username = auth["username"]
40
+ password = auth["password"]
41
+ if username != "" and password != "":
42
+ authflag = True
43
+
44
+ gr.Chatbot.postprocess = postprocess
45
+ PromptHelper.compact_text_chunks = compact_text_chunks
46
+
47
+ with gr.Blocks(css=customCSS) as demo:
48
+ history = gr.State([])
49
+ token_count = gr.State([])
50
+ invite_code = gr.State()
51
+ promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
52
+ TRUECOMSTANT = gr.State(True)
53
+ FALSECONSTANT = gr.State(False)
54
+ topic = gr.State("未命名对话历史记录")
55
+
56
+ # gr.HTML("""
57
+ # <div style="text-align: center; margin-top: 20px;">
58
+ # """)
59
+ gr.HTML(title)
60
+
61
+ with gr.Row(scale=1).style(equal_height=True):
62
+ with gr.Column(scale=5):
63
+ with gr.Row(scale=1):
64
+ chatbot = gr.Chatbot().style(height=600) # .style(color_map=("#1D51EE", "#585A5B"))
65
+ with gr.Row(scale=1):
66
+ with gr.Column(scale=12):
67
+ user_input = gr.Textbox(show_label=False, placeholder="在这里输入").style(
68
+ container=False)
69
+ with gr.Column(min_width=50, scale=1):
70
+ submitBtn = gr.Button("🚀", variant="primary")
71
+ with gr.Row(scale=1):
72
+ emptyBtn = gr.Button("🧹 新的对话", )
73
+ retryBtn = gr.Button("🔄 重新生成")
74
+ delLastBtn = gr.Button("🗑️ 删除一条对话")
75
+ reduceTokenBtn = gr.Button("♻️ 总结对话")
76
+
77
+ with gr.Column():
78
+ with gr.Column(min_width=50, scale=1):
79
+ status_display = gr.Markdown("status: ready")
80
+ with gr.Tab(label="ChatGPT"):
81
+ keyTXT = gr.Textbox(show_label=True, placeholder=f"OpenAI API-key...",
82
+ type="password", visible=not HIDE_MY_KEY, label="API-Key/Invite-Code")
83
+
84
+ keyTxt = gr.Textbox(visible=False)
85
+
86
+ key_button = gr.Button("Enter")
87
+
88
+ model_select_dropdown = gr.Dropdown(label="选择模型", choices=MODELS, multiselect=False,
89
+ value=MODELS[0])
90
+ with gr.Accordion("参数", open=False):
91
+ temperature = gr.Slider(minimum=-0, maximum=2.0, value=1.0,
92
+ step=0.1, interactive=True, label="Temperature", )
93
+ top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05,
94
+ interactive=True, label="Top-p (nucleus sampling)", visible=False)
95
+ use_streaming_checkbox = gr.Checkbox(label="实时传输回答", value=True, visible=enable_streaming_option)
96
+ use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
97
+ index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
98
+
99
+
100
+ with gr.Tab(label="Prompt"):
101
+ systemPromptTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入System Prompt...",
102
+ label="System prompt", value=initial_prompt).style(container=True)
103
+ with gr.Accordion(label="加载Prompt模板", open=True):
104
+ with gr.Column():
105
+ with gr.Row():
106
+ with gr.Column(scale=6):
107
+ templateFileSelectDropdown = gr.Dropdown(label="选择Prompt模板集合文件",
108
+ choices=get_template_names(plain=True),
109
+ multiselect=False,
110
+ value=get_template_names(plain=True)[0])
111
+ with gr.Column(scale=1):
112
+ templateRefreshBtn = gr.Button("🔄 刷新")
113
+ with gr.Row():
114
+ with gr.Column():
115
+ templateSelectDropdown = gr.Dropdown(label="从Prompt模板中加载", choices=load_template(
116
+ get_template_names(plain=True)[0], mode=1), multiselect=False, value=
117
+ load_template(
118
+ get_template_names(plain=True)[0], mode=1)[
119
+ 0])
120
+
121
+ with gr.Tab(label="保存/加载"):
122
+ with gr.Accordion(label="保存/加载对话历史记录", open=True):
123
+ with gr.Column():
124
+ with gr.Row():
125
+ with gr.Column(scale=6):
126
+ historyFileSelectDropdown = gr.Dropdown(
127
+ label="从列表中加载对话",
128
+ choices=get_history_names(plain=True),
129
+ multiselect=False,
130
+ value=get_history_names(plain=True)[0],
131
+ visible=False
132
+ )
133
+
134
+ with gr.Row():
135
+ with gr.Column(scale=6):
136
+ saveFileName = gr.Textbox(
137
+ show_label=True,
138
+ placeholder=f"设置文件名: 默认为.json,可选为.md",
139
+ label="设置保存文件名",
140
+ value="对话历史记录",
141
+ ).style(container=True)
142
+ with gr.Column(scale=1):
143
+ saveHistoryBtn = gr.Button("💾 保存对话")
144
+ exportMarkdownBtn = gr.Button("📝 导出为Markdown")
145
+ #gr.Markdown("默认保存于history文件夹")
146
+ with gr.Row():
147
+ with gr.Column():
148
+ downloadFile = gr.File(interactive=True)
149
+
150
+ gr.HTML("""
151
+ <div style="text-align: center; margin-top: 20px; margin-bottom: 20px;">
152
+ """)
153
+ gr.Markdown(description)
154
+
155
+ # 输入为api key则保持不变,为邀请码则调用中心的api key
156
+ key_button.click(key_preprocessing, [keyTXT], [status_display, keyTxt, invite_code])
157
+
158
+ user_input.submit(predict, [
159
+ keyTxt,
160
+ invite_code,
161
+ systemPromptTxt,
162
+ history,
163
+ user_input,
164
+ chatbot,
165
+ token_count,
166
+ top_p,
167
+ temperature,
168
+ use_streaming_checkbox,
169
+ model_select_dropdown,
170
+ use_websearch_checkbox,
171
+ index_files],
172
+ [chatbot, history, status_display, token_count], show_progress=True)
173
+ user_input.submit(reset_textbox, [], [user_input])
174
+
175
+ submitBtn.click(predict, [
176
+ keyTxt,
177
+ invite_code,
178
+ systemPromptTxt,
179
+ history,
180
+ user_input,
181
+ chatbot,
182
+ token_count,
183
+ top_p,
184
+ temperature,
185
+ use_streaming_checkbox,
186
+ model_select_dropdown,
187
+ use_websearch_checkbox,
188
+ index_files],
189
+ [chatbot, history, status_display, token_count], show_progress=True)
190
+
191
+ submitBtn.click(reset_textbox, [], [user_input])
192
+
193
+ emptyBtn.click(reset_state, outputs=[chatbot, history, token_count, status_display], show_progress=True)
194
+
195
+ retryBtn.click(retry,
196
+ [keyTxt, invite_code, systemPromptTxt, history, chatbot, token_count, top_p, temperature, use_streaming_checkbox,
197
+ model_select_dropdown], [chatbot, history, status_display, token_count], show_progress=True)
198
+
199
+ delLastBtn.click(delete_last_conversation, [chatbot, history, token_count], [
200
+ chatbot, history, token_count, status_display], show_progress=True)
201
+
202
+ reduceTokenBtn.click(reduce_token_size, [keyTxt, invite_code, systemPromptTxt, history, chatbot, token_count, top_p,
203
+ temperature, use_streaming_checkbox, model_select_dropdown],
204
+ [chatbot, history, status_display, token_count], show_progress=True)
205
+ # History
206
+ saveHistoryBtn.click(
207
+ save_chat_history,
208
+ [saveFileName, systemPromptTxt, history, chatbot],
209
+ downloadFile,
210
+ show_progress=True,
211
+ )
212
+ saveHistoryBtn.click(get_history_names, None, [historyFileSelectDropdown])
213
+ exportMarkdownBtn.click(
214
+ export_markdown,
215
+ [saveFileName, systemPromptTxt, history, chatbot],
216
+ downloadFile,
217
+ show_progress=True,
218
+ )
219
+ #historyRefreshBtn.click(get_history_names, None, [historyFileSelectDropdown])
220
+ historyFileSelectDropdown.change(
221
+ load_chat_history,
222
+ [historyFileSelectDropdown, systemPromptTxt, history, chatbot],
223
+ [saveFileName, systemPromptTxt, history, chatbot],
224
+ show_progress=True,
225
+ )
226
+ downloadFile.change(
227
+ load_chat_history,
228
+ [downloadFile, systemPromptTxt, history, chatbot],
229
+ [saveFileName, systemPromptTxt, history, chatbot],
230
+ )
231
+
232
+
233
+ # Template
234
+ templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
235
+
236
+ templateFileSelectDropdown.change(load_template, [templateFileSelectDropdown],
237
+ [promptTemplates, templateSelectDropdown], show_progress=True)
238
+
239
+ templateSelectDropdown.change(get_template_content, [promptTemplates, templateSelectDropdown, systemPromptTxt],
240
+ [systemPromptTxt], show_progress=True)
241
+
242
+ logging.info( "\n访问 http://localhost:7860 查看界面")
243
+ # 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
244
+ demo.title = "ChatGPT-长江商学院 🚀"
245
+
246
+ if __name__ == "__main__":
247
+ #if running in Docker
248
+ if dockerflag:
249
+ if authflag:
250
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860,auth=(username, password))
251
+ else:
252
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False)
253
+ #if not running in Docker
254
+ else:
255
+ if authflag:
256
+ demo.queue().launch(share=False, auth=(username, password))
257
+ else:
258
+ demo.queue().launch(share=False) # 改为 share=True 可以创建公开分享链接
bin_public/app/chat_func.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ from __future__ import annotations
3
+
4
+ import requests
5
+ import urllib3
6
+
7
+ from tqdm import tqdm
8
+ from duckduckgo_search import ddg
9
+ from llama_func import *
10
+ from bin_public.utils.tools import *
11
+ from bin_public.utils.utils_db import *
12
+ # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
13
+
14
+ if TYPE_CHECKING:
15
+ from typing import TypedDict
16
+
17
+ class DataframeData(TypedDict):
18
+ headers: List[str]
19
+ data: List[List[str | int | bool]]
20
+
21
+
22
+ initial_prompt = "You are a helpful assistant."
23
+ API_URL = "https://api.openai.com/v1/chat/completions"
24
+ HISTORY_DIR = "history"
25
+ TEMPLATES_DIR = r"/templates"
26
+
27
+
28
+ def get_response(
29
+ openai_api_key, system_prompt, history, temperature, top_p, stream, selected_model
30
+ ):
31
+ headers = {
32
+ "Content-Type": "application/json",
33
+ "Authorization": f"Bearer {openai_api_key}",
34
+ }
35
+
36
+ history = [construct_system(system_prompt), *history]
37
+
38
+ payload = {
39
+ "model": selected_model,
40
+ "messages": history, # [{"role": "user", "content": f"{inputs}"}],
41
+ "temperature": temperature, # 1.0,
42
+ "top_p": top_p, # 1.0,
43
+ "n": 1,
44
+ "stream": stream,
45
+ "presence_penalty": 0,
46
+ "frequency_penalty": 0,
47
+ }
48
+ if stream:
49
+ timeout = timeout_streaming
50
+ else:
51
+ timeout = timeout_all
52
+
53
+ # 获取环境变量中的代理设置
54
+ http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
55
+ https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
56
+
57
+ # 如果存在代理设置,使用它们
58
+ proxies = {}
59
+ if http_proxy:
60
+ logging.info(f"Using HTTP proxy: {http_proxy}")
61
+ proxies["http"] = http_proxy
62
+ if https_proxy:
63
+ logging.info(f"Using HTTPS proxy: {https_proxy}")
64
+ proxies["https"] = https_proxy
65
+
66
+ # 如果有代理,使用代理发送请求,否则使用默认设置发送请求
67
+ if proxies:
68
+ response = requests.post(
69
+ API_URL,
70
+ headers=headers,
71
+ json=payload,
72
+ stream=True,
73
+ timeout=timeout,
74
+ proxies=proxies,
75
+ )
76
+ else:
77
+ response = requests.post(
78
+ API_URL,
79
+ headers=headers,
80
+ json=payload,
81
+ stream=True,
82
+ timeout=timeout,
83
+ )
84
+ return response
85
+
86
+
87
+ def stream_predict(
88
+ openai_api_key,
89
+ system_prompt,
90
+ history,
91
+ inputs,
92
+ chatbot,
93
+ all_token_counts,
94
+ top_p,
95
+ temperature,
96
+ selected_model,
97
+ fake_input=None
98
+ ):
99
+ def get_return_value():
100
+ return chatbot, history, status_text, all_token_counts
101
+
102
+ logging.info("实时回答模式")
103
+ partial_words = ""
104
+ counter = 0
105
+ status_text = "开始实时传输回答……"
106
+ history.append(construct_user(inputs))
107
+ history.append(construct_assistant(""))
108
+ if fake_input:
109
+ chatbot.append((parse_text(fake_input), ""))
110
+ else:
111
+ chatbot.append((parse_text(inputs), ""))
112
+ user_token_count = 0
113
+ if len(all_token_counts) == 0:
114
+ system_prompt_token_count = count_token(construct_system(system_prompt))
115
+ user_token_count = (
116
+ count_token(construct_user(inputs)) + system_prompt_token_count
117
+ )
118
+ else:
119
+ user_token_count = count_token(construct_user(inputs))
120
+ all_token_counts.append(user_token_count)
121
+ logging.info(f"输入token计数: {user_token_count}")
122
+ yield get_return_value()
123
+ try:
124
+ response = get_response(
125
+ openai_api_key,
126
+ system_prompt,
127
+ history,
128
+ temperature,
129
+ top_p,
130
+ True,
131
+ selected_model,
132
+ )
133
+ except requests.exceptions.ConnectTimeout:
134
+ status_text = (
135
+ standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
136
+ )
137
+ yield get_return_value()
138
+ return
139
+ except requests.exceptions.ReadTimeout:
140
+ status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
141
+ yield get_return_value()
142
+ return
143
+
144
+ yield get_return_value()
145
+ error_json_str = ""
146
+
147
+ for chunk in tqdm(response.iter_lines()):
148
+ if counter == 0:
149
+ counter += 1
150
+ continue
151
+ counter += 1
152
+ # check whether each line is non-empty
153
+ if chunk:
154
+ chunk = chunk.decode()
155
+ chunklength = len(chunk)
156
+ try:
157
+ chunk = json.loads(chunk[6:])
158
+ except json.JSONDecodeError:
159
+ logging.info(chunk)
160
+ error_json_str += chunk
161
+ status_text = f"JSON解析错误。请重置对话。收到的内容: {error_json_str}"
162
+ yield get_return_value()
163
+ continue
164
+ # decode each line as response data is in bytes
165
+ if chunklength > 6 and "delta" in chunk["choices"][0]:
166
+ finish_reason = chunk["choices"][0]["finish_reason"]
167
+ status_text = construct_token_message(
168
+ sum(all_token_counts), stream=True
169
+ )
170
+ if finish_reason == "stop":
171
+ yield get_return_value()
172
+ break
173
+ try:
174
+ partial_words = (
175
+ partial_words + chunk["choices"][0]["delta"]["content"]
176
+ )
177
+ except KeyError:
178
+ status_text = (
179
+ standard_error_msg
180
+ + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: "
181
+ + str(sum(all_token_counts))
182
+ )
183
+ yield get_return_value()
184
+ break
185
+ history[-1] = construct_assistant(partial_words)
186
+ chatbot[-1] = (chatbot[-1][0], parse_text(partial_words))
187
+ all_token_counts[-1] += 1
188
+ yield get_return_value()
189
+
190
+
191
+ def predict_all(
192
+ openai_api_key,
193
+ system_prompt,
194
+ history,
195
+ inputs,
196
+ chatbot,
197
+ all_token_counts,
198
+ top_p,
199
+ temperature,
200
+ selected_model,
201
+ fake_input=None
202
+ ):
203
+ logging.info("一次性回答模式")
204
+ history.append(construct_user(inputs))
205
+ history.append(construct_assistant(""))
206
+ if fake_input:
207
+ chatbot.append((parse_text(fake_input), ""))
208
+ else:
209
+ chatbot.append((parse_text(inputs), ""))
210
+ all_token_counts.append(count_token(construct_user(inputs)))
211
+ try:
212
+ response = get_response(
213
+ openai_api_key,
214
+ system_prompt,
215
+ history,
216
+ temperature,
217
+ top_p,
218
+ False,
219
+ selected_model,
220
+ )
221
+ except requests.exceptions.ConnectTimeout:
222
+ status_text = (
223
+ standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
224
+ )
225
+ return chatbot, history, status_text, all_token_counts
226
+ except requests.exceptions.ProxyError:
227
+ status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
228
+ return chatbot, history, status_text, all_token_counts
229
+ except requests.exceptions.SSLError:
230
+ status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
231
+ return chatbot, history, status_text, all_token_counts
232
+ response = json.loads(response.text)
233
+ content = response["choices"][0]["message"]["content"]
234
+ history[-1] = construct_assistant(content)
235
+ chatbot[-1] = (chatbot[-1][0], parse_text(content))
236
+ total_token_count = response["usage"]["total_tokens"]
237
+ all_token_counts[-1] = total_token_count - sum(all_token_counts)
238
+ status_text = construct_token_message(total_token_count)
239
+ return chatbot, history, status_text, all_token_counts
240
+
241
+
242
+ def predict(
243
+ openai_api_key,
244
+ invite_code,
245
+ system_prompt,
246
+ history,
247
+ inputs,
248
+ chatbot,
249
+ all_token_counts,
250
+ top_p,
251
+ temperature,
252
+ stream=False,
253
+ selected_model=MODELS[0],
254
+ use_websearch=False,
255
+ files = None,
256
+ should_check_token_count=True,
257
+ ): # repetition_penalty, top_k
258
+ logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
259
+ if files:
260
+ msg = "构建索引中……(这可能需要比较久的时间)"
261
+ logging.info(msg)
262
+ yield chatbot, history, msg, all_token_counts
263
+ index = construct_index(openai_api_key, file_src=files)
264
+ msg = "索引构建完成,获取回答中……"
265
+ yield chatbot, history, msg, all_token_counts
266
+ history, chatbot, status_text = chat_ai(openai_api_key, index, inputs, history, chatbot)
267
+ yield chatbot, history, status_text, all_token_counts
268
+ return
269
+
270
+ old_inputs = ""
271
+ link_references = []
272
+ if use_websearch:
273
+ search_results = ddg(inputs, max_results=5)
274
+ old_inputs = inputs
275
+ web_results = []
276
+ for idx, result in enumerate(search_results):
277
+ logging.info(f"搜索结果{idx + 1}:{result}")
278
+ domain_name = urllib3.util.parse_url(result["href"]).host
279
+ web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
280
+ link_references.append(f"[{idx+1}]: [{domain_name}]({result['href']})")
281
+ inputs = (
282
+ replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
283
+ .replace("{query}", inputs)
284
+ .replace("{web_results}", "\n\n".join(web_results))
285
+ )
286
+
287
+ if len(openai_api_key) != 51:
288
+ status_text = standard_error_msg + no_apikey_msg
289
+ logging.info(status_text)
290
+ chatbot.append((parse_text(inputs), ""))
291
+ if len(history) == 0:
292
+ history.append(construct_user(inputs))
293
+ history.append("")
294
+ all_token_counts.append(0)
295
+ else:
296
+ history[-2] = construct_user(inputs)
297
+ yield chatbot, history, status_text, all_token_counts
298
+ return
299
+
300
+ yield chatbot, history, "��始生成回答……", all_token_counts
301
+
302
+ if stream:
303
+ logging.info("使用流式传输")
304
+ iter = stream_predict(
305
+ openai_api_key,
306
+ system_prompt,
307
+ history,
308
+ inputs,
309
+ chatbot,
310
+ all_token_counts,
311
+ top_p,
312
+ temperature,
313
+ selected_model,
314
+ fake_input=old_inputs
315
+ )
316
+ for chatbot, history, status_text, all_token_counts in iter:
317
+ yield chatbot, history, status_text, all_token_counts
318
+ else:
319
+ logging.info("不使用流式传输")
320
+ chatbot, history, status_text, all_token_counts = predict_all(
321
+ openai_api_key,
322
+ system_prompt,
323
+ history,
324
+ inputs,
325
+ chatbot,
326
+ all_token_counts,
327
+ top_p,
328
+ temperature,
329
+ selected_model,
330
+ fake_input=old_inputs
331
+ )
332
+ yield chatbot, history, status_text, all_token_counts
333
+
334
+ logging.info(f"传输完毕。当前token计数为{all_token_counts}")
335
+ if len(history) > 1 and history[-1]['content'] != inputs:
336
+ # logging.info("回答为:" +colorama.Fore.BLUE + f"{history[-1]['content']}" + colorama.Style.RESET_ALL)
337
+ try:
338
+ token = all_token_counts[-1]
339
+ except:
340
+ token = 0
341
+ holo_query_insert_chat_message(invite_code, inputs, history[-1]['content'], token, history)
342
+
343
+ if use_websearch:
344
+ response = history[-1]['content']
345
+ response += "\n\n" + "\n".join(link_references)
346
+ logging.info(f"Added link references.")
347
+ logging.info(response)
348
+ chatbot[-1] = (parse_text(old_inputs), response)
349
+ yield chatbot, history, status_text, all_token_counts
350
+
351
+ if stream:
352
+ max_token = max_token_streaming
353
+ else:
354
+ max_token = max_token_all
355
+
356
+ if sum(all_token_counts) > max_token and should_check_token_count:
357
+ status_text = f"精简token中{all_token_counts}/{max_token}"
358
+ logging.info(status_text)
359
+ yield chatbot, history, status_text, all_token_counts
360
+ iter = reduce_token_size(
361
+ openai_api_key,
362
+ invite_code,
363
+ system_prompt,
364
+ history,
365
+ chatbot,
366
+ all_token_counts,
367
+ top_p,
368
+ temperature,
369
+ stream=False,
370
+ selected_model=selected_model,
371
+ hidden=True,
372
+ )
373
+ for chatbot, history, status_text, all_token_counts in iter:
374
+ status_text = f"Token 达到上限,已自动降低Token计数至 {status_text}"
375
+ yield chatbot, history, status_text, all_token_counts
376
+
377
+
378
+ def retry(
379
+ openai_api_key,
380
+ invite_code,
381
+ system_prompt,
382
+ history,
383
+ chatbot,
384
+ token_count,
385
+ top_p,
386
+ temperature,
387
+ stream=False,
388
+ selected_model=MODELS[0],
389
+ ):
390
+ logging.info("重试中……")
391
+ if len(history) == 0:
392
+ yield chatbot, history, f"{standard_error_msg}上下文是空的", token_count
393
+ return
394
+ history.pop()
395
+ inputs = history.pop()["content"]
396
+ token_count.pop()
397
+ iter = predict(
398
+ openai_api_key,
399
+ invite_code,
400
+ system_prompt,
401
+ history,
402
+ inputs,
403
+ chatbot,
404
+ token_count,
405
+ top_p,
406
+ temperature,
407
+ stream=stream,
408
+ selected_model=selected_model,
409
+ )
410
+ logging.info("重试完毕")
411
+ for x in iter:
412
+ yield x
413
+
414
+
415
+ def reduce_token_size(
416
+ openai_api_key,
417
+ invite_code,
418
+ system_prompt,
419
+ history,
420
+ chatbot,
421
+ token_count,
422
+ top_p,
423
+ temperature,
424
+ stream=False,
425
+ selected_model=MODELS[0],
426
+ hidden=False,
427
+ ):
428
+ logging.info("开始减少token数量……")
429
+ iter = predict(
430
+ openai_api_key,
431
+ invite_code,
432
+ system_prompt,
433
+ history,
434
+ summarize_prompt,
435
+ chatbot,
436
+ token_count,
437
+ top_p,
438
+ temperature,
439
+ stream=stream,
440
+ selected_model=selected_model,
441
+ should_check_token_count=False,
442
+ )
443
+ logging.info(f"chatbot: {chatbot}")
444
+ for chatbot, history, status_text, previous_token_count in iter:
445
+ history = history[-2:]
446
+ token_count = previous_token_count[-1:]
447
+ if hidden:
448
+ chatbot.pop()
449
+ yield chatbot, history, construct_token_message(
450
+ sum(token_count), stream=stream
451
+ ), token_count
452
+ logging.info("减少token数量完毕")
bin_public/app/llama_func.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_index import GPTSimpleVectorIndex
2
+ from llama_index import download_loader
3
+ from llama_index import (
4
+ Document,
5
+ LLMPredictor,
6
+ PromptHelper,
7
+ QuestionAnswerPrompt,
8
+ RefinePrompt,
9
+ )
10
+ from langchain.llms import OpenAI
11
+ import colorama
12
+
13
+ from bin_public.utils.tools import *
14
+
15
+
16
+ def get_documents(file_src):
17
+ documents = []
18
+ index_name = ""
19
+ logging.debug("Loading documents...")
20
+ logging.debug(f"file_src: {file_src}")
21
+ for file in file_src:
22
+ logging.debug(f"file: {file.name}")
23
+ index_name += file.name
24
+ if os.path.splitext(file.name)[1] == ".pdf":
25
+ logging.debug("Loading PDF...")
26
+ CJKPDFReader = download_loader("CJKPDFReader")
27
+ loader = CJKPDFReader()
28
+ documents += loader.load_data(file=file.name)
29
+ elif os.path.splitext(file.name)[1] == ".docx":
30
+ logging.debug("Loading DOCX...")
31
+ DocxReader = download_loader("DocxReader")
32
+ loader = DocxReader()
33
+ documents += loader.load_data(file=file.name)
34
+ elif os.path.splitext(file.name)[1] == ".epub":
35
+ logging.debug("Loading EPUB...")
36
+ EpubReader = download_loader("EpubReader")
37
+ loader = EpubReader()
38
+ documents += loader.load_data(file=file.name)
39
+ else:
40
+ logging.debug("Loading text file...")
41
+ with open(file.name, "r", encoding="utf-8") as f:
42
+ text = add_space(f.read())
43
+ documents += [Document(text)]
44
+ index_name = sha1sum(index_name)
45
+ return documents, index_name
46
+
47
+
48
+ def construct_index(
49
+ api_key,
50
+ file_src,
51
+ max_input_size=4096,
52
+ num_outputs=1,
53
+ max_chunk_overlap=20,
54
+ chunk_size_limit=600,
55
+ embedding_limit=None,
56
+ separator=" ",
57
+ num_children=10,
58
+ max_keywords_per_chunk=10,
59
+ ):
60
+ os.environ["OPENAI_API_KEY"] = api_key
61
+ chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
62
+ embedding_limit = None if embedding_limit == 0 else embedding_limit
63
+ separator = " " if separator == "" else separator
64
+
65
+ llm_predictor = LLMPredictor(
66
+ llm=OpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
67
+ )
68
+ prompt_helper = PromptHelper(
69
+ max_input_size,
70
+ num_outputs,
71
+ max_chunk_overlap,
72
+ embedding_limit,
73
+ chunk_size_limit,
74
+ separator=separator,
75
+ )
76
+ documents, index_name = get_documents(file_src)
77
+ if os.path.exists(f"./index/{index_name}.json"):
78
+ logging.info("找到了缓存的索引文件,加载中……")
79
+ return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
80
+ else:
81
+ try:
82
+ logging.debug("构建索引中……")
83
+ index = GPTSimpleVectorIndex(
84
+ documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
85
+ )
86
+ # os.makedirs("./index", exist_ok=True)
87
+ # index.save_to_disk(f"./index/{index_name}.json")
88
+ return index
89
+ except Exception as e:
90
+ print(e)
91
+ return None
92
+
93
+
94
+ def chat_ai(
95
+ api_key,
96
+ index,
97
+ question,
98
+ context,
99
+ chatbot,
100
+ ):
101
+ os.environ["OPENAI_API_KEY"] = api_key
102
+
103
+ logging.info(f"Question: {question}")
104
+
105
+ response, chatbot_display, status_text = ask_ai(
106
+ api_key,
107
+ index,
108
+ question,
109
+ replace_today(PROMPT_TEMPLATE),
110
+ REFINE_TEMPLATE,
111
+ SIM_K,
112
+ INDEX_QUERY_TEMPERATURE,
113
+ context,
114
+ )
115
+ if response is None:
116
+ status_text = "查询失败,请换个问法试试"
117
+ return context, chatbot
118
+ response = response
119
+
120
+ context.append({"role": "user", "content": question})
121
+ context.append({"role": "assistant", "content": response})
122
+ chatbot.append((question, chatbot_display))
123
+
124
+ os.environ["OPENAI_API_KEY"] = ""
125
+ return context, chatbot, status_text
126
+
127
+
128
+ def ask_ai(
129
+ api_key,
130
+ index,
131
+ question,
132
+ prompt_tmpl,
133
+ refine_tmpl,
134
+ sim_k=1,
135
+ temprature=0,
136
+ prefix_messages=[],
137
+ ):
138
+ os.environ["OPENAI_API_KEY"] = api_key
139
+
140
+ logging.debug("Index file found")
141
+ logging.debug("Querying index...")
142
+ llm_predictor = LLMPredictor(
143
+ llm=OpenAI(
144
+ temperature=temprature,
145
+ model_name="gpt-3.5-turbo-0301",
146
+ prefix_messages=prefix_messages,
147
+ )
148
+ )
149
+
150
+ response = None # Initialize response variable to avoid UnboundLocalError
151
+ qa_prompt = QuestionAnswerPrompt(prompt_tmpl)
152
+ rf_prompt = RefinePrompt(refine_tmpl)
153
+ response = index.query(
154
+ question,
155
+ llm_predictor=llm_predictor,
156
+ similarity_top_k=sim_k,
157
+ text_qa_template=qa_prompt,
158
+ refine_template=rf_prompt,
159
+ response_mode="compact",
160
+ )
161
+
162
+ if response is not None:
163
+ logging.info(f"Response: {response}")
164
+ ret_text = response.response
165
+ nodes = []
166
+ for index, node in enumerate(response.source_nodes):
167
+ brief = node.source_text[:25].replace("\n", "")
168
+ nodes.append(
169
+ f"<details><summary>[{index+1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
170
+ )
171
+ new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
172
+ logging.info(
173
+ f"Response: {colorama.Fore.BLUE}{ret_text}{colorama.Style.RESET_ALL}"
174
+ )
175
+ os.environ["OPENAI_API_KEY"] = ""
176
+ return ret_text, new_response, f"查询消耗了{llm_predictor.last_token_usage} tokens"
177
+ else:
178
+ logging.warning("No response found, returning None")
179
+ os.environ["OPENAI_API_KEY"] = ""
180
+ return None
181
+
182
+
183
+ def add_space(text):
184
+ punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
185
+ for cn_punc, en_punc in punctuations.items():
186
+ text = text.replace(cn_punc, en_punc)
187
+ return text
bin_public/app/overwrites.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from llama_index import Prompt
4
+ from typing import List, Tuple
5
+ import mdtex2html
6
+
7
+ from llama_func import *
8
+
9
+
10
+ def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[str]:
11
+ logging.debug("Compacting text chunks...🚀🚀🚀")
12
+ combined_str = [c.strip() for c in text_chunks if c.strip()]
13
+ combined_str = [f"[{index+1}] {c}" for index, c in enumerate(combined_str)]
14
+ combined_str = "\n\n".join(combined_str)
15
+ # resplit based on self.max_chunk_overlap
16
+ text_splitter = self.get_text_splitter_given_prompt(prompt, 1, padding=1)
17
+ return text_splitter.split_text(combined_str)
18
+
19
+
20
+ def postprocess(
21
+ self, y: List[Tuple[str | None, str | None]]
22
+ ) -> List[Tuple[str | None, str | None]]:
23
+ """
24
+ Parameters:
25
+ y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format.
26
+ Returns:
27
+ List of tuples representing the message and response. Each message and response will be a string of HTML.
28
+ """
29
+ if y is None:
30
+ return []
31
+ for i, (message, response) in enumerate(y):
32
+ y[i] = (
33
+ # None if message is None else markdown.markdown(message),
34
+ # None if response is None else markdown.markdown(response),
35
+ None if message is None else message,
36
+ None if response is None else mdtex2html.convert(response, extensions=['fenced_code','codehilite','tables']),
37
+ )
38
+ return y
bin_public/config/presets.py CHANGED
@@ -116,6 +116,9 @@ pre code {
116
  }
117
  """
118
 
 
 
 
119
  summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
120
  # MODELS = ["gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-4","gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314"] # 可选的模型
121
  MODELS = ["gpt-3.5-turbo-0301"]
@@ -144,3 +147,40 @@ max_token_all = 3500 # 非流式对话时的最大 token 数
144
  timeout_all = 200 # 非流式对话时的超时时间
145
  enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
146
  HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  }
117
  """
118
 
119
+ SIM_K = 5
120
+ INDEX_QUERY_TEMPERATURE = 1.0
121
+
122
  summarize_prompt = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
123
  # MODELS = ["gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-4","gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314"] # 可选的模型
124
  MODELS = ["gpt-3.5-turbo-0301"]
 
147
  timeout_all = 200 # 非流式对话时的超时时间
148
  enable_streaming_option = True # 是否启用选择选择是否实时显示回答的勾选框
149
  HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
150
+
151
+ WEBSEARCH_PTOMPT_TEMPLATE = """\
152
+ Web search results:
153
+
154
+ {web_results}
155
+ Current date: {current_date}
156
+
157
+ Instructions: Using the provided web search results, write a comprehensive reply to the given query. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject.
158
+ Query: {query}
159
+ Reply in 中文"""
160
+
161
+ PROMPT_TEMPLATE = """\
162
+ Context information is below.
163
+ ---------------------
164
+ {context_str}
165
+ ---------------------
166
+ Current date: {current_date}.
167
+ Using the provided context information, write a comprehensive reply to the given query.
168
+ Make sure to cite results using [number] notation after the reference.
169
+ If the provided context information refer to multiple subjects with the same name, write separate answers for each subject.
170
+ Use prior knowledge only if the given context didn't provide enough information.
171
+ Answer the question: {query_str}
172
+ Reply in 中文
173
+ """
174
+
175
+ REFINE_TEMPLATE = """\
176
+ The original question is as follows: {query_str}
177
+ We have provided an existing answer: {existing_answer}
178
+ We have the opportunity to refine the existing answer
179
+ (only if needed) with some more context below.
180
+ ------------
181
+ {context_msg}
182
+ ------------
183
+ Given the new context, refine the original answer to better
184
+ Answer in the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch.
185
+ If the context isn't useful, return the original answer.
186
+ """
bin_public/utils/tools.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ from __future__ import annotations
3
+ from typing import TYPE_CHECKING, List
4
+ import logging
5
+ import json
6
+ import os
7
+ import datetime
8
+ import hashlib
9
+ import csv
10
+
11
+ import gradio as gr
12
+ from pypinyin import lazy_pinyin
13
+ import tiktoken
14
+
15
+ from bin_public.config.presets import *
16
+
17
+ # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
18
+
19
+ if TYPE_CHECKING:
20
+ from typing import TypedDict
21
+
22
+ class DataframeData(TypedDict):
23
+ headers: List[str]
24
+ data: List[List[str | int | bool]]
25
+
26
+
27
+ initial_prompt = "You are a helpful assistant."
28
+ API_URL = "https://api.openai.com/v1/chat/completions"
29
+ HISTORY_DIR = "history"
30
+ TEMPLATES_DIR = "templates"
31
+
32
+
33
+ def count_token(message):
34
+ encoding = tiktoken.get_encoding("cl100k_base")
35
+ input_str = f"role: {message['role']}, content: {message['content']}"
36
+ length = len(encoding.encode(input_str))
37
+ return length
38
+
39
+
40
+ def parse_text(text):
41
+ lines = text.split("\n")
42
+ lines = [line for line in lines if line != ""]
43
+ count = 0
44
+ for i, line in enumerate(lines):
45
+ if "```" in line:
46
+ count += 1
47
+ items = line.split('`')
48
+ if count % 2 == 1:
49
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
50
+ else:
51
+ lines[i] = f'<br></code></pre>'
52
+ else:
53
+ if i > 0:
54
+ if count % 2 == 1:
55
+ line = line.replace("`", "\`")
56
+ line = line.replace("<", "&lt;")
57
+ line = line.replace(">", "&gt;")
58
+ line = line.replace(" ", "&nbsp;")
59
+ line = line.replace("*", "&ast;")
60
+ line = line.replace("_", "&lowbar;")
61
+ line = line.replace("-", "&#45;")
62
+ line = line.replace(".", "&#46;")
63
+ line = line.replace("!", "&#33;")
64
+ line = line.replace("(", "&#40;")
65
+ line = line.replace(")", "&#41;")
66
+ line = line.replace("$", "&#36;")
67
+ lines[i] = "<br>" + line
68
+ text = "".join(lines)
69
+ return text
70
+
71
+
72
+ def construct_text(role, text):
73
+ return {"role": role, "content": text}
74
+
75
+
76
+ def construct_user(text):
77
+ return construct_text("user", text)
78
+
79
+
80
+ def construct_system(text):
81
+ return construct_text("system", text)
82
+
83
+
84
+ def construct_assistant(text):
85
+ return construct_text("assistant", text)
86
+
87
+
88
+ def construct_token_message(token, stream=False):
89
+ return f"Token 计数: {token}"
90
+
91
+
92
+ def delete_last_conversation(chatbot, history, previous_token_count):
93
+ if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
94
+ logging.info("由于包含报错信息,只删除chatbot记录")
95
+ chatbot.pop()
96
+ return chatbot, history
97
+ if len(history) > 0:
98
+ logging.info("删除了一组对话历史")
99
+ history.pop()
100
+ history.pop()
101
+ if len(chatbot) > 0:
102
+ logging.info("删除了一组chatbot对话")
103
+ chatbot.pop()
104
+ if len(previous_token_count) > 0:
105
+ logging.info("删除了一组对话的token计数记录")
106
+ previous_token_count.pop()
107
+ return (
108
+ chatbot,
109
+ history,
110
+ previous_token_count,
111
+ construct_token_message(sum(previous_token_count)),
112
+ )
113
+
114
+
115
+ def save_file(filename, system, history, chatbot):
116
+ logging.info("保存对话历史中……")
117
+ os.makedirs(HISTORY_DIR, exist_ok=True)
118
+ if filename.endswith(".json"):
119
+ json_s = {"system": system, "history": history, "chatbot": chatbot}
120
+ print(json_s)
121
+ with open(os.path.join(HISTORY_DIR, filename), "w") as f:
122
+ json.dump(json_s, f)
123
+ elif filename.endswith(".md"):
124
+ md_s = f"system: \n- {system} \n"
125
+ for data in history:
126
+ md_s += f"\n{data['role']}: \n- {data['content']} \n"
127
+ with open(os.path.join(HISTORY_DIR, filename), "w", encoding="utf8") as f:
128
+ f.write(md_s)
129
+ logging.info("保存对话历史完毕")
130
+ return os.path.join(HISTORY_DIR, filename)
131
+
132
+
133
+ def save_chat_history(filename, system, history, chatbot):
134
+ if filename == "":
135
+ return
136
+ if not filename.endswith(".json"):
137
+ filename += ".json"
138
+ return save_file(filename, system, history, chatbot)
139
+
140
+
141
+ def export_markdown(filename, system, history, chatbot):
142
+ if filename == "":
143
+ return
144
+ if not filename.endswith(".md"):
145
+ filename += ".md"
146
+ return save_file(filename, system, history, chatbot)
147
+
148
+
149
+ def load_chat_history(filename, system, history, chatbot):
150
+ logging.info("加载对话历史中……")
151
+ if type(filename) != str:
152
+ filename = filename.name
153
+ try:
154
+ with open(os.path.join(HISTORY_DIR, filename), "r") as f:
155
+ json_s = json.load(f)
156
+ try:
157
+ if type(json_s["history"][0]) == str:
158
+ logging.info("历史记录格式为旧版,正在转换……")
159
+ new_history = []
160
+ for index, item in enumerate(json_s["history"]):
161
+ if index % 2 == 0:
162
+ new_history.append(construct_user(item))
163
+ else:
164
+ new_history.append(construct_assistant(item))
165
+ json_s["history"] = new_history
166
+ logging.info(new_history)
167
+ except:
168
+ # 没有对话历史
169
+ pass
170
+ logging.info("加载对话历史完毕")
171
+ return filename, json_s["system"], json_s["history"], json_s["chatbot"]
172
+ except FileNotFoundError:
173
+ logging.info("没有找到对话历史文件,不执行任何操作")
174
+ return filename, system, history, chatbot
175
+
176
+
177
+ def sorted_by_pinyin(list):
178
+ return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
179
+
180
+
181
+ def get_file_names(dir, plain=False, filetypes=[".json"]):
182
+ logging.info(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
183
+ files = []
184
+ try:
185
+ for type in filetypes:
186
+ files += [f for f in os.listdir(dir) if f.endswith(type)]
187
+ except FileNotFoundError:
188
+ files = []
189
+ files = sorted_by_pinyin(files)
190
+ if files == []:
191
+ files = [""]
192
+ if plain:
193
+ return files
194
+ else:
195
+ return gr.Dropdown.update(choices=files)
196
+
197
+
198
+ def get_history_names(plain=False):
199
+ logging.info("获取历史记录文件名列表")
200
+ return get_file_names(HISTORY_DIR, plain)
201
+
202
+
203
+ def load_template(filename, mode=0):
204
+ logging.info(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
205
+ lines = []
206
+ logging.info("Loading template...")
207
+ if filename.endswith(".json"):
208
+ with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
209
+ lines = json.load(f)
210
+ lines = [[i["act"], i["prompt"]] for i in lines]
211
+ else:
212
+ with open(
213
+ os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8"
214
+ ) as csvfile:
215
+ reader = csv.reader(csvfile)
216
+ lines = list(reader)
217
+ lines = lines[1:]
218
+ if mode == 1:
219
+ return sorted_by_pinyin([row[0] for row in lines])
220
+ elif mode == 2:
221
+ return {row[0]: row[1] for row in lines}
222
+ else:
223
+ choices = sorted_by_pinyin([row[0] for row in lines])
224
+ return {row[0]: row[1] for row in lines}, gr.Dropdown.update(
225
+ choices=choices, value=choices[0]
226
+ )
227
+
228
+
229
+ def get_template_names(plain=False):
230
+ logging.info("获取模板文件名列表")
231
+ return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
232
+
233
+
234
+ def get_template_content(templates, selection, original_system_prompt):
235
+ logging.info(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}")
236
+ try:
237
+ return templates[selection]
238
+ except:
239
+ return original_system_prompt
240
+
241
+
242
+ def reset_state():
243
+ logging.info("重置状态")
244
+ return [], [], [], construct_token_message(0)
245
+
246
+
247
+ def reset_textbox():
248
+ return gr.update(value="")
249
+
250
+
251
+ def reset_default():
252
+ global API_URL
253
+ API_URL = "https://api.openai.com/v1/chat/completions"
254
+ os.environ.pop("HTTPS_PROXY", None)
255
+ os.environ.pop("https_proxy", None)
256
+ return gr.update(value=API_URL), gr.update(value=""), "API URL 和代理已重置"
257
+
258
+
259
+ def change_api_url(url):
260
+ global API_URL
261
+ API_URL = url
262
+ msg = f"API地址更改为了{url}"
263
+ logging.info(msg)
264
+ return msg
265
+
266
+
267
+ def change_proxy(proxy):
268
+ os.environ["HTTPS_PROXY"] = proxy
269
+ msg = f"代理更改为了{proxy}"
270
+ logging.info(msg)
271
+ return msg
272
+
273
+
274
+ def hide_middle_chars(s):
275
+ if len(s) <= 8:
276
+ return s
277
+ else:
278
+ head = s[:4]
279
+ tail = s[-4:]
280
+ hidden = "*" * (len(s) - 8)
281
+ return head + hidden + tail
282
+
283
+
284
+ def submit_key(key):
285
+ key = key.strip()
286
+ msg = f"API密钥更改为了{hide_middle_chars(key)}"
287
+ logging.info(msg)
288
+ return key, msg
289
+
290
+
291
+ def sha1sum(filename):
292
+ sha1 = hashlib.sha1()
293
+ sha1.update(filename.encode("utf-8"))
294
+ return sha1.hexdigest()
295
+
296
+
297
+ def replace_today(prompt):
298
+ today = datetime.datetime.today().strftime("%Y-%m-%d")
299
+ return prompt.replace("{current_date}", today)
bin_public/utils/utils_db.py CHANGED
@@ -1,6 +1,5 @@
1
  import psycopg2
2
  import datetime
3
- #from bin_public.config.config import HOLOGRES_CONFIG
4
  from bin_public.config.presets import *
5
  from dateutil import tz
6
  import os
 
1
  import psycopg2
2
  import datetime
 
3
  from bin_public.config.presets import *
4
  from dateutil import tz
5
  import os