Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import cv2 | |
| from rapid_table_det.inference import TableDetector | |
| from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img | |
| example_images = [ | |
| "images/doc1.png", | |
| "images/doc2.jpg", | |
| "images/doc3.jpg", | |
| "images/doc4.jpg", | |
| "images/doc5.jpg", | |
| "images/real1.jpg", | |
| "images/real2.jpeg", | |
| "images/real3.jpg", | |
| "images/real4.jpg", | |
| "images/real5.jpg" | |
| ] | |
| # 定义模型类型选项 | |
| model_type_options = { | |
| "YOLO 目标检测": ["yolo_obj_det"], | |
| "Paddle 目标检测": ["paddle_obj_det"], | |
| "Paddle 目标检测 (量化)": ["paddle_obj_det_s"], | |
| "YOLO 语义分割": ["yolo_edge_det"], | |
| "YOLO 语义分割 (小型)": ["yolo_edge_det_s"], | |
| "Paddle 语义分割": ["paddle_edge_det"], | |
| "Paddle 语义分割 (量化)": ["paddle_edge_det_s"], | |
| "Paddle 方向分类": ["paddle_cls_det"] | |
| } | |
| # 预生成所有可能的 TableDetector 实例 | |
| preinitialized_detectors = {} | |
| for obj_model_type in model_type_options["YOLO 目标检测"] + model_type_options["Paddle 目标检测"] + model_type_options[ | |
| "Paddle 目标检测 (量化)"]: | |
| for edge_model_type in model_type_options["YOLO 语义分割"] + model_type_options["YOLO 语义分割 (小型)"] + model_type_options[ | |
| "Paddle 语义分割"] + model_type_options["Paddle 语义分割 (量化)"]: | |
| for cls_model_type in model_type_options["Paddle 方向分类"]: | |
| detector_key = (obj_model_type, edge_model_type, cls_model_type) | |
| preinitialized_detectors[detector_key] = TableDetector( | |
| obj_model_type=obj_model_type, | |
| edge_model_type=edge_model_type, | |
| cls_model_type=cls_model_type, | |
| obj_model_path=os.path.join("models", f"{obj_model_type}.onnx"), | |
| edge_model_path=os.path.join("models", f"{edge_model_type}.onnx"), | |
| cls_model_path=os.path.join("models", f"{cls_model_type}.onnx") | |
| ) | |
| # 定义图片缩放函数 | |
| def resize_image(image, max_size=640): | |
| height, width = image.shape[:2] | |
| if max(height, width) > max_size: | |
| scale = max_size / max(height, width) | |
| new_height = int(height * scale) | |
| new_width = int(width * scale) | |
| image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) | |
| return image | |
| # 定义推理函数 | |
| def run_inference(img_path, obj_model_type, edge_model_type, cls_model_type, det_accuracy, use_obj_det, use_edge_det, | |
| use_cls_det): | |
| detector_key = (obj_model_type, edge_model_type, cls_model_type) | |
| table_det = preinitialized_detectors[detector_key] | |
| result, elapse = table_det( | |
| img_path, | |
| det_accuracy=det_accuracy, | |
| use_obj_det=use_obj_det, | |
| use_edge_det=use_edge_det, | |
| use_cls_det=use_cls_det | |
| ) | |
| # 加载图片 | |
| img = img_loader(img_path) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| extract_img = img.copy() | |
| visual_img = img.copy() | |
| extract_imgs = [] | |
| for i, res in enumerate(result): | |
| box = res["box"] | |
| lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"] | |
| # 带识别框和左上角方向位置 | |
| visual_img = visuallize(visual_img, box, lt, rt, rb, lb) | |
| # 透视变换提取表格图片 | |
| wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb) | |
| extract_imgs.append(wrapped_img) | |
| # 缩放图片 | |
| visual_img = resize_image(visual_img) | |
| extract_imgs = [resize_image(img) for img in extract_imgs] | |
| obj_det_elapse, edge_elapse, rotate_det_elapse = elapse | |
| return visual_img, extract_imgs, f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}" | |
| def update_extract_outputs(visual_img, extract_imgs, time_info): | |
| if len(extract_imgs) == 1: | |
| return visual_img, extract_imgs[0], time_info | |
| else: | |
| return visual_img, extract_imgs, time_info | |
| # 创建Gradio界面 | |
| with gr.Blocks( | |
| css=""" | |
| .scrollable-container { | |
| overflow-x: auto; | |
| white-space: nowrap; | |
| } | |
| .header-links { | |
| text-align: center; | |
| } | |
| .header-links a { | |
| display: inline-block; | |
| text-align: center; | |
| margin-right: 10px; /* 调整间距 */ | |
| } | |
| """ | |
| ) as demo: | |
| gr.HTML( | |
| "<h1 style='text-align: center;'><a href='https://github.com/RapidAI/RapidTableDetection'>RapidTableDetection</a></h1>" | |
| ) | |
| gr.HTML(''' | |
| <div class="header-links"> | |
| <a href=""><img src="https://img.shields.io/badge/Python->=3.8,<3.12-aff.svg"></a> | |
| <a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Mac%2C%20Win-pink.svg"></a> | |
| <a href="https://semver.org/"><img alt="SemVer2.0" src="https://img.shields.io/badge/SemVer-2.0-brightgreen"></a> | |
| <a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg"></a> | |
| <a href="https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE"><img alt="GitHub" src="https://img.shields.io/badge/license-Apache 2.0-blue"></a> | |
| </div> | |
| ''') | |
| with gr.Row(): | |
| with gr.Column(variant="panel", scale=1): | |
| img_input = gr.Image(label="Upload or Select Image", sources="upload", value="images/real1.jpg") | |
| # 示例图片选择器 | |
| examples = gr.Examples( | |
| examples=example_images, | |
| examples_per_page=len(example_images), | |
| inputs=img_input, | |
| fn=lambda x: x, # 简单返回图片路径 | |
| outputs=img_input, | |
| cache_examples=False | |
| ) | |
| obj_model_type = gr.Dropdown( | |
| choices=model_type_options["YOLO 目标检测"] + model_type_options["Paddle 目标检测"] + | |
| model_type_options["Paddle 目标检测 (量化)"], | |
| value="yolo_obj_det", | |
| label="obj det model") | |
| edge_model_type = gr.Dropdown( | |
| choices=model_type_options["YOLO 语义分割"] + model_type_options["YOLO 语义分割 (小型)"] + | |
| model_type_options["Paddle 语义分割"] + model_type_options["Paddle 语义分割 (量化)"], | |
| value="yolo_edge_det", | |
| label="edge seg model") | |
| cls_model_type = gr.Dropdown(choices=model_type_options["Paddle 方向分类"], | |
| value="paddle_cls_det", | |
| label="direction cls model") | |
| det_accuracy = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="目标检测置信度阈值") | |
| use_obj_det = gr.Checkbox(value=True, label="use obj det") | |
| use_edge_det = gr.Checkbox(value=True, label="use edge seg") | |
| use_cls_det = gr.Checkbox(value=True, label="use direction cls") | |
| run_button = gr.Button("run") | |
| with gr.Column(scale=2): | |
| visual_output = gr.Image(label="output visualize") | |
| extract_outputs = gr.Gallery(label="extracted images", object_fit="contain", columns=1, preview=True) | |
| time_output = gr.Textbox(label="elapsed") | |
| run_button.click( | |
| fn=run_inference, | |
| inputs=[img_input, obj_model_type, edge_model_type, cls_model_type, det_accuracy, use_obj_det, use_edge_det, | |
| use_cls_det], | |
| outputs=[visual_output, extract_outputs, time_output] | |
| ) | |
| # 启动Gradio应用 | |
| demo.launch() | |