Spaces:
Running
Running
File size: 12,602 Bytes
184a47b 2930d96 184a47b 2930d96 184a47b 2930d96 184a47b 2930d96 184a47b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 |
# 导入必要的库
import sys
import os # 用于操作系统相关的操作,例如读取环境变量
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import IPython.display # 用于在 IPython 环境中显示数据,例如图片
import io # 用于处理流式数据(例如文件流)
import gradio as gr
from dotenv import load_dotenv, find_dotenv
from llm.call_llm import get_completion
from database.create_db import create_db_info
from qa_chain.Chat_QA_chain_self import Chat_QA_chain_self
from qa_chain.QA_chain_self import QA_chain_self
import re
# 导入 dotenv 库的函数
# dotenv 允许您从 .env 文件中读取环境变量
# 这在开发时特别有用,可以避免将敏感信息(如API密钥)硬编码到代码中
# 寻找 .env 文件并加载它的内容
# 这允许您使用 os.environ 来读取在 .env 文件中设置的环境变量
_ = load_dotenv(find_dotenv())
LLM_MODEL_DICT = {
# "openai": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0613", "gpt-4", "gpt-4-32k"],
# "wenxin": ["ERNIE-Bot", "ERNIE-Bot-4", "ERNIE-Bot-turbo"],
# "xinhuo": ["Spark-1.5", "Spark-2.0"],
"zhipuai": ["chatglm_pro", "chatglm_std", "chatglm_lite"]
}
LLM_MODEL_LIST = sum(list(LLM_MODEL_DICT.values()), [])
INIT_LLM = "chatglm_pro"
# EMBEDDING_MODEL_LIST = ['zhipuai', 'openai', 'm3e']
EMBEDDING_MODEL_LIST = ["zhipuai"]
INIT_EMBEDDING_MODEL = "zhipuai"
DEFAULT_DB_PATH = "./knowledge_db"
DEFAULT_PERSIST_PATH = "./vector_db/chroma"
AIGC_AVATAR_PATH = "./figures/aigc_avatar.png"
DATAWHALE_AVATAR_PATH = "./figures/datawhale_avatar.png"
AIGC_LOGO_PATH = "./figures/aigc_logo.png"
DATAWHALE_LOGO_PATH = "./figures/datawhale_logo.png"
def get_model_by_platform(platform):
return LLM_MODEL_DICT.get(platform, "")
class Model_center:
"""
存储问答 Chain 的对象
- chat_qa_chain_self: 以 (model, embedding) 为键存储的带历史记录的问答链。
- qa_chain_self: 以 (model, embedding) 为键存储的不带历史记录的问答链。
"""
def __init__(self):
self.chat_qa_chain_self = {}
self.qa_chain_self = {}
def chat_qa_chain_self_answer(
self,
question: str,
chat_history: list = [],
model: str = "glm-4",
embedding: str = "embedding-2",
temperature: float = 0.0,
top_k: int = 4,
history_len: int = 3,
file_path: str = DEFAULT_DB_PATH,
persist_path: str = DEFAULT_PERSIST_PATH,
):
"""
调用带历史记录的问答链进行回答
"""
if question == None or len(question) < 1:
return "", chat_history
try:
if (model, embedding) not in self.chat_qa_chain_self:
self.chat_qa_chain_self[(model, embedding)] = Chat_QA_chain_self(
model=model,
temperature=temperature,
top_k=top_k,
chat_history=chat_history,
file_path=file_path,
persist_path=persist_path,
embedding=embedding,
)
chain = self.chat_qa_chain_self[(model, embedding)]
return "", chain.answer(
question=question, temperature=temperature, top_k=top_k
)
except Exception as e:
return e, chat_history
def qa_chain_self_answer(
self,
question: str,
chat_history: list = [],
model: str = "glm-4",
embedding="embedding-2",
temperature: float = 0.0,
top_k: int = 4,
file_path: str = DEFAULT_DB_PATH,
persist_path: str = DEFAULT_PERSIST_PATH,
):
"""
调用不带历史记录的问答链进行回答
"""
if question == None or len(question) < 1:
return "", chat_history
try:
if (model, embedding) not in self.qa_chain_self:
self.qa_chain_self[(model, embedding)] = QA_chain_self(
model=model,
temperature=temperature,
top_k=top_k,
file_path=file_path,
persist_path=persist_path,
embedding=embedding,
)
chain = self.qa_chain_self[(model, embedding)]
chat_history.append((question, chain.answer(question, temperature, top_k)))
return "", chat_history
except Exception as e:
return e, chat_history
def clear_history(self):
if len(self.chat_qa_chain_self) > 0:
for chain in self.chat_qa_chain_self.values():
chain.clear_history()
def format_chat_prompt(message, chat_history):
"""
该函数用于格式化聊天 prompt。
参数:
message: 当前的用户消息。
chat_history: 聊天历史记录。
返回:
prompt: 格式化后的 prompt。
"""
# 初始化一个空字符串,用于存放格式化后的聊天 prompt。
prompt = ""
# 遍历聊天历史记录。
for turn in chat_history:
# 从聊天记录中提取用户和机器人的消息。
user_message, bot_message = turn
# 更新 prompt,加入用户和机器人的消息。
prompt = f"{prompt}\nUser: {user_message}\nAssistant: {bot_message}"
# 将当前的用户消息也加入到 prompt中,并预留一个位置给机器人的回复。
prompt = f"{prompt}\nUser: {message}\nAssistant:"
# 返回格式化后的 prompt。
return prompt
def respond(
message, chat_history, llm, history_len=3, temperature=0.1, max_tokens=2048
):
"""
该函数用于生成机器人的回复。
参数:
message: 当前的用户消息。
chat_history: 聊天历史记录。
返回:
"": 空字符串表示没有内容需要显示在界面上,可以替换为真正的机器人回复。
chat_history: 更新后的聊天历史记录
"""
if message == None or len(message) < 1:
return "", chat_history
try:
# 限制 history 的记忆长度
chat_history = chat_history[-history_len:] if history_len > 0 else []
# 调用上面的函数,将用户的消息和聊天历史记录格式化为一个 prompt。
formatted_prompt = format_chat_prompt(message, chat_history)
# 使用llm对象的predict方法生成机器人的回复(注意:llm对象在此代码中并未定义)。
bot_message = get_completion(
formatted_prompt, llm, temperature=temperature, max_tokens=max_tokens
)
# 将bot_message中\n换为<br/>
bot_message = re.sub(r"\\n", "<br/>", bot_message)
# 将用户的消息和机器人的回复加入到聊天历史记录中。
chat_history.append((message, bot_message))
# 返回一个空字符串和更新后的聊天历史记录(这里的空字符串可以替换为真正的机器人回复,如果需要显示在界面上)。
return "", chat_history
except Exception as e:
return e, chat_history
model_center = Model_center()
block = gr.Blocks()
with block as demo:
with gr.Row(equal_height=True):
# gr.Image(value=AIGC_LOGO_PATH, scale=1, min_width=10, show_label=False, show_download_button=False, container=False)
with gr.Column(scale=2):
gr.Markdown(
"""<h1><center>大模型应用开发</center></h1>
<center>LLM-UNIVERSE</center>
"""
)
# gr.Image(value=DATAWHALE_LOGO_PATH, scale=1, min_width=10, show_label=False, show_download_button=False, container=False)
with gr.Row():
with gr.Column(scale=4):
# chatbot = gr.Chatbot(height=400, show_copy_button=True, show_share_button=True, avatar_images=(AIGC_AVATAR_PATH, DATAWHALE_AVATAR_PATH))
chatbot = gr.Chatbot(
height=400, show_copy_button=True, show_share_button=True
)
# 创建一个文本框组件,用于输入 prompt。
msg = gr.Textbox(label="Prompt/问题")
with gr.Row():
# 创建提交按钮。
db_with_his_btn = gr.Button("Chat db with history")
db_wo_his_btn = gr.Button("Chat db without history")
llm_btn = gr.Button("Chat with llm")
with gr.Row():
# 创建一个清除按钮,用于清除聊天机器人组件的内容。
clear = gr.ClearButton(components=[chatbot], value="Clear console")
with gr.Column(scale=1):
file = gr.File(
label="请选择知识库目录",
file_count="directory",
file_types=[".txt", ".md", ".docx", ".pdf"],
)
with gr.Row():
init_db = gr.Button("知识库文件向量化")
model_argument = gr.Accordion("参数配置", open=False)
with model_argument:
temperature = gr.Slider(
0,
1,
value=0.01,
step=0.01,
label="llm temperature",
interactive=True,
)
top_k = gr.Slider(
1,
10,
value=3,
step=1,
label="vector db search top k",
interactive=True,
)
history_len = gr.Slider(
0, 5, value=3, step=1, label="history length", interactive=True
)
model_select = gr.Accordion("模型选择")
with model_select:
llm = gr.Dropdown(
LLM_MODEL_LIST,
label="large language model",
value=INIT_LLM,
interactive=True,
)
embeddings = gr.Dropdown(
EMBEDDING_MODEL_LIST,
label="Embedding model",
value=INIT_EMBEDDING_MODEL,
)
# 设置初始化向量数据库按钮的点击事件。当点击时,调用 create_db_info 函数,并传入用户的文件和希望使用的 Embedding 模型。
init_db.click(create_db_info, inputs=[file, embeddings], outputs=[msg])
# 设置按钮的点击事件。当点击时,调用上面定义的 chat_qa_chain_self_answer 函数,并传入用户的消息和聊天历史记录,然后更新文本框和聊天机器人组件。
db_with_his_btn.click(
model_center.chat_qa_chain_self_answer,
inputs=[msg, chatbot, llm, embeddings, temperature, top_k, history_len],
outputs=[msg, chatbot],
)
# 设置按钮的点击事件。当点击时,调用上面定义的 qa_chain_self_answer 函数,并传入用户的消息和聊天历史记录,然后更新文本框和聊天机器人组件。
db_wo_his_btn.click(
model_center.qa_chain_self_answer,
inputs=[msg, chatbot, llm, embeddings, temperature, top_k],
outputs=[msg, chatbot],
)
# 设置按钮的点击事件。当点击时,调用上面定义的 respond 函数,并传入用户的消息和聊天历史记录,然后更新文本框和聊天机器人组件。
llm_btn.click(
respond,
inputs=[msg, chatbot, llm, history_len, temperature],
outputs=[msg, chatbot],
show_progress="minimal",
)
# 设置文本框的提交事件(即按下Enter键时)。功能与上面的 llm_btn 按钮点击事件相同。
msg.submit(
respond,
inputs=[msg, chatbot, llm, history_len, temperature],
outputs=[msg, chatbot],
show_progress="hidden",
)
# 点击后清空后端存储的聊天记录
clear.click(model_center.clear_history)
gr.Markdown(
"""提醒:<br>
1. 使用时请先上传自己的知识文件,不然将会解析项目自带的知识库。
2. 初始化数据库时间可能较长,请耐心等待。
3. 使用中如果出现异常,将会在文本输入框进行展示,请不要惊慌。 <br>
"""
)
# threads to consume the request
gr.close_all()
# 启动新的 Gradio 应用,设置分享功能为 True,并使用环境变量 PORT1 指定服务器端口。
# demo.launch(share=True, server_port=int(os.environ['PORT1']))
# 直接启动
demo.launch(share=True)
|