WayneWuDH commited on
Commit
23ae983
1 Parent(s): 3cb43ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -195
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
-
3
  import gradio as gr
4
  import nltk
5
  import sentence_transformers
@@ -13,140 +12,30 @@ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
13
  from langchain.prompts import PromptTemplate
14
  from langchain.prompts.prompt import PromptTemplate
15
  from langchain.vectorstores import FAISS
16
-
17
  from chatllm import ChatLLM
18
  from chinese_text_splitter import ChineseTextSplitter
19
 
20
- nltk.data.path.append('./nltk_data')
21
-
22
- embedding_model_dict = {
23
- "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
24
- "ernie-base": "nghuyong/ernie-3.0-base-zh",
25
- "text2vec-base": "GanymedeNil/text2vec-base-chinese",
26
- #"ViT-B-32": 'ViT-B-32::laion2b-s34b-b79k'
27
- }
28
-
29
- llm_model_dict = {
30
- "ChatGLM-6B-int8": "THUDM/chatglm-6b-int8",
31
- "ChatGLM-6B-int4": "THUDM/chatglm-6b-int4",
32
- "ChatGLM-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
33
- #"Minimax": "Minimax"
34
- }
35
-
36
- DEVICE = "cuda" if torch.cuda.is_available(
37
- ) else "mps" if torch.backends.mps.is_available() else "cpu"
38
-
39
-
40
- def search_web(query):
41
-
42
- SESSION.proxies = {
43
- "http": f"socks5h://localhost:7890",
44
- "https": f"socks5h://localhost:7890"
45
- }
46
- results = ddg(query)
47
- web_content = ''
48
- if results:
49
- for result in results:
50
- web_content += result['body']
51
- return web_content
52
-
53
-
54
- def load_file(filepath):
55
- if filepath.lower().endswith(".pdf"):
56
- loader = UnstructuredFileLoader(filepath)
57
- textsplitter = ChineseTextSplitter(pdf=True)
58
- docs = loader.load_and_split(textsplitter)
59
- elif filepath.lower().endswith(".xlsx") or filepath.lower().endswith(".xls"):
60
- # Read the Excel file into a pandas DataFrame
61
- df = pd.read_excel(filepath)
62
- # Convert the DataFrame to a list of strings (or however you want to process the data)
63
- docs = df.values.tolist()
64
- else:
65
- loader = UnstructuredFileLoader(filepath, mode="elements")
66
- textsplitter = ChineseTextSplitter(pdf=False)
67
- docs = loader.load_and_split(text_splitter=textsplitter)
68
  return docs
69
 
 
 
 
 
 
70
 
71
-
72
- def init_knowledge_vector_store(embedding_model, filepath):
73
- if embedding_model == "ViT-B-32":
74
- jina_auth_token = os.getenv('jina_auth_token')
75
- embeddings = JinaEmbeddings(
76
- jina_auth_token=jina_auth_token,
77
- model_name=embedding_model_dict[embedding_model])
78
- else:
79
- embeddings = HuggingFaceEmbeddings(
80
- model_name=embedding_model_dict[embedding_model], )
81
- embeddings.client = sentence_transformers.SentenceTransformer(
82
- embeddings.model_name, device=DEVICE)
83
-
84
- docs = load_file(filepath)
85
 
86
  vector_store = FAISS.from_documents(docs, embeddings)
87
  return vector_store
88
 
