allinaigc commited on
Commit
9b79aec
1 Parent(s): b244fe4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +402 -0
app.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 1. 完成了用Qwen通义千问作为知识库查询。
3
+ 1. 总共有三个区块:知识库回答,应用来源,相关问题。
4
+ 1. 在Huggingface的API上部署了一个在线BGE的模型,用于回答问题。OpenAI的Emebedding或者Langchain的Embedding都不可以用(会报错: self.d)。
5
+
6
+ 注意事项:
7
+ 1. langchain_KB.py中的代码是用来构建本地知识库的,里面的embeddings需要与rag_response_002.py中的embeddings一致。否则会出错!
8
+ 1. 如果报错sentence_transformer, 主要原因是与matlabplot等各种package的兼容性冲突。目前几个核心python文件的中的package不会冲突,可以看一下。
9
+
10
+ """
11
+
12
+ ##TODO:
13
+
14
+ # -*- coding: utf-8 -*-
15
+ import streamlit as st
16
+ import openai
17
+ import os
18
+ import numpy as np
19
+ import pandas as pd
20
+
21
+ import csv
22
+ import tempfile
23
+ from tempfile import NamedTemporaryFile
24
+ import pathlib
25
+ from pathlib import Path
26
+ import re
27
+ from re import sub
28
+ from itertools import product
29
+ import time
30
+ from time import sleep
31
+ import streamlit_authenticator as stauth
32
+ from langchain_community.vectorstores import FAISS
33
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
34
+ from langchain_core.output_parsers import StrOutputParser
35
+ from langchain_core.runnables import RunnablePassthrough
36
+ from langchain.llms.base import LLM
37
+ from langchain.llms.utils import enforce_stop_tokens
38
+ from typing import Dict, List, Optional, Tuple, Union
39
+ import requests
40
+ import streamlit as st
41
+ import qwen_response
42
+ import rag_reponse_002
43
+ import dashscope
44
+ from dotenv import load_dotenv
45
+ from datetime import datetime
46
+ import pytz
47
+ from pytz import timezone
48
+
49
+ # def get_current_time():
50
+ # beijing_tz = timezone('Asia/Shanghai')
51
+ # beijing_time = datetime.now(beijing_tz)
52
+ # current_time = beijing_time.strftime('%H:%M:%S')
53
+ # return current_time
54
+
55
+ load_dotenv()
56
+ ### 设置openai的API key
57
+ os.environ["OPENAI_API_KEY"] = os.environ['user_token']
58
+ openai.api_key = os.environ['user_token']
59
+ bing_search_api_key = os.environ['bing_api_key']
60
+ dashscope.api_key = os.environ['dashscope_api_key']
61
+
62
+
63
+ ### Streamlit页面设定。
64
+ st.set_page_config(layout="wide", page_icon="🚀", page_title="本地化国产大模型知识库查询演示")
65
+ st.title("本地化国产大模型知识库查询演示")
66
+ # st.title("大语言模型智能知识库查询中心")
67
+ # st.title("大语言模型本地知识库问答系统")
68
+ # st.subheader("Large Language Model-based Knowledge Base QA System")
69
+ # st.warning("_声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_")
70
+ st.caption("_声明:内容由人工智能生成,仅供参考。您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_")
71
+ # st.caption("_声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_")
72
+ # st.info("_声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_")
73
+ # st.divider()
74
+
75
+
76
+
77
+ ### upload file
78
+ # username = 'test'
79
+ # path = f'./{username}/faiss_index/index.faiss'
80
+ # if os.path.exists(path):
81
+ # print(f'{path} local KB exists')
82
+ # database_info = pd.read_csv(f'./{username}/database_name.csv')
83
+ # current_database_name = database_info.iloc[-1][0]
84
+ # current_database_date = database_info.iloc[-1][1]
85
+ # database_claim = f"当前知识库为:{current_database_name},创建于{current_database_date}。可以开始提问!"
86
+ # st.markdown(database_claim)
87
+
88
+ # uploaded_file = st.file_uploader(
89
+ # "选择上传一个新知识库", type=(["pdf"]))
90
+ # # 默认状态下没有上传文件,None,会报错。需要判断。
91
+ # if uploaded_file is not None:
92
+ # # uploaded_file_path = upload_file(uploaded_file)
93
+ # upload_file(uploaded_file)
94
+
95
+
96
+ # # ## 创建向量数据库
97
+ # from langchain.embeddings.openai import OpenAIEmbeddings
98
+ # embeddings = OpenAIEmbeddings(disallowed_special=()) ## 这里是联网情况下,部署在Huggingface上后使用。
99
+ # print('embeddings:', embeddings)
100
+
101
+ # embedding_model_name = 'GanymedeNil/text2vec-large-chinese'
102
+ # # embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name) ## 这里是联网情况下连接huggingface后使用。
103
+ # embeddings = HuggingFaceEmbeddings(model_name='/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/Coding/RAG/bge-large-zh') ## 切换成BGE的embedding。
104
+ # embeddings = HuggingFaceEmbeddings(model_name='/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/RAG/bge-large-zh/') ## 切换成BGE的embedding。
105
+ # embeddings = HuggingFaceEmbeddings(model_name='/Users/yunshi/Downloads/chatGLM/My_LocalKB_Project/GanymedeNil_text2vec-large-chinese/') ## 这里会有个“No sentence-transformers model found with name“的warning,但不是error,不影响使用。
106
+
107
+
108
+
109
+ ### authentication with a local yaml file.
110
+ import yaml
111
+ from yaml.loader import SafeLoader
112
+ with open('./config.yaml') as file:
113
+ config = yaml.load(file, Loader=SafeLoader)
114
+ authenticator = stauth.Authenticate(
115
+ config['credentials'],
116
+ config['cookie']['name'],
117
+ config['cookie']['key'],
118
+ config['cookie']['expiry_days'],
119
+ config['preauthorized']
120
+ )
121
+
122
+ user, authentication_status, username = authenticator.login('用户登录', 'main')
123
+
124
+ if authentication_status:
125
+ with st.sidebar:
126
+ st.markdown(
127
+ """
128
+ <style>
129
+ [data-testid="stSidebar"][aria-expanded="true"]{
130
+ min-width: 450px;
131
+ max-width: 450px;
132
+ }
133
+ """,
134
+ unsafe_allow_html=True,
135
+ )
136
+ ### siderbar的题目。
137
+ ### siderbar的题目。
138
+ # st.header(f'**大语言模型专家系统工作设定区**')
139
+ st.header(f'**欢迎 **{username}** 使用本系统** ')
140
+ st.write(f'_Large Language Model Expert System Working Environment_')
141
+ # st.write(f'_Welcome and Hope U Enjoy Staying Here_')
142
+ authenticator.logout('登出', 'sidebar')
143
+
144
+ ### upload模块
145
+ def upload_file(uploaded_file):
146
+ if uploaded_file is not None:
147
+ # filename = uploaded_file.name
148
+ # st.write(filename) # print out the whole file name to validate. not to show in the final version.
149
+ try:
150
+ # if '.pdf' in filename: ### original code here.
151
+ if '.pdf' in uploaded_file.name:
152
+ pdf_filename = uploaded_file.name ### original code here.
153
+ filename = uploaded_file.name
154
+ # print('PDF file:', pdf_filename)
155
+ # with st.status('正在为您解析新知识库...', expanded=False, state='running') as status:
156
+ spinner = st.spinner('正在为您解析新知识库...请耐心等待')
157
+ with spinner:
158
+ ### 以下是langchain方案。
159
+ import langchain_KB
160
+ import save_database_info
161
+
162
+ uploaded_file_name = "File_provided"
163
+ temp_dir = tempfile.TemporaryDirectory()
164
+ # ! working.
165
+ uploaded_file_path = pathlib.Path(temp_dir.name) / uploaded_file_name
166
+ with open(pdf_filename, 'wb') as output_temporary_file:
167
+ # with open(f'./{username}_upload.pdf', 'wb') as output_temporary_file: ### original code here. 可能会造成在引用信息来源时文件名不对的问题。
168
+ # ! 必须用这种格式读入内容,然后才可以写入temporary文件夹中。
169
+ # output_temporary_file.write(uploaded_file.getvalue())
170
+ output_temporary_file.write(uploaded_file.getvalue())
171
+
172
+ langchain_KB.langchain_localKB_construct(output_temporary_file, username)
173
+ ## 在屏幕上展示当前知识库的信息,包括名字和加载日期。
174
+ save_database_info.save_database_info(f'./{username}/database_name.csv', pdf_filename, str(datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y-%m-%d %H:%M")))
175
+ st.markdown('新知识库解析成功,请务必刷新页面,然后开启对话 🔃')
176
+
177
+ return pdf_filename
178
+
179
+ else:
180
+ # if '.csv' in filename: ### original code here.
181
+ if '.csv' in uploaded_file.name:
182
+ print('start the csv file processing...')
183
+ csv_filename = uploaded_file.name
184
+ filename = uploaded_file.name
185
+
186
+ csv_file = pd.read_csv(uploaded_file)
187
+ csv_file.to_csv(f'./{username}/{username}_upload.csv', encoding='utf-8', index=False)
188
+ st.write(csv_file[:3]) # 这里只是显示文件,后面需要定位文件所在的绝对路径。
189
+ else:
190
+ xls_file = pd.read_excel(uploaded_file)
191
+ xls_file.to_csv(f'./{username}_upload.csv', index=False)
192
+ st.write(xls_file[:3])
193
+
194
+ print('end the csv file processing...')
195
+
196
+ # uploaded_file_name = "File_provided"
197
+ # temp_dir = tempfile.TemporaryDirectory()
198
+ # ! working.
199
+ # uploaded_file_path = pathlib.Path(temp_dir.name) / uploaded_file_name
200
+ # with open('./upload.csv', 'wb') as output_temporary_file:
201
+ # with open(f'./{username}_upload.csv', 'wb') as output_temporary_file:
202
+ # print(f'./{name}_upload.csv')
203
+ # ! 必须用这种格式读入内容,然后才可以写入temporary文件夹中。
204
+ # output_temporary_file.write(uploaded_file.getvalue())
205
+ # st.write(uploaded_file_path) #* 可以查看文件是否真实存在,然后是否可以
206
+
207
+ except Exception as e:
208
+ st.write(e)
209
+
210
+ ## 以下代码是为了解决上传文件后,文件路径和文件名不对的问题。
211
+ # uploaded_file_name = "File_provided"
212
+ # temp_dir = tempfile.TemporaryDirectory()
213
+ # # ! working.
214
+ # uploaded_file_path = pathlib.Path(temp_dir.name) / uploaded_file_name
215
+ # # with open('./upload.csv', 'wb') as output_temporary_file:
216
+ # with open(f'./{name}_upload.csv', 'wb') as output_temporary_file:
217
+ # # print(f'./{name}_upload.csv')
218
+ # # ! 必须用这种格式读入内容,然后才可以写入temporary文件夹中。
219
+ # # output_temporary_file.write(uploaded_file.getvalue())
220
+ # output_temporary_file.write(uploaded_file.getvalue())
221
+ # # st.write(uploaded_file_path) # * 可以查看文件是否真实存在,然后是否可以
222
+ # # st.write('Now file saved successfully.')
223
+
224
+ # return pdf_filename, csv_filename
225
+ return filename
226
+
227
+ path = f'./{username}/faiss_index/index.faiss'
228
+ if os.path.exists(path):
229
+ print(f'{path} local KB exists')
230
+ database_info = pd.read_csv(f'./{username}/database_name.csv', encoding='utf-8', header=None) ## 不加encoding的话,中文名字的PDF会报错。
231
+ print(database_info)
232
+ current_database_name = database_info.iloc[-1][0]
233
+ current_database_date = database_info.iloc[-1][1]
234
+ database_claim = f"当前知识库为:{current_database_name},创建于{current_database_date}。可以开始提问!"
235
+ st.warning(database_claim)
236
+ # st.markdown(database_claim)
237
+
238
+ try:
239
+ uploaded_file = st.file_uploader(
240
+ "选择上传一个新知识库", type=(["pdf"]))
241
+ # 默认状态下没有上传文件,None,会报错。需要判断。
242
+ if uploaded_file is not None:
243
+ # uploaded_file_path = upload_file(uploaded_file)
244
+ upload_file(uploaded_file)
245
+ except Exception as e:
246
+ print(e)
247
+ pass
248
+
249
+ ## 在sidebar上的三个分页显示,用st.tabs实现。
250
+ tab_1, tab_2, tab_3, tab_4 = st.tabs(['使用须知', '模型参数', '提示词模板', '系统角色设定'])
251
+
252
+ # with st.expander(label='**使用须知**', expanded=False):
253
+ with tab_1:
254
+ # st.markdown("#### 快速上手指南")
255
+ # with st.text(body="说明"):
256
+ # st.markdown("* 重启一轮新对话时,只需要刷新页面(按Ctrl/Command + R)即可。")
257
+ with st.text(body="说明"):
258
+ st.markdown("* 为了保护数据与隐私,所有对话均不会被保存,刷新页面立即删除。敬请放心。")
259
+ # with st.text(body="说明"):
260
+ # st.markdown("* “GPT-4”回答质量极佳,但速度缓慢,建议适当使用。")
261
+ with st.text(body="说明"):
262
+ st.markdown("* 查询知识库模式与所有的搜索引擎或者数据库检索方式一样,仅限一轮对话,将不会保持之前的会话记录。")
263
+ with st.text(body="说明"):
264
+ st.markdown("""* 系统的工作流程如下:
265
+ 1. 用户输入问题。
266
+ 1. 系统将问题转换为机器可理解的格式。
267
+ 1. 系统使用大语言模型来生成与问题相关的候选答案。
268
+ 1. 系统使用本地知识库来评估候选答案的准确性。
269
+ 1. 系统返回最准确的答案。""")
270
+
271
+ ## 大模型参数
272
+ # with st.expander(label='**大语言模型参数**', expanded=True):
273
+ with tab_2:
274
+ max_tokens = st.slider(label='Max_Token(生成结果时最大字数)', min_value=100, max_value=8096, value=4096,step=100)
275
+ temperature = st.slider(label='Temperature (温度)', min_value=0.0, max_value=1.0, value=0.8, step=0.1)
276
+ top_p = st.slider(label='Top_P (核采样)', min_value=0.0, max_value=1.0, value=0.6, step=0.1)
277
+ frequency_penalty = st.slider(label='Frequency Penalty (重复度惩罚因子)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1)
278
+ presence_penalty = st.slider(label='Presence Penalty (控制主题的重复度)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1)
279
+
280
+ with tab_3:
281
+ # st.markdown("#### Prompt提示词参考资料")
282
+ # with st.expander(label="**大语言模型基础提示词Prompt示例**", expanded=False):
283
+ st.code(
284
+ body="我是一个企业主,我需要关注哪些“存货”相关的数据资源规则?", language='plaintext')
285
+ st.code(
286
+ body="作为零售商,了解哪些关键的库存管理指标对我至关重要?", language='plaintext')
287
+ st.code(body="企业主在监控库存时,应如何确保遵守行业法规和最佳实践?",
288
+ language='plaintext')
289
+ st.code(body="在数字化时代,我应该关注哪些技术工具或平台来优化我的库存数据流程?", language='plaintext')
290
+ st.code(body="我应该如何定期审查和分析这些库存数据以提高运营效率?", language='plaintext')
291
+ st.code(body="如何设置预警系统来避免过度库存或缺货情况?", language='plaintext')
292
+
293
+ with tab_4:
294
+ st.text_area(label='系统角色设定', value='你是一个人工智能,你需要回答我提出的问题,或者完成我交代的任务。你需要使用我提问的语言(如中文、英文)来回答。', height=200, label_visibility='hidden')
295
+
296
+
297
+ elif authentication_status == False:
298
+ st.error('⛔ 用户名或密码错误!')
299
+ elif authentication_status == None:
300
+ st.warning('⬆ 请先登录!')
301
+
302
+
303
+ ### 上传文件的模块
304
+
305
+
306
+
307
+
308
+ #### start: 主程序
309
+ ## 清楚所有对话记录。
310
+ def clear_all():
311
+ st.session_state.conversation = None
312
+ st.session_state.chat_history = None
313
+ st.session_state.messages = []
314
+ message_placeholder = st.empty()
315
+
316
+ ## 只用这一个就可以了。
317
+ st.rerun()
318
+
319
+ return None
320
+
321
+ if "copied" not in st.session_state:
322
+ st.session_state.copied = []
323
+
324
+ if "llm_response" not in st.session_state:
325
+ st.session_state.llm_response = ""
326
+
327
+ # ## copy to clipboard function with a button.
328
+ # def copy_to_clipboard(text):
329
+ # st.session_state.copied.append(text)
330
+ # clipboard.copy(text)
331
+
332
+
333
+ def main():
334
+ # llm = ChatGLM() ## 启动一个实例。
335
+ col1, col2 = st.columns([2, 1])
336
+ # st.markdown('### 数据库查询区')
337
+ # with st.expander(label='**查询企业内部知识库**', expanded=True):
338
+ with col1:
339
+ KB_mode = True
340
+ user_input = st.text_input(label='**📶 大模型数据库对话区**', placeholder='请输入您的问题', label_visibility='visible')
341
+ if user_input:
342
+ ## 非stream输出,原始状态,不需要改变api.py中的内容。
343
+ # with st.status('检索中...', expanded=True, state='running') as status:
344
+ spinner = st.spinner('思考中...请耐心等待')
345
+ with spinner:
346
+ if KB_mode == True:
347
+ # import rag_reponse_001
348
+ # clear_all()
349
+ # response = rag_reponse_001.rag_response(user_input=user_input, k=5) ## working.
350
+ # print('user_input:', user_input)
351
+ response, source = rag_reponse_002.rag_response(username=username, user_input=user_input, k=3)
352
+ print('llm response:', response)
353
+ sim_prompt = f"""你需要根据以下的问题来提出5个可能的后续问题{user_input}
354
+ """
355
+ # sim_questions = chatgpt.chatgpt(user_prompt=sim_prompt) ## chatgpt to get similar questions.
356
+ sim_questions = qwen_response.call_with_messages(sim_prompt)
357
+ if len(user_input) != 0:
358
+ sim_prompt = f"""你需要根据以下的初始问题来提出3个相似的问题和3个后续问题。
359
+ 初始问题是:{user_input}
360
+ --------------------
361
+ 你回答的时候,需要使用如下格式:
362
+ **相似问题:**
363
+ **后续问题:**
364
+
365
+ """
366
+ # sim_prompt = f"""你需要根据以下的问题来提出5个可能的后续问题{user_input}"""
367
+
368
+ ### 这里用chatgpt来生成相似问题。
369
+ # sim_questions = chatgpt.chatgpt(user_prompt=sim_prompt)
370
+
371
+ ### 这里用Qwen来生成相似问题。
372
+ sim_questions = qwen_response.call_with_messages(sim_prompt)
373
+
374
+
375
+ st.markdown(response)
376
+ # st_copy_to_clipboard(text=str(response), show_text=True, before_copy_label="📋", after_copy_label="✅")
377
+
378
+ ## 如果这样使用,每次按button都会重新提交问题。
379
+ # st.button(label="📃", on_click=copy_to_clipboard, args=(response,))
380
+
381
+ st.divider()
382
+ st.caption(source)
383
+ st.divider()
384
+
385
+ ## 初始状态下response未被定义。
386
+ try:
387
+ if response:
388
+ with col2:
389
+ with st.expander(label='## **您可能还会关注以下内容**', expanded=True):
390
+ st.info(sim_questions)
391
+ except:
392
+ pass
393
+
394
+ # st.stop()
395
+
396
+ return None
397
+
398
+ #### End: 主程序
399
+
400
+ if __name__ == '__main__':
401
+ main()
402
+