Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import os
|
2 |
-
|
3 |
import gradio as gr
|
4 |
import nltk
|
5 |
import sentence_transformers
|
@@ -13,140 +12,30 @@ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
13 |
from langchain.prompts import PromptTemplate
|
14 |
from langchain.prompts.prompt import PromptTemplate
|
15 |
from langchain.vectorstores import FAISS
|
16 |
-
|
17 |
from chatllm import ChatLLM
|
18 |
from chinese_text_splitter import ChineseTextSplitter
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
"ernie-base": "nghuyong/ernie-3.0-base-zh",
|
25 |
-
"text2vec-base": "GanymedeNil/text2vec-base-chinese",
|
26 |
-
#"ViT-B-32": 'ViT-B-32::laion2b-s34b-b79k'
|
27 |
-
}
|
28 |
-
|
29 |
-
llm_model_dict = {
|
30 |
-
"ChatGLM-6B-int8": "THUDM/chatglm-6b-int8",
|
31 |
-
"ChatGLM-6B-int4": "THUDM/chatglm-6b-int4",
|
32 |
-
"ChatGLM-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
|
33 |
-
#"Minimax": "Minimax"
|
34 |
-
}
|
35 |
-
|
36 |
-
DEVICE = "cuda" if torch.cuda.is_available(
|
37 |
-
) else "mps" if torch.backends.mps.is_available() else "cpu"
|
38 |
-
|
39 |
-
|
40 |
-
def search_web(query):
|
41 |
-
|
42 |
-
SESSION.proxies = {
|
43 |
-
"http": f"socks5h://localhost:7890",
|
44 |
-
"https": f"socks5h://localhost:7890"
|
45 |
-
}
|
46 |
-
results = ddg(query)
|
47 |
-
web_content = ''
|
48 |
-
if results:
|
49 |
-
for result in results:
|
50 |
-
web_content += result['body']
|
51 |
-
return web_content
|
52 |
-
|
53 |
-
|
54 |
-
def load_file(filepath):
|
55 |
-
if filepath.lower().endswith(".pdf"):
|
56 |
-
loader = UnstructuredFileLoader(filepath)
|
57 |
-
textsplitter = ChineseTextSplitter(pdf=True)
|
58 |
-
docs = loader.load_and_split(textsplitter)
|
59 |
-
elif filepath.lower().endswith(".xlsx") or filepath.lower().endswith(".xls"):
|
60 |
-
# Read the Excel file into a pandas DataFrame
|
61 |
-
df = pd.read_excel(filepath)
|
62 |
-
# Convert the DataFrame to a list of strings (or however you want to process the data)
|
63 |
-
docs = df.values.tolist()
|
64 |
-
else:
|
65 |
-
loader = UnstructuredFileLoader(filepath, mode="elements")
|
66 |
-
textsplitter = ChineseTextSplitter(pdf=False)
|
67 |
-
docs = loader.load_and_split(text_splitter=textsplitter)
|
68 |
return docs
|
69 |
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
72 |
-
def init_knowledge_vector_store(embedding_model, filepath):
|
73 |
-
if embedding_model == "ViT-B-32":
|
74 |
-
jina_auth_token = os.getenv('jina_auth_token')
|
75 |
-
embeddings = JinaEmbeddings(
|
76 |
-
jina_auth_token=jina_auth_token,
|
77 |
-
model_name=embedding_model_dict[embedding_model])
|
78 |
-
else:
|
79 |
-
embeddings = HuggingFaceEmbeddings(
|
80 |
-
model_name=embedding_model_dict[embedding_model], )
|
81 |
-
embeddings.client = sentence_transformers.SentenceTransformer(
|
82 |
-
embeddings.model_name, device=DEVICE)
|
83 |
-
|
84 |
-
docs = load_file(filepath)
|
85 |
|
86 |
vector_store = FAISS.from_documents(docs, embeddings)
|
87 |
return vector_store
|
88 |
|
89 |
-
|
90 |
-
def get_knowledge_based_answer(query,
|
91 |
-
large_language_model,
|
92 |
-
vector_store,
|
93 |
-
VECTOR_SEARCH_TOP_K,
|
94 |
-
web_content,
|
95 |
-
history_len,
|
96 |
-
temperature,
|
97 |
-
top_p,
|
98 |
-
chat_history=[]):
|
99 |
-
if web_content:
|
100 |
-
prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。
|
101 |
-
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
|
102 |
-
已知网络检索内容:{web_content}""" + """
|
103 |
-
已知内容:
|
104 |
-
{context}
|
105 |
-
问题:
|
106 |
-
{question}"""
|
107 |
-
else:
|
108 |
-
prompt_template = """基于以下已知信息,请简洁并专业地回答用户的问题。
|
109 |
-
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。不允许在答案中添加编造成分。另外,答案请使用中文。
|
110 |
-
|
111 |
-
已知内容:
|
112 |
-
{context}
|
113 |
-
|
114 |
-
问题:
|
115 |
-
{question}"""
|
116 |
-
prompt = PromptTemplate(template=prompt_template,
|
117 |
-
input_variables=["context", "question"])
|
118 |
-
chatLLM = ChatLLM()
|
119 |
-
chatLLM.history = chat_history[-history_len:] if history_len > 0 else []
|
120 |
-
if large_language_model == "Minimax":
|
121 |
-
chatLLM.model = 'Minimax'
|
122 |
-
else:
|
123 |
-
chatLLM.load_model(
|
124 |
-
model_name_or_path=llm_model_dict[large_language_model])
|
125 |
-
chatLLM.temperature = temperature
|
126 |
-
chatLLM.top_p = top_p
|
127 |
-
|
128 |
-
knowledge_chain = RetrievalQA.from_llm(
|
129 |
-
llm=chatLLM,
|
130 |
-
retriever=vector_store.as_retriever(
|
131 |
-
search_kwargs={"k": VECTOR_SEARCH_TOP_K}),
|
132 |
-
prompt=prompt)
|
133 |
-
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
|
134 |
-
input_variables=["page_content"], template="{page_content}")
|
135 |
-
|
136 |
-
knowledge_chain.return_source_documents = True
|
137 |
-
|
138 |
-
result = knowledge_chain({"query": query})
|
139 |
-
return result
|
140 |
-
|
141 |
-
|
142 |
-
def clear_session():
|
143 |
-
return '', None
|
144 |
-
|
145 |
-
|
146 |
def predict(input,
|
147 |
large_language_model,
|
148 |
embedding_model,
|
149 |
-
|
150 |
VECTOR_SEARCH_TOP_K,
|
151 |
history_len,
|
152 |
temperature,
|
@@ -155,12 +44,15 @@ def predict(input,
|
|
155 |
history=None):
|
156 |
if history == None:
|
157 |
history = []
|
158 |
-
|
159 |
-
|
|
|
|
|
160 |
if use_web == 'True':
|
161 |
web_content = search_web(query=input)
|
162 |
else:
|
163 |
web_content = ''
|
|
|
164 |
resp = get_knowledge_based_answer(
|
165 |
query=input,
|
166 |
large_language_model=large_language_model,
|
@@ -172,8 +64,9 @@ def predict(input,
|
|
172 |
temperature=temperature,
|
173 |
top_p=top_p,
|
174 |
)
|
175 |
-
|
176 |
history.append((input, resp['result']))
|
|
|
177 |
return '', history, history
|
178 |
|
179 |
|
@@ -198,79 +91,85 @@ if __name__ == "__main__":
|
|
198 |
|
199 |
embedding_model = gr.Dropdown(list(
|
200 |
embedding_model_dict.keys()),
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
|
242 |
with gr.Column(scale=4):
|
243 |
chatbot = gr.Chatbot(label='ChatLLM').style(height=600)
|
244 |
message = gr.Textbox(label='请输入问题')
|
245 |
state = gr.State()
|
246 |
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
import gradio as gr
|
3 |
import nltk
|
4 |
import sentence_transformers
|
|
|
12 |
from langchain.prompts import PromptTemplate
|
13 |
from langchain.prompts.prompt import PromptTemplate
|
14 |
from langchain.vectorstores import FAISS
|
|
|
15 |
from chatllm import ChatLLM
|
16 |
from chinese_text_splitter import ChineseTextSplitter
|
17 |
|
18 |
+
def load_files(filepaths):
|
19 |
+
docs = []
|
20 |
+
for filepath in filepaths:
|
21 |
+
docs += load_file(filepath)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
return docs
|
23 |
|
24 |
+
def init_knowledge_vector_store(embedding_model, filepaths):
|
25 |
+
embeddings = HuggingFaceEmbeddings(
|
26 |
+
model_name=embedding_model_dict[embedding_model], )
|
27 |
+
embeddings.client = sentence_transformers.SentenceTransformer(
|
28 |
+
embeddings.model_name, device=DEVICE)
|
29 |
|
30 |
+
docs = load_files(filepaths)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
vector_store = FAISS.from_documents(docs, embeddings)
|
33 |
return vector_store
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def predict(input,
|
36 |
large_language_model,
|
37 |
embedding_model,
|
38 |
+
file_objs,
|
39 |
VECTOR_SEARCH_TOP_K,
|
40 |
history_len,
|
41 |
temperature,
|
|
|
44 |
history=None):
|
45 |
if history == None:
|
46 |
history = []
|
47 |
+
|
48 |
+
filepaths = [file_obj.name for file_obj in file_objs]
|
49 |
+
vector_store = init_knowledge_vector_store(embedding_model, filepaths)
|
50 |
+
|
51 |
if use_web == 'True':
|
52 |
web_content = search_web(query=input)
|
53 |
else:
|
54 |
web_content = ''
|
55 |
+
|
56 |
resp = get_knowledge_based_answer(
|
57 |
query=input,
|
58 |
large_language_model=large_language_model,
|
|
|
64 |
temperature=temperature,
|
65 |
top_p=top_p,
|
66 |
)
|
67 |
+
|
68 |
history.append((input, resp['result']))
|
69 |
+
|
70 |
return '', history, history
|
71 |
|
72 |
|
|
|
91 |
|
92 |
embedding_model = gr.Dropdown(list(
|
93 |
embedding_model_dict.keys()),
|
94 |
+
label="Embedding model",
|
95 |
+
value="text2vec-base")
|
96 |
+
|
97 |
+
files = gr.Files(label='请上传知识库文件, 目前支持txt、docx、md格式',
|
98 |
+
file_types=['.txt', '.md', '.docx'])
|
99 |
+
|
100 |
+
use_web = gr.Radio(["True", "False"],
|
101 |
+
label="Web Search",
|
102 |
+
value="False")
|
103 |
+
model_argument = gr.Accordion("模型参数配置")
|
104 |
+
|
105 |
+
with model_argument:
|
106 |
+
|
107 |
+
VECTOR_SEARCH_TOP_K = gr.Slider(
|
108 |
+
1,
|
109 |
+
10,
|
110 |
+
value=6,
|
111 |
+
step=1,
|
112 |
+
label="vector search top k",
|
113 |
+
interactive=True)
|
114 |
+
|
115 |
+
HISTORY_LEN = gr.Slider(0,
|
116 |
+
3,
|
117 |
+
value=0,
|
118 |
+
step=1,
|
119 |
+
label="history len",
|
120 |
+
interactive=True)
|
121 |
+
|
122 |
+
temperature = gr.Slider(0,
|
123 |
+
1,
|
124 |
+
value=0.01,
|
125 |
+
step=0.01,
|
126 |
+
label="temperature",
|
127 |
+
interactive=True)
|
128 |
+
top_p = gr.Slider(0,
|
129 |
+
1,
|
130 |
+
value=0.9,
|
131 |
+
step=0.1,
|
132 |
+
label="top_p",
|
133 |
+
interactive=True)
|
134 |
|
135 |
with gr.Column(scale=4):
|
136 |
chatbot = gr.Chatbot(label='ChatLLM').style(height=600)
|
137 |
message = gr.Textbox(label='请输入问题')
|
138 |
state = gr.State()
|
139 |
|
140 |
+
with gr.Row():
|
141 |
+
clear_history = gr.Button("🧹 清除历史对话")
|
142 |
+
send = gr.Button("🚀 发送")
|
143 |
+
|
144 |
+
send.click(predict,
|
145 |
+
inputs=[
|
146 |
+
message, large_language_model,
|
147 |
+
embedding_model, files, VECTOR_SEARCH_TOP_K,
|
148 |
+
HISTORY_LEN, temperature, top_p, use_web,
|
149 |
+
state
|
150 |
+
],
|
151 |
+
outputs=[message, chatbot, state])
|
152 |
+
clear_history.click(fn=clear_session,
|
153 |
+
inputs=[],
|
154 |
+
outputs=[chatbot, state],
|
155 |
+
queue=False)
|
156 |
+
|
157 |
+
message.submit(predict,
|
158 |
+
inputs=[
|
159 |
+
message, large_language_model,
|
160 |
+
embedding_model, files,
|
161 |
+
VECTOR_SEARCH_TOP_K, HISTORY_LEN,
|
162 |
+
temperature, top_p, use_web, state
|
163 |
+
],
|
164 |
+
outputs=[message, chatbot, state])
|
165 |
+
gr.Markdown("""提醒:<br>
|
166 |
+
1. 使用时请先上传自己的知识文件,并且文件中不含某些特殊字符,否则将返回error. <br>
|
167 |
+
2. 有任何使用请注意这里有一些关键的改动:
|
168 |
+
|
169 |
+
1. 我将 `file = gr.File(label='请上传知识库文件, 目前支持txt、docx、md格式', file_types=['.txt', '.md', '.docx'])` 更改为 `files = gr.Files(label='请上传知识库文件, 目前支持txt、docx、md格式', file_types=['.txt', '.md', '.docx'])`。这意味着现在您可以上传多个文件。
|
170 |
+
|
171 |
+
2. 在 `send.click` 和 `message.submit` 的 `inputs` 参数中,我将 `file` 改成了 `files`。
|
172 |
+
|
173 |
+
这样一来,用户就能够在 Gradio 界面中选择并上传多个文件了。这些文件会被传递给 `predict` 函数,然后被合并在一起,送到模型进行处理。
|
174 |
+
|
175 |
+
请注意,由于我并不了解你的全部代码和环境,这个修改可能需要一些额外的调整才能在你的环境中正常运行。我推荐你在实际应用这段代码之前,先在一个安全的环境中进行测试。
|