89
-
90
- def get_knowledge_based_answer(query,
91
- large_language_model,
92
- vector_store,
93
- VECTOR_SEARCH_TOP_K,
94
- web_content,
95
- history_len,
96
- temperature,
97
- top_p,
98
- chat_history=[]):
99
- if web_content:
100
- prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。
101
- 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
102
- 已知网络检索内容:{web_content}""" + """
103
- 已知内容:
104
- {context}
105
- 问题:
106
- {question}"""
107
- else:
108
- prompt_template = """基于以下已知信息,请简洁并专业地回答用户的问题。
109
- 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。不允许在答案中添加编造成分。另外,答案请使用中文。
110
-
111
- 已知内容:
112
- {context}
113
-
114
- 问题:
115
- {question}"""
116
- prompt = PromptTemplate(template=prompt_template,
117
- input_variables=["context", "question"])
118
- chatLLM = ChatLLM()
119
- chatLLM.history = chat_history[-history_len:] if history_len > 0 else []
120
- if large_language_model == "Minimax":
121
- chatLLM.model = 'Minimax'
122
- else:
123
- chatLLM.load_model(
124
- model_name_or_path=llm_model_dict[large_language_model])
125
- chatLLM.temperature = temperature
126
- chatLLM.top_p = top_p
127
-
128
- knowledge_chain = RetrievalQA.from_llm(
129
- llm=chatLLM,
130
- retriever=vector_store.as_retriever(
131
- search_kwargs={"k": VECTOR_SEARCH_TOP_K}),
132
- prompt=prompt)
133
- knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
134
- input_variables=["page_content"], template="{page_content}")
135
-
136
- knowledge_chain.return_source_documents = True
137
-
138
- result = knowledge_chain({"query": query})
139
- return result
140
-
141
-
142
- def clear_session():
143
- return '', None
144
-
145
-
146
  def predict(input,
147
  large_language_model,
148
  embedding_model,
149
- file_obj,
150
  VECTOR_SEARCH_TOP_K,
151
  history_len,
152
  temperature,
@@ -155,12 +44,15 @@ def predict(input,
155
  history=None):
156
  if history == None:
157
  history = []
158
- print(file_obj.name)
159
- vector_store = init_knowledge_vector_store(embedding_model, file_obj.name)
 
 
160
  if use_web == 'True':
161
  web_content = search_web(query=input)
162
  else:
163
  web_content = ''
 
164
  resp = get_knowledge_based_answer(
165
  query=input,
166
  large_language_model=large_language_model,
@@ -172,8 +64,9 @@ def predict(input,
172
  temperature=temperature,
173
  top_p=top_p,
174
  )
175
- print(resp)
176
  history.append((input, resp['result']))
 
177
  return '', history, history
178
 
179
 
@@ -198,79 +91,85 @@ if __name__ == "__main__":
198
 
199
  embedding_model = gr.Dropdown(list(
200
  embedding_model_dict.keys()),
201
- label="Embedding model",
202
- value="text2vec-base")
203
-
204
- file = gr.File(label='请上传知识库文件, 目前支持txt、docx、md格式',
205
- file_types=['.txt', '.md', '.docx'])
206
-
207
- use_web = gr.Radio(["True", "False"],
208
- label="Web Search",
209
- value="False")
210
- model_argument = gr.Accordion("模型参数配置")
211
-
212
- with model_argument:
213
-
214
- VECTOR_SEARCH_TOP_K = gr.Slider(
215
- 1,
216
- 10,
217
- value=6,
218
- step=1,
219
- label="vector search top k",
220
- interactive=True)
221
-
222
- HISTORY_LEN = gr.Slider(0,
223
- 3,
224
- value=0,
225
- step=1,
226
- label="history len",
227
- interactive=True)
228
-
229
- temperature = gr.Slider(0,
230
- 1,
231
- value=0.01,
232
- step=0.01,
233
- label="temperature",
234
- interactive=True)
235
- top_p = gr.Slider(0,
236
- 1,
237
- value=0.9,
238
- step=0.1,
239
- label="top_p",
240
- interactive=True)
241
 
242
  with gr.Column(scale=4):
243
  chatbot = gr.Chatbot(label='ChatLLM').style(height=600)
244
  message = gr.Textbox(label='请输入问题')
245
  state = gr.State()
246
 
247
- with gr.Row():
248
- clear_history = gr.Button("🧹 清除历史对话")
249
- send = gr.Button("🚀 发送")
250
-
251
- send.click(predict,
252
- inputs=[
253
- message, large_language_model,
254
- embedding_model, file, VECTOR_SEARCH_TOP_K,
255
- HISTORY_LEN, temperature, top_p, use_web,
256
- state
257
- ],
258
- outputs=[message, chatbot, state])
259
- clear_history.click(fn=clear_session,
260
- inputs=[],
261
- outputs=[chatbot, state],
262
- queue=False)
263
-
264
- message.submit(predict,
265
- inputs=[
266
- message, large_language_model,
267
- embedding_model, file,
268
- VECTOR_SEARCH_TOP_K, HISTORY_LEN,
269
- temperature, top_p, use_web, state
270
- ],
271
- outputs=[message, chatbot, state])
272
- gr.Markdown("""提醒:<br>
273
- 1. 使用时请先上传自己的知识文件,并且文件中不含某些特殊字符,否则将返回error. <br>
274
- 2. 有任何使用问题,请通过[问题交流区](https://huggingface.co/spaces/thomas-yanxin/LangChain-ChatLLM/discussions)或[Github Issue区](https://github.com/thomas-yanxin/LangChain-ChatGLM-Webui/issues)进行反馈. <br>
275
- """)
276
- demo.queue().launch(server_name='0.0.0.0', share=False)
 
 
 
 
 
 
 
1
  import os
 
2
  import gradio as gr
3
  import nltk
4
  import sentence_transformers
 
12
  from langchain.prompts import PromptTemplate
13
  from langchain.prompts.prompt import PromptTemplate
14
  from langchain.vectorstores import FAISS
 
15
  from chatllm import ChatLLM
16
  from chinese_text_splitter import ChineseTextSplitter
17
 
18
+ def load_files(filepaths):
19
+ docs = []
20
+ for filepath in filepaths:
21
+ docs += load_file(filepath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  return docs
23
 
24
+ def init_knowledge_vector_store(embedding_model, filepaths):
25
+ embeddings = HuggingFaceEmbeddings(
26
+ model_name=embedding_model_dict[embedding_model], )
27
+ embeddings.client = sentence_transformers.SentenceTransformer(
28
+ embeddings.model_name, device=DEVICE)
29
 
30
+ docs = load_files(filepaths)
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  vector_store = FAISS.from_documents(docs, embeddings)
33
  return vector_store
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def predict(input,
36
  large_language_model,
37
  embedding_model,
38
+ file_objs,
39
  VECTOR_SEARCH_TOP_K,
40
  history_len,
41
  temperature,
 
44
  history=None):
45
  if history == None:
46
  history = []
47
+
48
+ filepaths = [file_obj.name for file_obj in file_objs]
49
+ vector_store = init_knowledge_vector_store(embedding_model, filepaths)
50
+
51
  if use_web == 'True':
52
  web_content = search_web(query=input)
53
  else:
54
  web_content = ''
55
+
56
  resp = get_knowledge_based_answer(
57
  query=input,
58
  large_language_model=large_language_model,
 
64
  temperature=temperature,
65
  top_p=top_p,
66
  )
67
+
68
  history.append((input, resp['result']))
69
+
70
  return '', history, history
71
 
72
 
 
91
 
92
  embedding_model = gr.Dropdown(list(
93
  embedding_model_dict.keys()),
94
+ label="Embedding model",
95
+ value="text2vec-base")
96
+
97
+ files = gr.Files(label='请上传知识库文件, 目前支持txt、docx、md格式',
98
+ file_types=['.txt', '.md', '.docx'])
99
+
100
+ use_web = gr.Radio(["True", "False"],
101
+ label="Web Search",
102
+ value="False")
103
+ model_argument = gr.Accordion("模型参数配置")
104
+
105
+ with model_argument:
106
+
107
+ VECTOR_SEARCH_TOP_K = gr.Slider(
108
+ 1,
109
+ 10,
110
+ value=6,
111
+ step=1,
112
+ label="vector search top k",
113
+ interactive=True)
114
+
115
+ HISTORY_LEN = gr.Slider(0,
116
+ 3,
117
+ value=0,
118
+ step=1,
119
+ label="history len",
120
+ interactive=True)
121
+
122
+ temperature = gr.Slider(0,
123
+ 1,
124
+ value=0.01,
125
+ step=0.01,
126
+ label="temperature",
127
+ interactive=True)
128
+ top_p = gr.Slider(0,
129
+ 1,
130
+ value=0.9,
131
+ step=0.1,
132
+ label="top_p",
133
+ interactive=True)
134
 
135
  with gr.Column(scale=4):
136
  chatbot = gr.Chatbot(label='ChatLLM').style(height=600)
137
  message = gr.Textbox(label='请输入问题')
138
  state = gr.State()
139
 
140
+ with gr.Row():
141
+ clear_history = gr.Button("🧹 清除历史对话")
142
+ send = gr.Button("🚀 发送")
143
+
144
+ send.click(predict,
145
+ inputs=[
146
+ message, large_language_model,
147
+ embedding_model, files, VECTOR_SEARCH_TOP_K,
148
+ HISTORY_LEN, temperature, top_p, use_web,
149
+ state
150
+ ],
151
+ outputs=[message, chatbot, state])
152
+ clear_history.click(fn=clear_session,
153
+ inputs=[],
154
+ outputs=[chatbot, state],
155
+ queue=False)
156
+
157
+ message.submit(predict,
158
+ inputs=[
159
+ message, large_language_model,
160
+ embedding_model, files,
161
+ VECTOR_SEARCH_TOP_K, HISTORY_LEN,
162
+ temperature, top_p, use_web, state
163
+ ],
164
+ outputs=[message, chatbot, state])
165
+ gr.Markdown("""提醒:<br>
166
+ 1. 使用时请先上传自己的知识文件,并且文件中不含某些特殊字符,否则将返回error. <br>
167
+ 2. 有任何使用请注意这里有一些关键的改动:
168
+
169
+ 1. 我将 `file = gr.File(label='请上传知识库文件, 目前支持txt、docx、md格式', file_types=['.txt', '.md', '.docx'])` 更改为 `files = gr.Files(label='请上传知识库文件, 目前支持txt、docx、md格式', file_types=['.txt', '.md', '.docx'])`。这意味着现在您可以上传多个文件。
170
+
171
+ 2. 在 `send.click` 和 `message.submit` 的 `inputs` 参数中,我将 `file` 改成了 `files`。
172
+
173
+ 这样一来,用户就能够在 Gradio 界面中选择并上传多个文件了。这些文件会被传递给 `predict` 函数,然后被合并在一起,送到模型进行处理。
174
+
175
+ 请注意,由于我并不了解你的全部代码和环境,这个修改可能需要一些额外的调整才能在你的环境中正常运行。我推荐你在实际应用这段代码之前,先在一个安全的环境中进行测试。