File size: 7,291 Bytes
8570e66
3b5e13f
 
 
 
 
 
 
 
 
 
 
8570e66
3b5e13f
 
 
9e87980
3b5e13f
 
 
9e87980
3b5e13f
 
 
8570e66
 
3b5e13f
 
 
 
 
 
 
 
 
 
 
2a39c5a
 
 
 
 
 
3b5e13f
 
8570e66
3b5e13f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8570e66
9e87980
8570e66
 
3b5e13f
9e87980
3b5e13f
 
9e87980
3b5e13f
9e87980
3b5e13f
 
 
 
 
 
9e87980
 
3b5e13f
 
 
 
 
9e87980
3b5e13f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e87980
3b5e13f
 
 
 
9e87980
3b5e13f
 
 
 
 
 
 
 
 
 
 
 
 
c887e15
 
 
 
 
 
3b5e13f
c887e15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e87980
3b5e13f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8570e66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import threading
import time

import cv2
import gradio as gr
from lineless_table_rec import LinelessTableRecognition
from paddleocr import PPStructure
from rapid_table import RapidTable
from rapidocr_onnxruntime import RapidOCR
from table_cls import TableCls
from wired_table_rec import WiredTableRecognition

from utils import plot_rec_box, LoadImage, format_html, box_4_2_poly_to_box_4_1
img_loader = LoadImage()
table_rec_path = "models/table_rec/ch_ppstructure_mobile_v2_SLANet.onnx"
det_model_dir = {
    "mobile_det": "models/ocr/ch_PP-OCRv4_det_infer.onnx",
}

rec_model_dir = {
    "mobile_rec": "models/ocr/ch_PP-OCRv4_rec_infer.onnx",
}
table_engine_list = [
    "auto",
    "RapidTable(SLANet)",
    "RapidTable(SLANet-plus)",
    "wired_table_v2",
    "pp_table",
    "wired_table_v1",
    "lineless_table"
]

# 示例图片路径
example_images = [
    "images/wired1.png",
    "images/wired2.png",
    "images/wired3.png",
    "images/lineless1.png",
    "images/wired4.jpg",
    "images/lineless2.png",
    "images/wired5.jpg",
    "images/lineless3.jpg",
    "images/wired6.jpg",
]
rapid_table_engine = RapidTable(model_path=table_rec_path)
SLANet_plus_table_Engine = RapidTable()
wired_table_engine_v1 = WiredTableRecognition(version="v1")
wired_table_engine_v2 = WiredTableRecognition(version="v2")
lineless_table_engine = LinelessTableRecognition()
table_cls = TableCls()
ocr_engine_dict = {}
pp_engine_dict = {}
for det_model in det_model_dir.keys():
    for rec_model in rec_model_dir.keys():
        det_model_path = det_model_dir[det_model]
        rec_model_path = rec_model_dir[rec_model]
        key = f"{det_model}_{rec_model}"
        ocr_engine_dict[key] = RapidOCR(det_model_path=det_model_path, rec_model_path=rec_model_path)
        pp_engine_dict[key] = PPStructure(
            layout=False,
            show_log=False,
            table=True,
            use_onnx=True,
            table_model_dir=table_rec_path,
            det_model_dir=det_model_path,
            rec_model_dir=rec_model_path
        )


def select_ocr_model(det_model, rec_model):
    return ocr_engine_dict[f"{det_model}_{rec_model}"]


def select_table_model(img, table_engine_type, det_model, rec_model):
    if table_engine_type == "RapidTable(SLANet)":
        return rapid_table_engine, table_engine_type
    elif table_engine_type == "RapidTable(SLANet-plus)":
        return SLANet_plus_table_Engine, table_engine_type
    elif table_engine_type == "wired_table_v1":
        return wired_table_engine_v1, table_engine_type
    elif table_engine_type == "wired_table_v2":
        print("使用v2 wired table")
        return wired_table_engine_v2, table_engine_type
    elif table_engine_type == "lineless_table":
        return lineless_table_engine, table_engine_type
    elif table_engine_type == "pp_table":
        return pp_engine_dict[f"{det_model}_{rec_model}"], 0
    elif table_engine_type == "auto":
        cls, elasp = table_cls(img)
        if cls == 'wired':
            table_engine = wired_table_engine_v2
            return table_engine, "wired_table_v2"
        return lineless_table_engine, "lineless_table"


