File size: 7,309 Bytes
3cb43ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23ae983
 
 
 
3cb43ab
 
23ae983
 
 
 
 
3cb43ab
23ae983
3cb43ab
 
 
 
 
 
 
23ae983
3cb43ab
 
 
 
 
 
 
 
23ae983
 
 
 
3cb43ab
 
 
 
23ae983
3cb43ab
 
 
 
 
 
 
 
 
 
 
23ae983
3cb43ab
23ae983
3cb43ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23ae983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cb43ab
 
 
 
 
 
23ae983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import gradio as gr
import nltk
import sentence_transformers
import torch
from duckduckgo_search import ddg
from duckduckgo_search.utils import SESSION
from langchain.chains import RetrievalQA
from langchain.document_loaders import UnstructuredFileLoader
from langchain.embeddings import JinaEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.vectorstores import FAISS
from chatllm import ChatLLM
from chinese_text_splitter import ChineseTextSplitter

def load_files(filepaths):
    docs = []
    for filepath in filepaths:
        docs += load_file(filepath)
    return docs

def init_knowledge_vector_store(embedding_model, filepaths):
    embeddings = HuggingFaceEmbeddings(
        model_name=embedding_model_dict[embedding_model], )
    embeddings.client = sentence_transformers.SentenceTransformer(
        embeddings.model_name, device=DEVICE)

    docs = load_files(filepaths)

    vector_store = FAISS.from_documents(docs, embeddings)
    return vector_store

def predict(input,
            large_language_model,
            embedding_model,
            file_objs,
            VECTOR_SEARCH_TOP_K,
            history_len,
            temperature,
            top_p,
            use_web,
            history=None):
    if history == None:
        history = []

    filepaths = [file_obj.name for file_obj in file_objs]
    vector_store = init_knowledge_vector_store(embedding_model, filepaths)

    if use_web == 'True':
        web_content = search_web(query=input)
    else:
        web_content = ''

    resp = get_knowledge_based_answer(
        query=input,
        large_language_model=large_language_model,
        vector_store=vector_store,
        VECTOR_SEARCH_TOP_K=VECTOR_SEARCH_TOP_K,
        web_content=web_content,
        chat_history=history,
        history_len=history_len,
        temperature=temperature,
        top_p=top_p,
    )

    history.append((input, resp['result']))

    return '', history, history


if __name__ == "__main__":
    block = gr.Blocks()
    with block as demo:
        gr.Markdown("""<h1><center>Knowledge Based-ChatGLM-DongHeng</center></h1>
        <center><font size=3>
        本软件是Dongheng自用中文大模型软件. <br>
        项目内容是基于GLM-6B大模型对本地上传知识库内容进行应答 <br>
        禁止任何商业用途,谢谢!
        </center></font>
        """)
        with gr.Row():
            with gr.Column(scale=1):
                model_choose = gr.Accordion("模型选择")
                with model_choose:
                    large_language_model = gr.Dropdown(
                        list(llm_model_dict.keys()),
                        label="large language model",
                        value="ChatGLM-6B-int4")

                    embedding_model = gr.Dropdown(list(
                        embedding_model_dict.keys()),
                        label="Embedding model",
                        value="text2vec-base")

                    files = gr.Files(label='请上传知识库文件, 目前支持txt、docx、md格式',
                                     file_types=['.txt', '.md', '.docx'])

                    use_web = gr.Radio(["True", "False"],
                                       label="Web Search",
                                       value="False")
                    model_argument = gr.Accordion("模型参数配置")

                    with model_argument:

                        VECTOR_SEARCH_TOP_K = gr.Slider(
                            1,
                            10,
                            value=6,
                            step=1,
                            label="vector search top k",
                            interactive=True)

                        HISTORY_LEN = gr.Slider(0,
                                                3,
                                                value=0,
                                                step=1,
                                                label="history len",
                                                interactive=True)

                        temperature = gr.Slider(0,
                                                1,
                                                value=0.01,
                                                step=0.01,
                                                label="temperature",
                                                interactive=True)
                        top_p = gr.Slider(0,
                                          1,
                                          value=0.9,
                                          step=0.1,
                                          label="top_p",
                                          interactive=True)

            with gr.Column(scale=4):
                chatbot = gr.Chatbot(label='ChatLLM').style(height=600)
                message = gr.Textbox(label='请输入问题')
                state = gr.State()

            with gr.Row():
                clear_history = gr.Button("🧹 清除历史对话")
                send = gr.Button("🚀 发送")

            send.click(predict,
                       inputs=[
                           message, large_language_model,
                           embedding_model, files, VECTOR_SEARCH_TOP_K,
                           HISTORY_LEN, temperature, top_p, use_web,
                           state
                       ],
                       outputs=[message, chatbot, state])
            clear_history.click(fn=clear_session,
                                inputs=[],
                                outputs=[chatbot, state],
                                queue=False)

            message.submit(predict,
                           inputs=[
                               message, large_language_model,
                               embedding_model, files,
                               VECTOR_SEARCH_TOP_K, HISTORY_LEN,
                               temperature, top_p, use_web, state
                           ],
                           outputs=[message, chatbot, state])
            gr.Markdown("""提醒:<br>
            1. 使用时请先上传自己的知识文件,并且文件中不含某些特殊字符,否则将返回error. <br>
            2. 有任何使用请注意这里有一些关键的改动:

1. 我将 `file = gr.File(label='请上传知识库文件, 目前支持txt、docx、md格式', file_types=['.txt', '.md', '.docx'])` 更改为 `files = gr.Files(label='请上传知识库文件, 目前支持txt、docx、md格式', file_types=['.txt', '.md', '.docx'])`。这意味着现在您可以上传多个文件。

2. 在 `send.click` 和 `message.submit` 的 `inputs` 参数中,我将 `file` 改成了 `files`。

这样一来,用户就能够在 Gradio 界面中选择并上传多个文件了。这些文件会被传递给 `predict` 函数,然后被合并在一起,送到模型进行处理。

请注意,由于我并不了解你的全部代码和环境,这个修改可能需要一些额外的调整才能在你的环境中正常运行。我推荐你在实际应用这段代码之前,先在一个安全的环境中进行测试。