File size: 12,602 Bytes
184a47b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2930d96
 
184a47b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2930d96
 
184a47b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2930d96
 
184a47b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2930d96
 
 
 
 
 
 
 
 
 
 
 
184a47b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
# 导入必要的库

import sys
import os  # 用于操作系统相关的操作,例如读取环境变量

sys.path.append(os.path.dirname(os.path.dirname(__file__)))

import IPython.display  # 用于在 IPython 环境中显示数据,例如图片
import io  # 用于处理流式数据(例如文件流)
import gradio as gr
from dotenv import load_dotenv, find_dotenv
from llm.call_llm import get_completion
from database.create_db import create_db_info
from qa_chain.Chat_QA_chain_self import Chat_QA_chain_self
from qa_chain.QA_chain_self import QA_chain_self
import re

# 导入 dotenv 库的函数
# dotenv 允许您从 .env 文件中读取环境变量
# 这在开发时特别有用,可以避免将敏感信息(如API密钥)硬编码到代码中

# 寻找 .env 文件并加载它的内容
# 这允许您使用 os.environ 来读取在 .env 文件中设置的环境变量
_ = load_dotenv(find_dotenv())
LLM_MODEL_DICT = {
    # "openai": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0613", "gpt-4", "gpt-4-32k"],
    # "wenxin": ["ERNIE-Bot", "ERNIE-Bot-4", "ERNIE-Bot-turbo"],
    # "xinhuo": ["Spark-1.5", "Spark-2.0"],
    "zhipuai": ["chatglm_pro", "chatglm_std", "chatglm_lite"]
}


LLM_MODEL_LIST = sum(list(LLM_MODEL_DICT.values()), [])
INIT_LLM = "chatglm_pro"
# EMBEDDING_MODEL_LIST = ['zhipuai', 'openai', 'm3e']
EMBEDDING_MODEL_LIST = ["zhipuai"]
INIT_EMBEDDING_MODEL = "zhipuai"
DEFAULT_DB_PATH = "./knowledge_db"
DEFAULT_PERSIST_PATH = "./vector_db/chroma"
AIGC_AVATAR_PATH = "./figures/aigc_avatar.png"
DATAWHALE_AVATAR_PATH = "./figures/datawhale_avatar.png"
AIGC_LOGO_PATH = "./figures/aigc_logo.png"
DATAWHALE_LOGO_PATH = "./figures/datawhale_logo.png"


def get_model_by_platform(platform):
    return LLM_MODEL_DICT.get(platform, "")


class Model_center:
    """
    存储问答 Chain 的对象

    - chat_qa_chain_self: 以 (model, embedding) 为键存储的带历史记录的问答链。
    - qa_chain_self: 以 (model, embedding) 为键存储的不带历史记录的问答链。
    """

    def __init__(self):
        self.chat_qa_chain_self = {}
        self.qa_chain_self = {}

    def chat_qa_chain_self_answer(
        self,
        question: str,
        chat_history: list = [],
        model: str = "glm-4",
        embedding: str = "embedding-2",
        temperature: float = 0.0,
        top_k: int = 4,
        history_len: int = 3,
        file_path: str = DEFAULT_DB_PATH,
        persist_path: str = DEFAULT_PERSIST_PATH,
    ):
        """
        调用带历史记录的问答链进行回答
        """
        if question == None or len(question) < 1:
            return "", chat_history
        try:
            if (model, embedding) not in self.chat_qa_chain_self:
                self.chat_qa_chain_self[(model, embedding)] = Chat_QA_chain_self(
                    model=model,
                    temperature=temperature,
                    top_k=top_k,
                    chat_history=chat_history,
                    file_path=file_path,
                    persist_path=persist_path,
                    embedding=embedding,
                )
            chain = self.chat_qa_chain_self[(model, embedding)]
            return "", chain.answer(
                question=question, temperature=temperature, top_k=top_k
            )
        except Exception as e:
            return e, chat_history

    def qa_chain_self_answer(
        self,
        question: str,
        chat_history: list = [],
        model: str = "glm-4",
        embedding="embedding-2",
        temperature: float = 0.0,
        top_k: int = 4,
        file_path: str = DEFAULT_DB_PATH,
        persist_path: str = DEFAULT_PERSIST_PATH,
    ):
        """
        调用不带历史记录的问答链进行回答
        """
        if question == None or len(question) < 1:
            return "", chat_history
        try:
            if (model, embedding) not in self.qa_chain_self:
                self.qa_chain_self[(model, embedding)] = QA_chain_self(
                    model=model,
                    temperature=temperature,
                    top_k=top_k,
                    file_path=file_path,
                    persist_path=persist_path,
                    embedding=embedding,
                )
            chain = self.qa_chain_self[(model, embedding)]
            chat_history.append((question, chain.answer(question, temperature, top_k)))
            return "", chat_history
        except Exception as e:
            return e, chat_history

    def clear_history(self):
        if len(self.chat_qa_chain_self) > 0:
            for chain in self.chat_qa_chain_self.values():
                chain.clear_history()