def process_image(img, table_engine_type, det_model, rec_model):
    img = img_loader(img)
    start = time.time()
    table_engine, talbe_type = select_table_model(img, table_engine_type, det_model, rec_model)
    ocr_engine = select_ocr_model(det_model, rec_model)

    if isinstance(table_engine, PPStructure):
        result = table_engine(img, return_ocr_result_in_table=True)
        html = result[0]['res']['html']
        polygons = result[0]['res']['cell_bbox']
        polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
        ocr_boxes = result[0]['res']['boxes']
        all_elapse = f"- `table all cost: {time.time() - start:.5f}`"
    else:
        ocr_res, ocr_infer_elapse = ocr_engine(img)
        det_cost, cls_cost, rec_cost = ocr_infer_elapse
        ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
        if isinstance(table_engine, RapidTable):
            html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
            polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
        elif isinstance(table_engine, (WiredTableRecognition, LinelessTableRecognition)):
            html, table_rec_elapse, polygons, _, _ = table_engine(img, ocr_result=ocr_res)

        sum_elapse = time.time() - start
        all_elapse = f"- table_type: {talbe_type}\n table all cost: {sum_elapse:.5f}\n - table rec cost: {table_rec_elapse:.5f}\n - ocr cost: {det_cost + cls_cost + rec_cost:.5f}"

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    table_boxes_img = plot_rec_box(img.copy(), polygons)
    ocr_boxes_img = plot_rec_box(img.copy(), ocr_boxes)
    complete_html = format_html(html)

    return complete_html, table_boxes_img, ocr_boxes_img, all_elapse


def main():
    det_models_labels = list(det_model_dir.keys())
    rec_models_labels = list(rec_model_dir.keys())

    with gr.Blocks(css="""
        .scrollable-container {
            overflow-x: auto;
            white-space: nowrap;
        }
    """) as demo:
        with gr.Row():  # 两列布局
            with gr.Tab("Options"):
                with gr.Column(variant="panel", scale=1):  # 侧边栏,宽度比例为1
                    img_input = gr.Image(label="Upload or Select Image",  sources="upload", value="images/lineless3.jpg")

                    # 示例图片选择器
                    examples = gr.Examples(
                        examples=example_images,
                        inputs=img_input,
                        fn=lambda x: x,  # 简单返回图片路径
                        outputs=img_input,
                        cache_examples=True
                    )

                    table_engine_type = gr.Dropdown(table_engine_list, label="Select Recognition Table Engine",
                                                    value=table_engine_list[0])
                    det_model = gr.Dropdown(det_models_labels, label="Select OCR Detection Model",
                                            value=det_models_labels[0])
                    rec_model = gr.Dropdown(rec_models_labels, label="Select OCR Recognition Model",
                                            value=rec_models_labels[0])

                    run_button = gr.Button("Run")
                    gr.Markdown("# Elapsed Time")
                    elapse_text = gr.Text(label="")  # 使用 `gr.Text` 组件展示字符串
            with gr.Column(scale=2):  # 右边列
                # 使用 Markdown 标题分隔各个组件
                gr.Markdown("# Html Render")
                html_output = gr.HTML(label="", elem_classes="scrollable-container")
                gr.Markdown("# Table Boxes")
                table_boxes_output = gr.Image(label="")
                gr.Markdown("# OCR Boxes")
                ocr_boxes_output = gr.Image(label="")

        run_button.click(
            fn=process_image,
            inputs=[img_input, table_engine_type, det_model, rec_model],
            outputs=[html_output, table_boxes_output, ocr_boxes_output, elapse_text]
        )

    demo.launch()


if __name__ == '__main__':
    main()