xiaoximew's picture
Update
04aa62c
import os
from typing import Dict
import gradio as gr
import pandas as pd
from chat_task.chat import generate_chat
from doc_qa_task.doc_qa import generate_doc_qa
from examples import (
load_examples,
preprocess_docqa_examples,
preprocess_extraction_examples,
preprocess_qa_generator_examples,
)
from extract_data_task.extract import extract_slots
from plugin_task.api import api_plugin_chat
from qa_generator_task.generate_qa import generate_qa_pairs
from plugin_task.plugins import PLUGIN_JSON_SCHEMA
abs_path = os.path.abspath(__file__)
current_dir = os.path.dirname(abs_path)
statistic_path = os.path.join(current_dir, "images")
load_examples()
def clear_session():
"""Clears the chat session."""
return "", None
def clear_plugin_session(session: Dict):
"""Clears the plugin session."""
session.clear()
return session, None, None
def show_custom_fallback_textbox(x):
if x == "自定义话术":
return [gr.Row(visible=True), gr.Textbox()]
else:
return [gr.Row(visible=False), gr.Textbox()]
def validate_field_word_count(
input_text: str, description: str, max_word_count: int = 3000
):
"""
Validate the input text for word count
:param input_text:
:return:
"""
if len(input_text) == 0:
raise gr.Error(f"{description}不能为空")
if len(input_text) > max_word_count:
raise gr.Error(f"{description}字数不能超过{max_word_count}字")
def validate_chat(input_text: str):
"""
Validate the input text
:param input_text:
:return:
"""
validate_field_word_count(input_text, "输入", 500)
def validate_doc_qa(
input_text: str,
doc_df: "pd.DataFrame",
fallback_ratio: str,
fallback_text_input: str,
):
"""
Validate fields of doc_qa
:param input_text:
:param doc_df:
:param fallback_ratio:
:param fallback_text_input:
:return:
"""
# add all the doc ids to the input text
if fallback_ratio == "自定义话术":
validate_field_word_count(fallback_text_input, "自定义话术", 100)
validate_field_word_count(input_text, "输入", 500)
page_content_full_text = (
" ".join(doc_df["文档片段名称"].tolist())
+ " "
+ " ".join(doc_df["文档片段内容"].tolist())
)
validate_field_word_count(page_content_full_text, "文档信息", 2500)
def validate_qa_pair_generator(input_text: str):
"""
Validate the input text
:param input_text:
:return:
"""
return validate_field_word_count(input_text, "输入")
def validate_extraction(
input_text: str,
extraction_df: "pd.DataFrame",
):
"""
Validate fields of extraction
"""
extraction_full_text = (
" ".join(extraction_df["字段名称"].tolist())
+ " "
+ " ".join(extraction_df["字段描述"].tolist())
)
validate_field_word_count(input_text, "输入", 1500)
validate_field_word_count(extraction_full_text, "待抽取字段描述", 1500)
def validate_plugin(input_text: str):
"""
Validate the input text
:param input_text:
:return:
"""
validate_field_word_count(input_text, "输入", 500)
with gr.Blocks(
title="Orion-14B",
theme="shivi/calm_seafoam@>=0.0.1,<1.0.0",
) as demo:
def user(user_message, history):
return user_message, (history or []) + [[user_message, ""]]
gr.Markdown(
"""
<div style="overflow: hidden;color:#fff;display: flex;flex-direction: column;align-items: center; position: relative; width: 100%; height: 180px;background-size: cover; background-image: url(https://www.orionstar.com/res/orics/down/ow001_20240119_8369eca9013416109a2303bf4e329140.png);">
<img style="width: 130px;height: 60px;position: absolute;top:10px;left:10px" src="https://www.orionstar.com/res/orics/down/ow001_20240119_1236eba7ea0ac15931f4518d7f211d47.png"/>
<img style="min-width: 1416px; width: 1416px;height: 100px;margin-top: 30px;" src="https://www.orionstar.com/res/orics/down/ow001_20240119_10c5ca12a57116bda0e35916a28b247f.png"/>
<span style="margin-top: 10px;font-size: 12px;">请在<a href="https://github.com/OrionStarAI/Orion" style="color: white;">Github</a>点击Star支持我们,加入<a href="https://www.orionstar.com/res/orics/down/ow001_20240122_d87e5b4ea66a31493c38fcffe7bdb453.png" style="color: white;">官方微信交流群</a></span>
</div>
"""
)
with gr.Tab("基础能力"):
chatbot = gr.Chatbot(
label="Orion-14B-Chat",
elem_classes="control-height",
show_copy_button=True,
min_width=1368,
height=416,
)
chat_text_input = gr.Textbox(label="输入", min_width=1368)
with gr.Row():
with gr.Column(scale=2):
gr.Examples(
[
"可以给我讲个笑话吗?",
"什么是伟大的诗歌?",
"你知道李白吗?",
"黑洞是如何工作的?",
"在表中插入一条数据,id为1,name为张三,age为18,请问SQL语句是什么?",
],
chat_text_input,
label="试试问",
)
with gr.Column(scale=1):
with gr.Row(variant="compact"):
clear_history = gr.Button(
"清除历史",
min_width="17",
size="sm",
scale=1,
icon=os.path.join(statistic_path, "clear.png"),
)
submit = gr.Button(
"发送",
variant="primary",
min_width="17",
size="sm",
scale=1,
icon=os.path.join(statistic_path, "send.svg"),
)
chat_text_input.submit(
fn=validate_chat, inputs=[chat_text_input], outputs=[], queue=False
).success(
user, [chat_text_input, chatbot], [chat_text_input, chatbot], queue=False
).success(
fn=generate_chat,
inputs=[chat_text_input, chatbot],
outputs=[chat_text_input, chatbot],
)
submit.click(
fn=validate_chat, inputs=[chat_text_input], outputs=[], queue=False
).success(
user, [chat_text_input, chatbot], [chat_text_input, chatbot], queue=False
).success(
fn=generate_chat,
inputs=[chat_text_input, chatbot],
outputs=[chat_text_input, chatbot],
api_name="chat",
)
clear_history.click(
fn=clear_session, inputs=[], outputs=[chat_text_input, chatbot], queue=False
)
with gr.Tab("基于文档问答"):
with gr.Row():
with gr.Column(scale=3, min_width=357, variant="panel"):
gr.Markdown(
'<span style="color:rgba(0, 0, 0, 0.5); font-size: 14px; font-weight: 400; line-height: 28px; letter-spacing: 0em; text-align: left; width: 42px; height: 14px; left: 36px; top: 255px;">配置项</span>'
)
citations_radio = gr.Radio(
["开启引用", "关闭引用"], label="引用", value="关闭引用"
)
fallback_radio = gr.Radio(
["使用大模型知识", "自定义话术"],
label="超纲问题回复",
value="自定义话术",
)
fallback_text_input = gr.Textbox(
label="自定义话术",
value="抱歉,我还在学习中,暂时无法回答您的问题。",
)
gr.Markdown(
'<span style="color:rgba(0, 0, 0, 0.5); font-size: 14px; font-weight: 400; line-height: 28px; letter-spacing: 0em; text-align: left; width: 42px; height: 14px; left: 36px; top: 255px;">文档信息</span>'
)
doc_df = gr.Dataframe(
headers=["文档片段内容", "文档片段名称"],
datatype=["str", "str"],
row_count=6,
col_count=(2, "fixed"),
label="",
interactive=True,
wrap=True,
elem_classes="control-height",
height=300,
)
with gr.Column(scale=2, min_width=430):
chatbot = gr.Chatbot(
label="适用场景:预期LLM通过自由知识回答",
elem_classes="control-height",
show_copy_button=True,
min_width=999,
height=419,
)
doc_qa_input = gr.Textbox(label="输入", min_width=999, max_lines=10)
with gr.Row():
with gr.Column(scale=2):
gr.Examples(
[
"哪些情况下不能超车?",
"参观须知",
"青岛啤酒酒精含量是多少?",
],
doc_qa_input,
label="试试问",
cache_examples=True,
fn=preprocess_docqa_examples,
outputs=[doc_df],
)
with gr.Column(scale=1):
with gr.Row(variant="compact"):
clear_history = gr.Button(
"清除历史",
min_width="17",
size="sm",
scale=1,
icon=os.path.join(statistic_path, "clear.png"),
)
submit = gr.Button(
"发送",
variant="primary",
min_width="17",
size="sm",
scale=1,
icon=os.path.join(statistic_path, "send.svg"),
)
doc_qa_input.submit(
fn=validate_doc_qa,
inputs=[
doc_qa_input,
doc_df,
fallback_radio,
fallback_text_input,
],
outputs=[],
queue=False,
).success(
user, [doc_qa_input, chatbot], [doc_qa_input, chatbot], queue=False
).success(
fn=generate_doc_qa,
inputs=[
doc_qa_input,
chatbot,
doc_df,
fallback_radio,
fallback_text_input,
citations_radio,
],
outputs=[doc_qa_input, chatbot],
scroll_to_output=True,
api_name="doc_qa",
)
submit.click(
fn=validate_doc_qa,
inputs=[
doc_qa_input,
doc_df,
fallback_radio,
fallback_text_input,
],
outputs=[],
queue=False,
).success(
user, [doc_qa_input, chatbot], [doc_qa_input, chatbot], queue=False
).success(
fn=generate_doc_qa,
inputs=[
doc_qa_input,
chatbot,
doc_df,
fallback_radio,
fallback_text_input,
citations_radio,
],
outputs=[doc_qa_input, chatbot],
scroll_to_output=True,
)
clear_history.click(
fn=lambda x: (None, None, None),
inputs=[],
outputs=[doc_df, doc_qa_input, chatbot],
queue=False,
)
with gr.Tab("插件能力"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(
'<span style="color:rgba(0, 0, 0, 0.5); font-size: 14px; font-weight: 400; line-height: 28px; letter-spacing: 0em; text-align: left; width: 42px; height: 14px; left: 36px; top: 255px;">配置项</span>'
)
radio_plugins = [
gr.Radio(
["开启", "关闭"],
label=plugin_json["name_for_human"],
value="开启",
)
for plugin_json in PLUGIN_JSON_SCHEMA
]
with gr.Column(scale=3):
session = gr.State(value=dict())
chatbot = gr.Chatbot(
label="适用场景:需要LLM调用API解决问题",
elem_classes="control-height",
show_copy_button=True,
)
plugin_text_input = gr.Textbox(label="输入")
with gr.Row():
with gr.Column(scale=2):
gr.Examples(
[
"北京天气怎么样?",
"查询物流信息",
"每日壁纸",
"bing今天的壁纸是什么",
"查询手机号码归属地",
],
plugin_text_input,
label="试试问",
)
with gr.Column(scale=1):
with gr.Row(variant="compact"):
clear_history = gr.Button(
"清除历史",
min_width="17",
size="sm",
scale=1,
icon=os.path.join(statistic_path, "clear.png"),
)
submit = gr.Button(
"发送",
variant="primary",
min_width="17",
size="sm",
scale=1,
icon=os.path.join(statistic_path, "send.svg"),
)
plugin_text_input.submit(
fn=validate_plugin,
inputs=[
plugin_text_input,
],
outputs=[],
queue=False,
).success(
user,
[plugin_text_input, chatbot],
[plugin_text_input, chatbot],
scroll_to_output=True,
).success(
fn=api_plugin_chat,
inputs=[session, plugin_text_input, chatbot, *radio_plugins],
outputs=[session, plugin_text_input, chatbot],
scroll_to_output=True,
)
submit.click(
fn=validate_plugin,
inputs=[
plugin_text_input,
],
outputs=[],
queue=False,
).success(
user,
[plugin_text_input, chatbot],
[plugin_text_input, chatbot],
scroll_to_output=True,
).success(
fn=api_plugin_chat,
inputs=[session, plugin_text_input, chatbot, *radio_plugins],
outputs=[session, plugin_text_input, chatbot],
api_name="plugin",
scroll_to_output=True,
)
clear_history.click(
fn=clear_plugin_session,
inputs=[session],
outputs=[session, plugin_text_input, chatbot],
queue=False,
)
with gr.Tab("生成QA对"):
with gr.Row(equal_height=True):
qa_generator_output = gr.Code(
language="json",
show_label=False,
min_width=1368,
)
with gr.Row():
qa_generator_input = gr.Textbox(
label="输入",
show_label=True,
info="",
min_width=1368,
lines=5,
max_lines=10,
)
with gr.Row():
with gr.Column(scale=2):
gr.Examples(
[
"第一章 总 则 \n第...",
"金字塔,在建筑学上是...",
"山西老陈醋是以高粱、...",
"室内装饰构造虚拟仿真...",
"猎户星空(Orion...",
],
qa_generator_input,
label="试试问",
cache_examples=True,
fn=preprocess_qa_generator_examples,
outputs=[qa_generator_input],
)
with gr.Column(scale=1):
with gr.Row(variant="compact"):
clear = gr.Button(
"清除",
min_width="17",
size="sm",
scale=1,
icon=os.path.join(statistic_path, "clear.png"),
)
submit = gr.Button(
"发送",
variant="primary",
min_width="17",
size="sm",
scale=1,
icon=os.path.join(statistic_path, "send.svg"),
)
submit.click(
fn=validate_qa_pair_generator,
inputs=[qa_generator_input],
outputs=[],
).success(
fn=generate_qa_pairs,
inputs=[qa_generator_input],
outputs=[qa_generator_output, qa_generator_input],
scroll_to_output=True,
api_name="qa_generator",
)
clear.click(
fn=lambda x: ("", ""),
inputs=[],
outputs=[qa_generator_input, qa_generator_output],
queue=False,
)
with gr.Tab("抽取数据"):
extract_outpu_df = gr.Dataframe(
label="",
headers=["字段名称", "字段抽取结果"],
datatype=["str", "str"],
col_count=(2, "fixed"),
wrap=True,
elem_classes="control-height",
height=234,
row_count=5,
)
extract_input = gr.Textbox(label="输入", lines=5, min_width=1368, max_lines=10)
extraction_df = gr.Dataframe(
headers=["字段名称", "字段描述"],
datatype=["str", "str"],
row_count=3,
col_count=(2, "fixed"),
label="",
interactive=True,
wrap=True,
elem_classes="control-height",
height=180,
)
with gr.Row():
with gr.Column(scale=2):
gr.Examples(
["第一条合同当...", "发票编号: IN...", "发件人:John..."],
extract_input,
label="试试问",
cache_examples=True,
fn=preprocess_extraction_examples,
outputs=[extract_input, extraction_df],
)
with gr.Column(scale=1):
with gr.Row(variant="compact"):
clear = gr.Button(
"清除历史",
min_width="17",
size="sm",
scale=1,
icon=os.path.join(statistic_path, "clear.png"),
)
submit = gr.Button(
"发送",
variant="primary",
min_width="17",
size="sm",
scale=1,
icon=os.path.join(statistic_path, "send.svg"),
)
submit.click(
fn=validate_extraction,
inputs=[extract_input, extraction_df],
outputs=[],
).success(
fn=extract_slots,
inputs=[extract_input, extraction_df],
outputs=[extract_outpu_df],
scroll_to_output=True,
api_name="extract",
)
clear.click(
fn=lambda x: ("", None, None),
inputs=[],
outputs=[
extract_input,
extraction_df,
extract_outpu_df,
],
queue=False,
)
if __name__ == "__main__":
demo.queue(api_open=False, max_size=40).launch(
height=800,
share=False,
server_name="0.0.0.0",
show_api=False,
max_threads=4,
)