def format_chat_prompt(message, chat_history):
    """
    该函数用于格式化聊天 prompt。

    参数:
    message: 当前的用户消息。
    chat_history: 聊天历史记录。

    返回:
    prompt: 格式化后的 prompt。
    """
    # 初始化一个空字符串,用于存放格式化后的聊天 prompt。
    prompt = ""
    # 遍历聊天历史记录。
    for turn in chat_history:
        # 从聊天记录中提取用户和机器人的消息。
        user_message, bot_message = turn
        # 更新 prompt,加入用户和机器人的消息。
        prompt = f"{prompt}\nUser: {user_message}\nAssistant: {bot_message}"
    # 将当前的用户消息也加入到 prompt中,并预留一个位置给机器人的回复。
    prompt = f"{prompt}\nUser: {message}\nAssistant:"
    # 返回格式化后的 prompt。
    return prompt


def respond(
    message, chat_history, llm, history_len=3, temperature=0.1, max_tokens=2048
):
    """
    该函数用于生成机器人的回复。

    参数:
    message: 当前的用户消息。
    chat_history: 聊天历史记录。

    返回:
    "": 空字符串表示没有内容需要显示在界面上,可以替换为真正的机器人回复。
    chat_history: 更新后的聊天历史记录
    """
    if message == None or len(message) < 1:
        return "", chat_history
    try:
        # 限制 history 的记忆长度
        chat_history = chat_history[-history_len:] if history_len > 0 else []
        # 调用上面的函数,将用户的消息和聊天历史记录格式化为一个 prompt。
        formatted_prompt = format_chat_prompt(message, chat_history)
        # 使用llm对象的predict方法生成机器人的回复(注意:llm对象在此代码中并未定义)。
        bot_message = get_completion(
            formatted_prompt, llm, temperature=temperature, max_tokens=max_tokens
        )
        # 将bot_message中\n换为<br/>
        bot_message = re.sub(r"\\n", "<br/>", bot_message)
        # 将用户的消息和机器人的回复加入到聊天历史记录中。
        chat_history.append((message, bot_message))
        # 返回一个空字符串和更新后的聊天历史记录(这里的空字符串可以替换为真正的机器人回复,如果需要显示在界面上)。
        return "", chat_history
    except Exception as e:
        return e, chat_history


model_center = Model_center()

