OpenLLM / app.py
xusenlin's picture
Update app.py
d877fef
import logging
import os
import re
import shutil
import gradio as gr
import openai
import pandas as pd
from backoff import on_exception, expo
from sqlalchemy import create_engine
from tools.doc_qa import DocQAPromptAdapter
from tools.web.overwrites import postprocess, reload_javascript
from tools.web.presets import (
small_and_beautiful_theme,
title,
description,
description_top,
CONCURRENT_COUNT
)
from tools.web.utils import (
convert_to_markdown,
shared_state,
reset_textbox,
cancel_outputing,
transfer_input,
reset_state,
delete_last_conversation
)
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
)
openai.api_key = "xxx"
doc_adapter = DocQAPromptAdapter()
def add_llm(model_name, api_base, models):
""" 添加模型 """
models = models or {}
if model_name and api_base:
models.update(
{
model_name: api_base
}
)
choices = [m[0] for m in models.items()]
return "", "", models, gr.Dropdown.update(choices=choices, value=choices[0] if choices else None)
def set_openai_env(api_base):
""" 配置接口地址 """
openai.api_base = api_base
doc_adapter.embeddings.openai_api_base = api_base
def get_file_list():
""" 获取文件列表 """
if not os.path.exists("doc_store"):
return []
return os.listdir("doc_store")
file_list = get_file_list()
def upload_file(file):
""" 上传文件 """
if not os.path.exists("doc_store"):
os.mkdir("docs")
if file is not None:
filename = os.path.basename(file.name)
shutil.move(file.name, f"doc_store/{filename}")
file_list = get_file_list()
file_list.remove(filename)
file_list.insert(0, filename)
return gr.Dropdown.update(choices=file_list, value=filename)
def add_vector_store(filename, model_name, models, chunk_size, chunk_overlap):
""" 将文件转为向量数据存储 """
api_base = models[model_name]
set_openai_env(api_base)
doc_adapter.chunk_size = chunk_size
doc_adapter.chunk_overlap = chunk_overlap
if filename is not None:
vs_path = f"vector_store/{filename.split('.')[0]}-{filename.split('.')[-1]}"
if not os.path.exists(vs_path):
doc_adapter.create_vector_store(f"doc_store/{filename}", vs_path=vs_path)
msg = f"Successfully added vector store for {filename}!"
else:
doc_adapter.reset_vector_store(vs_path=vs_path)
msg = f"Successfully loaded vector store for {filename}!"
else:
msg = "Please select a file!"
return msg
def add_db(db_user, db_password, db_host, db_port, db_name, databases):
""" 添加数据库 """
databases = databases or {}
if db_user and db_password and db_host and db_port and db_name:
databases.update(
{
db_name: {
"user": db_user,
"password": db_password,
"host": db_host,
"port": int(db_port)
}
}
)
choices = [m[0] for m in databases.items()]
return "", "", "", "", "", databases, gr.Dropdown.update(choices=choices, value=choices[0] if choices else None)
def get_table_names(select_database, databases):
""" 获取数据库表名 """
if select_database:
db_config = databases[select_database]
con = create_engine(f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{select_database}")
tables = pd.read_sql("show tables;", con=con).values
tables = [t[0] for t in tables]
return gr.Dropdown.update(choices=tables, value=[tables[0]])
def get_sql_result(x, con):
q = r"sql\n(.+?);\n"
sql = re.findall(q, x, re.DOTALL)[0] + ";"
df = pd.read_sql(sql, con=con).iloc[:10, :]
return df.to_markdown(numalign="center", stralign="center")
@on_exception(expo, openai.error.RateLimitError, max_tries=5)
def chat_completions_create(params):
""" chat接口 """
return openai.ChatCompletion.create(**params)
def predict(
model_name,
models,
text,
chatbot,
history,
top_p,
temperature,
max_tokens,
memory_k,
is_kgqa,
single_turn,
is_dbqa,
select_database,
select_table,
databases,
):
api_base = models[model_name]
set_openai_env(api_base)
if text == "":
yield chatbot, history, "Empty context."
return
if history is None:
history = []
messages = []
if is_dbqa:
temperature = 0.0
db_config = databases[select_database]
con = create_engine(f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{select_database}")
table_schema = ""
for t in select_table:
table_schema += pd.read_sql(f"show create table {t};", con=con)["Create Table"][0] + "\n\n"
table_schema = table_schema.replace("DEFAULT NULL", "")
messages.append(
{
"role": "system",
"content": f"你现在是一名SQL助手,能够根据用户的问题生成准确的SQL查询。已知SQL的建表语句为:{table_schema}根据上述数据库信息,回答相关问题。"
},
)
else:
if not single_turn:
for h in history[-memory_k:]:
messages.extend(
[
{
"role": "user",
"content": h[0]
},
{
"role": "assistant",
"content": h[1]
}
]
)
messages.append(
{
"role": "user",
"content": doc_adapter(text) if is_kgqa else text
}
)
params = dict(
stream=True,
messages=messages,
model=model_name,
top_p=top_p,
temperature=temperature,
max_tokens=max_tokens
)
res = chat_completions_create(params)
x = ""
for openai_object in res:
delta = openai_object.choices[0]["delta"]
if "content" in delta:
x += delta["content"]
a, b = [[y[0], convert_to_markdown(y[1])] for y in history] + [
[text, convert_to_markdown(x)]
], history + [[text, x]]
yield a, b, "Generating..."
if shared_state.interrupted:
shared_state.recover()
try:
yield a, b, "Stop: Success"
return
except:
pass
if is_dbqa:
try:
res = get_sql_result(x, con)
a[-1][-1] += "\n\n" + convert_to_markdown(res)
b[-1][-1] += "\n\n" + convert_to_markdown(res)
except:
pass
try:
yield a, b, "Generate: Success"
except:
pass
def retry(
model_name,
models,
text,
chatbot,
history,
top_p,
temperature,
max_tokens,
memory_k,
is_kgqa,
single_turn,
is_dbqa,
select_database,
select_table,
databases,
):
logging.info("Retry...")
if len(history) == 0:
yield chatbot, history, "Empty context."
return
chatbot.pop()
inputs = history.pop()[0]
for x in predict(
model_name,
models,
inputs,
chatbot,
history,
top_p,
temperature,
max_tokens,
memory_k,
is_kgqa,
single_turn,
is_dbqa,
select_database,
select_table,
databases,
):
yield x
gr.Chatbot.postprocess = postprocess
with open("assets/custom.css", "r", encoding="utf-8") as f:
customCSS = f.read()
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
history = gr.State([])
user_question = gr.State("")
with gr.Row():
gr.HTML(title)
status_display = gr.Markdown("Success", elem_id="status_display")
gr.Markdown(description_top)
with gr.Row(scale=1).style(equal_height=True):
with gr.Column(scale=5):
with gr.Row():
chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%")
with gr.Row():
with gr.Column(scale=12):
user_input = gr.Textbox(
show_label=False, placeholder="Enter text"
).style(container=False)
with gr.Column(min_width=70, scale=1):
submitBtn = gr.Button("发送")
with gr.Column(min_width=70, scale=1):
cancelBtn = gr.Button("停止")
with gr.Row():
emptyBtn = gr.Button(
"🧹 新的对话",
)
retryBtn = gr.Button("🔄 重新生成")
delLastBtn = gr.Button("🗑️ 删除最旧对话")
with gr.Column():
with gr.Column(min_width=50, scale=1):
with gr.Tab(label="模型"):
model_name = gr.Textbox(
placeholder="chatglm",
label="模型名称",
)
api_base = gr.Textbox(
placeholder="https://0.0.0.0:80/v1",
label="模型接口地址",
)
add_model = gr.Button("添加模型")
with gr.Accordion(open=False, label="所有模型配置"):
models = gr.Json()
single_turn = gr.Checkbox(label="使用单轮对话", value=False)
select_model = gr.Dropdown(
choices=[m[0] for m in models.value.items()] if models.value else [],
value=[m[0] for m in models.value.items()][0] if models.value else None,
label="选择模型",
interactive=True,
)
with gr.Tab(label="知识库"):
is_kgqa = gr.Checkbox(
label="使用知识库问答",
value=False,
interactive=True,
)
gr.Markdown("""**基于本地知识库生成更加准确的回答!**""")
select_file = gr.Dropdown(
choices=file_list,
label="选择文件",
interactive=True,
value=file_list[0] if len(file_list) > 0 else None
)
file = gr.File(
label="上传文件",
visible=True,
file_types=['.txt', '.md', '.docx', '.pdf']
)
add_vs = gr.Button(value="添加到知识库")
with gr.Tab(label="数据库"):
with gr.Accordion(open=False, label="数据库配置"):
db_user = gr.Textbox(
placeholder="root",
label="用户名",
)
db_password = gr.Textbox(
placeholder="password",
label="密码",
type="password"
)
db_host = gr.Textbox(
placeholder="0.0.0.0",
label="主机",
)
db_port = gr.Textbox(
placeholder="3306",
label="端口",
)
db_name = gr.Textbox(
placeholder="test",
label="数据库名称",
)
add_database = gr.Button("添加数据库")
with gr.Accordion(open=False, label="所有数据库配置"):
databases = gr.Json()
select_database = gr.Dropdown(
choices=[d[0] for d in databases.value.items()] if databases.value else [],
value=[d[0] for d in databases.value.items()][0] if databases.value else None,
interactive=True,
label="选择数据库"
)
select_table = gr.Dropdown(label="选择表", interactive=True, multiselect=True)
is_dbqa = gr.Checkbox(
label="使用数据库问答",
value=False,
interactive=True,
)
with gr.Tab(label="参数"):
top_p = gr.Slider(
minimum=-0,
maximum=1.0,
value=0.95,
step=0.05,
interactive=True,
label="Top-p",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1,
step=0.1,
interactive=True,
label="Temperature",
)
max_tokens = gr.Slider(
minimum=0,
maximum=512,
value=512,
step=8,
interactive=True,
label="Max Generation Tokens",
)
memory_k = gr.Slider(
minimum=0,
maximum=10,
value=5,
step=1,
interactive=True,
label="Max Memory Window Size",
)
chunk_size = gr.Slider(
minimum=100,
maximum=1000,
value=200,
step=100,
interactive=True,
label="Chunk Size",
)
chunk_overlap = gr.Slider(
minimum=0,
maximum=100,
value=0,
step=10,
interactive=True,
label="Chunk Overlap",
)
gr.Markdown(description)
add_model.click(
add_llm,
inputs=[model_name, api_base, models],
outputs=[model_name, api_base, models, select_model],
)
add_database.click(
add_db,
inputs=[db_user, db_password, db_host, db_port, db_name, databases],
outputs=[db_user, db_password, db_host, db_port, db_name, databases, select_database],
)
select_database.change(
get_table_names,
inputs=[select_database, databases],
outputs=select_table,
)
file.upload(
upload_file,
inputs=file,
outputs=select_file,
)
add_vs.click(
add_vector_store,
inputs=[select_file, select_model, models, chunk_size, chunk_overlap],
outputs=status_display,
)
predict_args = dict(
fn=predict,
inputs=[
select_model,
models,
user_question,
chatbot,
history,
top_p,
temperature,
max_tokens,
memory_k,
is_kgqa,
single_turn,
is_dbqa,
select_database,
select_table,
databases,
],
outputs=[chatbot, history, status_display],
show_progress=True,
)
retry_args = dict(
fn=retry,
inputs=[
select_model,
models,
user_question,
chatbot,
history,
top_p,
temperature,
max_tokens,
memory_k,
is_kgqa,
single_turn,
is_dbqa,
select_database,
select_table,
databases,
],
outputs=[chatbot, history, status_display],
show_progress=True,
)
reset_args = dict(fn=reset_textbox, inputs=[], outputs=[user_input, status_display])
cancelBtn.click(cancel_outputing, [], [status_display])
transfer_input_args = dict(
fn=transfer_input,
inputs=[user_input],
outputs=[user_question, user_input, submitBtn, cancelBtn],
show_progress=True,
)
user_input.submit(**transfer_input_args).then(**predict_args)
submitBtn.click(**transfer_input_args).then(**predict_args)
emptyBtn.click(
reset_state,
outputs=[chatbot, history, status_display],
show_progress=True,
)
emptyBtn.click(**reset_args)
retryBtn.click(**retry_args)
delLastBtn.click(
delete_last_conversation,
[chatbot, history],
[chatbot, history, status_display],
show_progress=True,
)
demo.title = "OpenLLM Chatbot 🚀 "
if __name__ == "__main__":
reload_javascript()
demo.queue(concurrency_count=CONCURRENT_COUNT).launch()