UIE-X / app.py
sijunhe's picture
app.py
8bf7635
raw
history blame
13.7 kB
#-*- coding: UTF-8 -*-
# Copyright 2022 the HuggingFace Team.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
import gradio as gr
from paddlenlp import Taskflow
from paddlenlp.utils.doc_parser import DocParser
doc_parser = DocParser()
task_instance = Taskflow(
"information_extraction",
model="uie-x-base",
task_path="PaddlePaddle/uie-x-base",
from_hf_hub=True,
schema="")
examples = [
[
"business_card.png",
"Name;Title;Web Link;Email;Address",
],
[
"license.jpeg",
"Name;DOB;ISS;EXP",
],
[
"invoice.jpeg",
"名称;纳税人识别号;开票日期",
],
[
"custom.jpeg",
"收发货人;进口口岸;进口日期;运输方式;征免性质;境内目的地;运输工具名称;包装种类;件数;合同协议号"
],
[
"resume.png",
"职位;年龄;学校|时间;学校|专业",
],
]
example_files = {
"Name;Title;Web Link;Email;Address": "business_card.png",
"Name;DOB;ISS;EXP": "license.jpeg",
"职位;年龄;学校|时间;学校|专业": "resume.png",
"收发货人;进口口岸;进口日期;运输方式;征免性质;境内目的地;运输工具名称;包装种类;件数;合同协议号": "custom.jpeg",
"名称;纳税人识别号;开票日期": "invoice.jpeg",
}
lang_map = {
"resume.png": "ch",
"custom.jpeg": "ch",
"business_card.png": "en",
"invoice.jpeg": "ch",
"license.jpeg": "en",
}
def dbc2sbc(s):
rs = ""
for char in s:
code = ord(char)
if code == 0x3000:
code = 0x0020
else:
code -= 0xfee0
if not (0x0021 <= code and code <= 0x7e):
rs += char
continue
rs += chr(code)
return rs
def process_path(path):
error = None
if path:
try:
images_list = [doc_parser.read_image(path)]
return (
path,
gr.update(visible=True, value=images_list),
gr.update(visible=True),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
None,
)
except Exception as e:
traceback.print_exc()
error = str(e)
return (
None,
gr.update(visible=False, value=None),
gr.update(visible=False),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
gr.update(visible=True, value=error) if error is not None else None,
None,
)
def process_upload(file):
if file:
return process_path(file.name)
else:
return (
None,
gr.update(visible=False, value=None),
gr.update(visible=False),
gr.update(visible=False, value=None),
gr.update(visible=False, value=None),
None,
)
def get_schema(schema_str):
def _is_ch(s):
for ch in s:
if "\u4e00" <= ch <= "\u9fff":
return True
return False
schema_lang = "ch" if _is_ch(schema_str) else "en"
schema = schema_str.split(";")
schema_list = []
for s in schema:
cand = s.split("|")
if len(cand) == 1:
schema_list.append(cand[0])
else:
subject = cand[0]
relations = cand[1:]
added = False
for a in schema_list:
if isinstance(a, dict):
if subject in a.keys():
a[subject].extend(relations)
added = True
break
if not added:
a = {subject: relations}
schema_list.append(a)
return schema_list, schema_lang
def run_taskflow(document, schema, argument):
task_instance.set_schema(schema)
task_instance.set_argument(argument)
return task_instance({'doc': document})
def process_doc(document, schema, ocr_lang, layout_analysis):
if not schema:
schema = '时间;组织机构;人物'
if document is None:
return None, None
schema, schema_lang = get_schema(dbc2sbc(schema))
argument = {
"ocr_lang": ocr_lang,
"schema_lang": schema_lang,
"layout_analysis": layout_analysis
}
prediction = run_taskflow(document, schema, argument)[0]
img_show = doc_parser.write_image_with_results(
document,
result=prediction,
return_image=True)
img_list = [img_show]
return (
gr.update(visible=True, value=img_list),
gr.update(visible=True, value=prediction),
)
def load_example_document(img, schema, ocr_lang, layout_analysis):
if img is not None:
document = example_files[schema]
choice = lang_map[document].split("-")
ocr_lang = choice[0]
layout_analysis = False if len(choice) == 1 else True
preview, answer = process_doc(document, schema, ocr_lang, layout_analysis)
return document, schema, preview, gr.update(visible=True), answer
else:
return None, None, None, gr.update(visible=False), None
def read_content(file_path: str) -> str:
"""read the content of target file
"""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
CSS = """
#prompt input {
font-size: 16px;
}
#url-textbox {
padding: 0 !important;
}
#short-upload-box .w-full {
min-height: 10rem !important;
}
/* I think something like this can be used to re-shape
* the table
*/
/*
.gr-samples-table tr {
display: inline;
}
.gr-samples-table .p-2 {
width: 100px;
}
*/
#select-a-file {
width: 100%;
}
#file-clear {
padding-top: 2px !important;
padding-bottom: 2px !important;
padding-left: 8px !important;
padding-right: 8px !important;
margin-top: 10px;
}
.gradio-container .gr-button-primary {
background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
border: 1px solid #B0DCCC;
border-radius: 8px;
color: #1B8700;
}
.gradio-container.dark button#submit-button {
background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
border: 1px solid #B0DCCC;
border-radius: 8px;
color: #1B8700
}
table.gr-samples-table tr td {
border: none;
outline: none;
}
table.gr-samples-table tr td:first-of-type {
width: 0%;
}
div#short-upload-box div.absolute {
display: none !important;
}
gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
gap: 0px 2%;
}
gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
gap: 0px;
}
gradio-app h2, .gradio-app h2 {
padding-top: 10px;
}
#answer {
overflow-y: scroll;
color: white;
background: #666;
border-color: #666;
font-size: 20px;
font-weight: bold;
}
#answer span {
color: white;
}
#answer textarea {
color:white;
background: #777;
border-color: #777;
font-size: 18px;
}
#url-error input {
color: red;
}
"""
with gr.Blocks(css=CSS) as demo:
gr.HTML(read_content("header.html"))
gr.Markdown(
"Open-sourced by PaddleNLP, **UIE-X 🧾 ** is a universal information extraction engine for both scanned document and text inputs. It supports Entity Extraction, Relation Extraction and Event Extraction tasks."
"UIE-X performs well on a zero-shot settings, which is enabled by a flexible schema that allows you to specify extraction targets with simple natural language."
"Moreover, on PaddleNLP, we provide a comprehensive and easy-to-use fine-tuning and few-shot customization workflow."
"For more details, please visit our [GitHub](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications/information_extraction)"
)
document = gr.Variable()
is_text = gr.Variable()
example_schema = gr.Textbox(visible=False)
example_image = gr.Image(visible=False)
with gr.Row(equal_height=True):
with gr.Column():
with gr.Row():
gr.Markdown("## 1. Select a file 选择文件", elem_id="select-a-file")
img_clear_button = gr.Button(
"Clear", variant="secondary", elem_id="file-clear", visible=False
)
image = gr.Gallery(visible=False)
with gr.Row(equal_height=True):
with gr.Column():
with gr.Row():
url = gr.Textbox(
show_label=False,
placeholder="URL",
lines=1,
max_lines=1,
elem_id="url-textbox",
)
submit = gr.Button("Get")
url_error = gr.Textbox(
visible=False,
elem_id="url-error",
max_lines=1,
interactive=False,
label="Error",
)
gr.Markdown("— or —")
upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
gr.Examples(
examples=examples,
inputs=[example_image, example_schema],
)
with gr.Column():
gr.Markdown("## 2. Information Extraction 信息抽取 ")
gr.Markdown("### 👉 Set a schema 设置schema")
gr.Markdown("Entity extraction: entity type should be separated by ';', e.g. **Person;Organization**")
gr.Markdown("实体抽取:实体类别之间以';'分割,例如 **人物;组织机构**")
gr.Markdown("Relation extraction: set the subject and relation type, separated by '|', e.g. **Person|Date;Person|Email**")
gr.Markdown("关系抽取:需配置主体和关系类别,中间以'|'分割,例如 **人物|出生时间;人物|邮箱**")
gr.Markdown("### 👉 Model customization 模型定制")
gr.Markdown("We recommend to further improve the extraction performance in specific domain through the process of [data annotation & fine-tuning](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications/information_extraction/document)")
gr.Markdown("我们建议通过[数据标注+微调](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications/information_extraction/document)的流程进一步增强模型在特定场景的效果")
schema = gr.Textbox(
label="Schema",
placeholder="e.g. Name|Company;Name|Position;Email;Phone Number",
lines=1,
max_lines=1,
)
ocr_lang = gr.Radio(
choices=["ch", "en"],
value="en",
label="OCR语言 / OCR Language (Please choose ch for Chinese images.)",
)
layout_analysis = gr.Radio(
choices=["yes", "no"],
value="no",
label="版面分析 / Layout analysis (Better extraction for multi-line text)",
)
with gr.Row():
clear_button = gr.Button("Clear", variant="secondary")
submit_button = gr.Button(
"Submit", variant="primary", elem_id="submit-button"
)
with gr.Column():
output = gr.JSON(label="Output", visible=False)
for cb in [img_clear_button, clear_button]:
cb.click(
lambda _: (
gr.update(visible=False, value=None),
None,
gr.update(visible=False, value=None),
gr.update(visible=False),
None,
None,
None,
gr.update(visible=False, value=None),
None,
),
inputs=clear_button,
outputs=[
image,
document,
output,
img_clear_button,
example_image,
upload,
url,
url_error,
schema,
],
)
upload.change(
fn=process_upload,
inputs=[upload],
outputs=[document, image, img_clear_button, output, url_error],
)
submit.click(
fn=process_path,
inputs=[url],
outputs=[document, image, img_clear_button, output, url_error],
)
schema.submit(
fn=process_doc,
inputs=[document, schema, ocr_lang, layout_analysis],
outputs=[image, output],
)
submit_button.click(
fn=process_doc,
inputs=[document, schema, ocr_lang, layout_analysis],
outputs=[image, output],
)
example_image.change(
fn=load_example_document,
inputs=[example_image, example_schema, ocr_lang, layout_analysis],
outputs=[document, schema, image, img_clear_button, output],
)
gr.HTML(read_content("footer.html"))
if __name__ == "__main__":
demo.launch(enable_queue=False)