Joker1212's picture
Upload 4 files
e1b00d2 verified
import gradio as gr
import os
import cv2
from rapid_table_det.inference import TableDetector
from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img
example_images = [
"images/doc1.png",
"images/doc2.jpg",
"images/doc3.jpg",
"images/doc4.jpg",
"images/doc5.jpg",
"images/real1.jpg",
"images/real2.jpeg",
"images/real3.jpg",
"images/real4.jpg",
"images/real5.jpg"
]
# 定义模型类型选项
model_type_options = {
"YOLO 目标检测": ["yolo_obj_det"],
"Paddle 目标检测": ["paddle_obj_det"],
"Paddle 目标检测 (量化)": ["paddle_obj_det_s"],
"YOLO 语义分割": ["yolo_edge_det"],
"YOLO 语义分割 (小型)": ["yolo_edge_det_s"],
"Paddle 语义分割": ["paddle_edge_det"],
"Paddle 语义分割 (量化)": ["paddle_edge_det_s"],
"Paddle 方向分类": ["paddle_cls_det"]
}
# 预生成所有可能的 TableDetector 实例
preinitialized_detectors = {}
for obj_model_type in model_type_options["YOLO 目标检测"] + model_type_options["Paddle 目标检测"] + model_type_options[
"Paddle 目标检测 (量化)"]:
for edge_model_type in model_type_options["YOLO 语义分割"] + model_type_options["YOLO 语义分割 (小型)"] + model_type_options[
"Paddle 语义分割"] + model_type_options["Paddle 语义分割 (量化)"]:
for cls_model_type in model_type_options["Paddle 方向分类"]:
detector_key = (obj_model_type, edge_model_type, cls_model_type)
preinitialized_detectors[detector_key] = TableDetector(
obj_model_type=obj_model_type,
edge_model_type=edge_model_type,
cls_model_type=cls_model_type,
obj_model_path=os.path.join("models", f"{obj_model_type}.onnx"),
edge_model_path=os.path.join("models", f"{edge_model_type}.onnx"),
cls_model_path=os.path.join("models", f"{cls_model_type}.onnx")
)
# 定义图片缩放函数
def resize_image(image, max_size=640):
height, width = image.shape[:2]
if max(height, width) > max_size:
scale = max_size / max(height, width)
new_height = int(height * scale)
new_width = int(width * scale)
image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
return image
# 定义推理函数
def run_inference(img_path, obj_model_type, edge_model_type, cls_model_type, det_accuracy, use_obj_det, use_edge_det,
use_cls_det):
detector_key = (obj_model_type, edge_model_type, cls_model_type)
table_det = preinitialized_detectors[detector_key]
result, elapse = table_det(
img_path,
det_accuracy=det_accuracy,
use_obj_det=use_obj_det,
use_edge_det=use_edge_det,
use_cls_det=use_cls_det
)
# 加载图片
img = img_loader(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
extract_img = img.copy()
visual_img = img.copy()
extract_imgs = []
for i, res in enumerate(result):
box = res["box"]
lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"]
# 带识别框和左上角方向位置
visual_img = visuallize(visual_img, box, lt, rt, rb, lb)
# 透视变换提取表格图片
wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb)
extract_imgs.append(wrapped_img)
# 缩放图片
visual_img = resize_image(visual_img)
extract_imgs = [resize_image(img) for img in extract_imgs]
obj_det_elapse, edge_elapse, rotate_det_elapse = elapse
return visual_img, extract_imgs, f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}"
def update_extract_outputs(visual_img, extract_imgs, time_info):
if len(extract_imgs) == 1:
return visual_img, extract_imgs[0], time_info
else:
return visual_img, extract_imgs, time_info
# 创建Gradio界面
with gr.Blocks(
css="""
.scrollable-container {
overflow-x: auto;
white-space: nowrap;
}
.header-links {
text-align: center;
}
.header-links a {
display: inline-block;
text-align: center;
margin-right: 10px; /* 调整间距 */
}
"""
) as demo:
gr.HTML(
"<h1 style='text-align: center;'><a href='https://github.com/RapidAI/RapidTableDetection'>RapidTableDetection</a></h1>"
)
gr.HTML('''
<div class="header-links">
<a href=""><img src="https://img.shields.io/badge/Python->=3.8,<3.12-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Mac%2C%20Win-pink.svg"></a>
<a href="https://semver.org/"><img alt="SemVer2.0" src="https://img.shields.io/badge/SemVer-2.0-brightgreen"></a>
<a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
<a href="https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE"><img alt="GitHub" src="https://img.shields.io/badge/license-Apache 2.0-blue"></a>
</div>
''')
with gr.Row():
with gr.Column(variant="panel", scale=1):
img_input = gr.Image(label="Upload or Select Image", sources="upload", value="images/real1.jpg")
# 示例图片选择器
examples = gr.Examples(
examples=example_images,
examples_per_page=len(example_images),
inputs=img_input,
fn=lambda x: x, # 简单返回图片路径
outputs=img_input,
cache_examples=False
)
obj_model_type = gr.Dropdown(
choices=model_type_options["YOLO 目标检测"] + model_type_options["Paddle 目标检测"] +
model_type_options["Paddle 目标检测 (量化)"],
value="yolo_obj_det",
label="obj det model")
edge_model_type = gr.Dropdown(
choices=model_type_options["YOLO 语义分割"] + model_type_options["YOLO 语义分割 (小型)"] +
model_type_options["Paddle 语义分割"] + model_type_options["Paddle 语义分割 (量化)"],
value="yolo_edge_det",
label="edge seg model")
cls_model_type = gr.Dropdown(choices=model_type_options["Paddle 方向分类"],
value="paddle_cls_det",
label="direction cls model")
det_accuracy = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="目标检测置信度阈值")
use_obj_det = gr.Checkbox(value=True, label="use obj det")
use_edge_det = gr.Checkbox(value=True, label="use edge seg")
use_cls_det = gr.Checkbox(value=True, label="use direction cls")
run_button = gr.Button("run")
with gr.Column(scale=2):
visual_output = gr.Image(label="output visualize")
extract_outputs = gr.Gallery(label="extracted images", object_fit="contain", columns=1, preview=True)
time_output = gr.Textbox(label="elapsed")
run_button.click(
fn=run_inference,
inputs=[img_input, obj_model_type, edge_model_type, cls_model_type, det_accuracy, use_obj_det, use_edge_det,
use_cls_det],
outputs=[visual_output, extract_outputs, time_output]
)
# 启动Gradio应用
demo.launch()