import gradio as gr import os import cv2 from rapid_table_det.inference import TableDetector from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img example_images = [ "images/doc1.png", "images/doc2.jpg", "images/doc3.jpg", "images/doc4.jpg", "images/doc5.jpg", "images/real1.jpg", "images/real2.jpeg", "images/real3.jpg", "images/real4.jpg", "images/real5.jpg" ] # 定义模型类型选项 model_type_options = { "YOLO 目标检测": ["yolo_obj_det"], "Paddle 目标检测": ["paddle_obj_det"], "Paddle 目标检测 (量化)": ["paddle_obj_det_s"], "YOLO 语义分割": ["yolo_edge_det"], "YOLO 语义分割 (小型)": ["yolo_edge_det_s"], "Paddle 语义分割": ["paddle_edge_det"], "Paddle 语义分割 (量化)": ["paddle_edge_det_s"], "Paddle 方向分类": ["paddle_cls_det"] } # 预生成所有可能的 TableDetector 实例 preinitialized_detectors = {} for obj_model_type in model_type_options["YOLO 目标检测"] + model_type_options["Paddle 目标检测"] + model_type_options[ "Paddle 目标检测 (量化)"]: for edge_model_type in model_type_options["YOLO 语义分割"] + model_type_options["YOLO 语义分割 (小型)"] + model_type_options[ "Paddle 语义分割"] + model_type_options["Paddle 语义分割 (量化)"]: for cls_model_type in model_type_options["Paddle 方向分类"]: detector_key = (obj_model_type, edge_model_type, cls_model_type) preinitialized_detectors[detector_key] = TableDetector( obj_model_type=obj_model_type, edge_model_type=edge_model_type, cls_model_type=cls_model_type, obj_model_path=os.path.join("models", f"{obj_model_type}.onnx"), edge_model_path=os.path.join("models", f"{edge_model_type}.onnx"), cls_model_path=os.path.join("models", f"{cls_model_type}.onnx") ) # 定义图片缩放函数 def resize_image(image, max_size=640): height, width = image.shape[:2] if max(height, width) > max_size: scale = max_size / max(height, width) new_height = int(height * scale) new_width = int(width * scale) image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) return image # 定义推理函数 def run_inference(img_path, obj_model_type, edge_model_type, cls_model_type, det_accuracy, use_obj_det, use_edge_det, use_cls_det): detector_key = (obj_model_type, edge_model_type, cls_model_type) table_det = preinitialized_detectors[detector_key] result, elapse = table_det( img_path, det_accuracy=det_accuracy, use_obj_det=use_obj_det, use_edge_det=use_edge_det, use_cls_det=use_cls_det ) # 加载图片 img = img_loader(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) extract_img = img.copy() visual_img = img.copy() extract_imgs = [] for i, res in enumerate(result): box = res["box"] lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"] # 带识别框和左上角方向位置 visual_img = visuallize(visual_img, box, lt, rt, rb, lb) # 透视变换提取表格图片 wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb) extract_imgs.append(wrapped_img) # 缩放图片 visual_img = resize_image(visual_img) extract_imgs = [resize_image(img) for img in extract_imgs] obj_det_elapse, edge_elapse, rotate_det_elapse = elapse return visual_img, extract_imgs, f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}" def update_extract_outputs(visual_img, extract_imgs, time_info): if len(extract_imgs) == 1: return visual_img, extract_imgs[0], time_info else: return visual_img, extract_imgs, time_info # 创建Gradio界面 with gr.Blocks( css=""" .scrollable-container { overflow-x: auto; white-space: nowrap; } .header-links { text-align: center; } .header-links a { display: inline-block; text-align: center; margin-right: 10px; /* 调整间距 */ } """ ) as demo: gr.HTML( "

RapidTableDetection

" ) gr.HTML(''' ''') with gr.Row(): with gr.Column(variant="panel", scale=1): img_input = gr.Image(label="Upload or Select Image", sources="upload", value="images/real1.jpg") # 示例图片选择器 examples = gr.Examples( examples=example_images, examples_per_page=len(example_images), inputs=img_input, fn=lambda x: x, # 简单返回图片路径 outputs=img_input, cache_examples=False ) obj_model_type = gr.Dropdown( choices=model_type_options["YOLO 目标检测"] + model_type_options["Paddle 目标检测"] + model_type_options["Paddle 目标检测 (量化)"], value="yolo_obj_det", label="obj det model") edge_model_type = gr.Dropdown( choices=model_type_options["YOLO 语义分割"] + model_type_options["YOLO 语义分割 (小型)"] + model_type_options["Paddle 语义分割"] + model_type_options["Paddle 语义分割 (量化)"], value="yolo_edge_det", label="edge seg model") cls_model_type = gr.Dropdown(choices=model_type_options["Paddle 方向分类"], value="paddle_cls_det", label="direction cls model") det_accuracy = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="目标检测置信度阈值") use_obj_det = gr.Checkbox(value=True, label="use obj det") use_edge_det = gr.Checkbox(value=True, label="use edge seg") use_cls_det = gr.Checkbox(value=True, label="use direction cls") run_button = gr.Button("run") with gr.Column(scale=2): visual_output = gr.Image(label="output visualize") extract_outputs = gr.Gallery(label="extracted images", object_fit="contain", columns=1, preview=True) time_output = gr.Textbox(label="elapsed") run_button.click( fn=run_inference, inputs=[img_input, obj_model_type, edge_model_type, cls_model_type, det_accuracy, use_obj_det, use_edge_det, use_cls_det], outputs=[visual_output, extract_outputs, time_output] ) # 启动Gradio应用 demo.launch()