Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import nltk | |
import sentence_transformers | |
import torch | |
from duckduckgo_search import ddg | |
from duckduckgo_search.utils import SESSION | |
from langchain.chains import RetrievalQA | |
from langchain.document_loaders import UnstructuredFileLoader | |
from langchain.embeddings import JinaEmbeddings | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
from langchain.prompts import PromptTemplate | |
from langchain.prompts.prompt import PromptTemplate | |
from langchain.vectorstores import FAISS | |
from chatllm import ChatLLM | |
from chinese_text_splitter import ChineseTextSplitter | |
def load_files(filepaths): | |
docs = [] | |
for filepath in filepaths: | |
docs += load_file(filepath) | |
return docs | |
def init_knowledge_vector_store(embedding_model, filepaths): | |
embeddings = HuggingFaceEmbeddings( | |
model_name=embedding_model_dict[embedding_model], ) | |
embeddings.client = sentence_transformers.SentenceTransformer( | |
embeddings.model_name, device=DEVICE) | |
docs = load_files(filepaths) | |
vector_store = FAISS.from_documents(docs, embeddings) | |
return vector_store | |
def predict(input, | |
large_language_model, | |
embedding_model, | |
file_objs, | |
VECTOR_SEARCH_TOP_K, | |
history_len, | |
temperature, | |
top_p, | |
use_web, | |
history=None): | |
if history == None: | |
history = [] | |
filepaths = [file_obj.name for file_obj in file_objs] | |
vector_store = init_knowledge_vector_store(embedding_model, filepaths) | |
if use_web == 'True': | |
web_content = search_web(query=input) | |
else: | |
web_content = '' | |
resp = get_knowledge_based_answer( | |
query=input, | |
large_language_model=large_language_model, | |
vector_store=vector_store, | |
VECTOR_SEARCH_TOP_K=VECTOR_SEARCH_TOP_K, | |
web_content=web_content, | |
chat_history=history, | |
history_len=history_len, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
history.append((input, resp['result'])) | |
return '', history, history | |
if __name__ == "__main__": | |
block = gr.Blocks() | |
with block as demo: | |
gr.Markdown("""<h1><center>Knowledge Based-ChatGLM-DongHeng</center></h1> | |
<center><font size=3> | |
本软件是Dongheng自用中文大模型软件. <br> | |
项目内容是基于GLM-6B大模型对本地上传知识库内容进行应答 <br> | |
禁止任何商业用途,谢谢! | |
</center></font> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_choose = gr.Accordion("模型选择") | |
with model_choose: | |
large_language_model = gr.Dropdown( | |
list(llm_model_dict.keys()), | |
label="large language model", | |
value="ChatGLM-6B-int4") | |
embedding_model = gr.Dropdown(list( | |
embedding_model_dict.keys()), | |
label="Embedding model", | |
value="text2vec-base") | |
files = gr.Files(label='请上传知识库文件, 目前支持txt、docx、md格式', | |
file_types=['.txt', '.md', '.docx']) | |
use_web = gr.Radio(["True", "False"], | |
label="Web Search", | |
value="False") | |
model_argument = gr.Accordion("模型参数配置") | |
with model_argument: | |
VECTOR_SEARCH_TOP_K = gr.Slider( | |
1, | |
10, | |
value=6, | |
step=1, | |
label="vector search top k", | |
interactive=True) | |
HISTORY_LEN = gr.Slider(0, | |
3, | |
value=0, | |
step=1, | |
label="history len", | |
interactive=True) | |
temperature = gr.Slider(0, | |
1, | |
value=0.01, | |
step=0.01, | |
label="temperature", | |
interactive=True) | |
top_p = gr.Slider(0, | |
1, | |
value=0.9, | |
step=0.1, | |
label="top_p", | |
interactive=True) | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot(label='ChatLLM').style(height=600) | |
message = gr.Textbox(label='请输入问题') | |
state = gr.State() | |
with gr.Row(): | |
clear_history = gr.Button("🧹 清除历史对话") | |
send = gr.Button("🚀 发送") | |
send.click(predict, | |
inputs=[ | |
message, large_language_model, | |
embedding_model, files, VECTOR_SEARCH_TOP_K, | |
HISTORY_LEN, temperature, top_p, use_web, | |
state | |
], | |
outputs=[message, chatbot, state]) | |
clear_history.click(fn=clear_session, | |
inputs=[], | |
outputs=[chatbot, state], | |
queue=False) | |
message.submit(predict, | |
inputs=[ | |
message, large_language_model, | |
embedding_model, files, | |
VECTOR_SEARCH_TOP_K, HISTORY_LEN, | |
temperature, top_p, use_web, state | |
], | |
outputs=[message, chatbot, state]) | |
gr.Markdown("""提醒:<br> | |
1. 使用时请先上传自己的知识文件,并且文件中不含某些特殊字符,否则将返回error. <br> | |
2. 有任何使用请注意这里有一些关键的改动: | |
1. 我将 `file = gr.File(label='请上传知识库文件, 目前支持txt、docx、md格式', file_types=['.txt', '.md', '.docx'])` 更改为 `files = gr.Files(label='请上传知识库文件, 目前支持txt、docx、md格式', file_types=['.txt', '.md', '.docx'])`。这意味着现在您可以上传多个文件。 | |
2. 在 `send.click` 和 `message.submit` 的 `inputs` 参数中,我将 `file` 改成了 `files`。 | |
这样一来,用户就能够在 Gradio 界面中选择并上传多个文件了。这些文件会被传递给 `predict` 函数,然后被合并在一起,送到模型进行处理。 | |
请注意,由于我并不了解你的全部代码和环境,这个修改可能需要一些额外的调整才能在你的环境中正常运行。我推荐你在实际应用这段代码之前,先在一个安全的环境中进行测试。 | |