block = gr.Blocks()
with block as demo:
    with gr.Row(equal_height=True):
        # gr.Image(value=AIGC_LOGO_PATH, scale=1, min_width=10, show_label=False, show_download_button=False, container=False)

        with gr.Column(scale=2):
            gr.Markdown(
                """<h1><center>大模型应用开发</center></h1>
                <center>LLM-UNIVERSE</center>
                """
            )
        # gr.Image(value=DATAWHALE_LOGO_PATH, scale=1, min_width=10, show_label=False, show_download_button=False, container=False)

    with gr.Row():
        with gr.Column(scale=4):
            # chatbot = gr.Chatbot(height=400, show_copy_button=True, show_share_button=True, avatar_images=(AIGC_AVATAR_PATH, DATAWHALE_AVATAR_PATH))
            chatbot = gr.Chatbot(
                height=400, show_copy_button=True, show_share_button=True
            )
            # 创建一个文本框组件,用于输入 prompt。
            msg = gr.Textbox(label="Prompt/问题")

            with gr.Row():
                # 创建提交按钮。
                db_with_his_btn = gr.Button("Chat db with history")
                db_wo_his_btn = gr.Button("Chat db without history")
                llm_btn = gr.Button("Chat with llm")
            with gr.Row():
                # 创建一个清除按钮,用于清除聊天机器人组件的内容。
                clear = gr.ClearButton(components=[chatbot], value="Clear console")

        with gr.Column(scale=1):
            file = gr.File(
                label="请选择知识库目录",
                file_count="directory",
                file_types=[".txt", ".md", ".docx", ".pdf"],
            )
            with gr.Row():
                init_db = gr.Button("知识库文件向量化")
            model_argument = gr.Accordion("参数配置", open=False)
            with model_argument:
                temperature = gr.Slider(
                    0,
                    1,
                    value=0.01,
                    step=0.01,
                    label="llm temperature",
                    interactive=True,
                )

                top_k = gr.Slider(
                    1,
                    10,
                    value=3,
                    step=1,
                    label="vector db search top k",
                    interactive=True,
                )

                history_len = gr.Slider(
                    0, 5, value=3, step=1, label="history length", interactive=True
                )

            model_select = gr.Accordion("模型选择")
            with model_select:
                llm = gr.Dropdown(
                    LLM_MODEL_LIST,
                    label="large language model",
                    value=INIT_LLM,
                    interactive=True,
                )

                embeddings = gr.Dropdown(
                    EMBEDDING_MODEL_LIST,
                    label="Embedding model",
                    value=INIT_EMBEDDING_MODEL,
                )

        # 设置初始化向量数据库按钮的点击事件。当点击时,调用 create_db_info 函数,并传入用户的文件和希望使用的 Embedding 模型。
        init_db.click(create_db_info, inputs=[file, embeddings], outputs=[msg])

        # 设置按钮的点击事件。当点击时,调用上面定义的 chat_qa_chain_self_answer 函数,并传入用户的消息和聊天历史记录,然后更新文本框和聊天机器人组件。
        db_with_his_btn.click(
            model_center.chat_qa_chain_self_answer,
            inputs=[msg, chatbot, llm, embeddings, temperature, top_k, history_len],
            outputs=[msg, chatbot],
        )
        # 设置按钮的点击事件。当点击时,调用上面定义的 qa_chain_self_answer 函数,并传入用户的消息和聊天历史记录,然后更新文本框和聊天机器人组件。
        db_wo_his_btn.click(
            model_center.qa_chain_self_answer,
            inputs=[msg, chatbot, llm, embeddings, temperature, top_k],
            outputs=[msg, chatbot],
        )
        # 设置按钮的点击事件。当点击时,调用上面定义的 respond 函数,并传入用户的消息和聊天历史记录,然后更新文本框和聊天机器人组件。
        llm_btn.click(
            respond,
            inputs=[msg, chatbot, llm, history_len, temperature],
            outputs=[msg, chatbot],
            show_progress="minimal",
        )

        # 设置文本框的提交事件(即按下Enter键时)。功能与上面的 llm_btn 按钮点击事件相同。
        msg.submit(
            respond,
            inputs=[msg, chatbot, llm, history_len, temperature],
            outputs=[msg, chatbot],
            show_progress="hidden",
        )
        # 点击后清空后端存储的聊天记录
        clear.click(model_center.clear_history)
    gr.Markdown(
        """提醒:<br>
    1. 使用时请先上传自己的知识文件,不然将会解析项目自带的知识库。
    2. 初始化数据库时间可能较长,请耐心等待。
    3. 使用中如果出现异常,将会在文本输入框进行展示,请不要惊慌。 <br>
    """
    )
# threads to consume the request
gr.close_all()
# 启动新的 Gradio 应用,设置分享功能为 True,并使用环境变量 PORT1 指定服务器端口。
# demo.launch(share=True, server_port=int(os.environ['PORT1']))
# 直接启动
demo.launch(share=True)