File size: 10,195 Bytes
3b5e13f
 
 
 
 
 
 
 
 
 
 
8570e66
fcd5742
3b5e13f
 
 
9e87980
3b5e13f
 
 
9e87980
3b5e13f
 
 
8570e66
 
3b5e13f
 
 
 
 
 
 
 
 
 
 
2a39c5a
 
 
 
fcd5742
 
 
3b5e13f
 
8570e66
3b5e13f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8570e66
9e87980
8570e66
 
3b5e13f
9e87980
3b5e13f
 
9e87980
3b5e13f
9e87980
3b5e13f
 
 
 
 
 
9e87980
 
3b5e13f
 
fcd5742
3b5e13f
 
9e87980
3b5e13f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e87980
3b5e13f
 
fcd5742
 
 
 
 
 
 
 
 
3b5e13f
 
9e87980
3b5e13f
 
 
 
 
 
 
 
 
 
 
 
 
c887e15
 
 
 
 
fcd5742
 
 
 
 
 
 
 
c887e15
fcd5742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b5e13f
c887e15
 
fcd5742
c887e15
 
 
 
fcd5742
c887e15
 
 
fcd5742
c887e15
 
 
 
fcd5742
 
 
 
c887e15
 
 
 
 
 
 
 
9e87980
3b5e13f
 
 
 
 
 
 
 
 
 
fcd5742
3b5e13f
 
 
 
 
 
 
fcd5742
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import time

import cv2
import gradio as gr
from lineless_table_rec import LinelessTableRecognition
from paddleocr import PPStructure
from rapid_table import RapidTable
from rapidocr_onnxruntime import RapidOCR
from table_cls import TableCls
from wired_table_rec import WiredTableRecognition

from utils import plot_rec_box, LoadImage, format_html, box_4_2_poly_to_box_4_1

img_loader = LoadImage()
table_rec_path = "models/table_rec/ch_ppstructure_mobile_v2_SLANet.onnx"
det_model_dir = {
    "mobile_det": "models/ocr/ch_PP-OCRv4_det_infer.onnx",
}

rec_model_dir = {
    "mobile_rec": "models/ocr/ch_PP-OCRv4_rec_infer.onnx",
}
table_engine_list = [
    "auto",
    "RapidTable(SLANet)",
    "RapidTable(SLANet-plus)",
    "wired_table_v2",
    "pp_table",
    "wired_table_v1",
    "lineless_table"
]

# 示例图片路径
example_images = [
    "images/wired1.png",
    "images/wired2.png",
    "images/wired3.png",
    "images/lineless1.png",
    "images/wired4.jpg",
    "images/lineless2.png",
    "images/wired5.jpg",
    "images/lineless4.jpg",
    "images/wired7.jpg",
    "images/wired9.jpg",
]
rapid_table_engine = RapidTable(model_path=table_rec_path)
SLANet_plus_table_Engine = RapidTable()
wired_table_engine_v1 = WiredTableRecognition(version="v1")
wired_table_engine_v2 = WiredTableRecognition(version="v2")
lineless_table_engine = LinelessTableRecognition()
table_cls = TableCls()
ocr_engine_dict = {}
pp_engine_dict = {}
for det_model in det_model_dir.keys():
    for rec_model in rec_model_dir.keys():
        det_model_path = det_model_dir[det_model]
        rec_model_path = rec_model_dir[rec_model]
        key = f"{det_model}_{rec_model}"
        ocr_engine_dict[key] = RapidOCR(det_model_path=det_model_path, rec_model_path=rec_model_path)
        pp_engine_dict[key] = PPStructure(
            layout=False,
            show_log=False,
            table=True,
            use_onnx=True,
            table_model_dir=table_rec_path,
            det_model_dir=det_model_path,
            rec_model_dir=rec_model_path
        )


def select_ocr_model(det_model, rec_model):
    return ocr_engine_dict[f"{det_model}_{rec_model}"]


def select_table_model(img, table_engine_type, det_model, rec_model):
    if table_engine_type == "RapidTable(SLANet)":
        return rapid_table_engine, table_engine_type
    elif table_engine_type == "RapidTable(SLANet-plus)":
        return SLANet_plus_table_Engine, table_engine_type
    elif table_engine_type == "wired_table_v1":
        return wired_table_engine_v1, table_engine_type
    elif table_engine_type == "wired_table_v2":
        print("使用v2 wired table")
        return wired_table_engine_v2, table_engine_type
    elif table_engine_type == "lineless_table":
        return lineless_table_engine, table_engine_type
    elif table_engine_type == "pp_table":
        return pp_engine_dict[f"{det_model}_{rec_model}"], 0
    elif table_engine_type == "auto":
        cls, elasp = table_cls(img)
        if cls == 'wired':
            table_engine = wired_table_engine_v2
            return table_engine, "wired_table_v2"
        return lineless_table_engine, "lineless_table"


def process_image(img, table_engine_type, det_model, rec_model, small_box_cut_enhance):
    img = img_loader(img)
    start = time.time()
    table_engine, talbe_type = select_table_model(img, table_engine_type, det_model, rec_model)
    ocr_engine = select_ocr_model(det_model, rec_model)

    if isinstance(table_engine, PPStructure):
        result = table_engine(img, return_ocr_result_in_table=True)
        html = result[0]['res']['html']
        polygons = result[0]['res']['cell_bbox']
        polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
        ocr_boxes = result[0]['res']['boxes']
        all_elapse = f"- `table all cost: {time.time() - start:.5f}`"
    else:
        ocr_res, ocr_infer_elapse = ocr_engine(img)
        det_cost, cls_cost, rec_cost = ocr_infer_elapse
        ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
        if isinstance(table_engine, RapidTable):
            html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
            polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
        elif isinstance(table_engine, (WiredTableRecognition, LinelessTableRecognition)):
            html, table_rec_elapse, polygons, _, _ = table_engine(img, ocr_result=ocr_res)
            if not small_box_cut_enhance:
                html, table_rec_elapse, polygons, logic_points, ocr_res = table_engine(
                    img, ocr_result=ocr_res,
                    morph_close=False, more_h_lines=False, more_v_lines=False, extend_line=False
                )
            else:
                html, table_rec_elapse, polygons, logic_points, ocr_res = table_engine(
                    img, ocr_result=ocr_res
                )

        sum_elapse = time.time() - start
        all_elapse = f"- table_type: {talbe_type}\n table all cost: {sum_elapse:.5f}\n - table rec cost: {table_rec_elapse:.5f}\n - ocr cost: {det_cost + cls_cost + rec_cost:.5f}"

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    table_boxes_img = plot_rec_box(img.copy(), polygons)
    ocr_boxes_img = plot_rec_box(img.copy(), ocr_boxes)
    complete_html = format_html(html)

    return complete_html, table_boxes_img, ocr_boxes_img, all_elapse


def main():
    det_models_labels = list(det_model_dir.keys())
    rec_models_labels = list(rec_model_dir.keys())

    with gr.Blocks(css="""
        .scrollable-container {
            overflow-x: auto;
            white-space: nowrap;
        }
        .header-links {
            text-align: center;
        }
        .header-links a {
            display: inline-block;
            text-align: center;
            margin-right: 10px;  /* 调整间距 */
        }
    """) as demo:
        gr.HTML(
            "<h1 style='text-align: center;'><a href='https://github.com/RapidAI/TableStructureRec?tab=readme-ov-file'>TableStructureRec</a></h1>"
        )
        gr.HTML('''
                                        <div class="header-links">
                                          <a href=""><img src="https://img.shields.io/badge/Python->=3.6,<3.12-aff.svg"></a>
                                          <a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Mac%2C%20Win-pink.svg"></a>
                                          <a href="https://pypi.org/project/lineless-table-rec/"><img alt="PyPI" src="https://img.shields.io/pypi/v/lineless-table-rec"></a>
                                          <a href="https://pepy.tech/project/lineless-table-rec"><img src="https://static.pepy.tech/personalized-badge/lineless-table-rec?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20Lineless"></a>
                                          <a href="https://pepy.tech/project/wired-table-rec"><img src="https://static.pepy.tech/personalized-badge/wired-table-rec?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20Wired"></a>
                                          <a href="https://semver.org/"><img alt="SemVer2.0" src="https://img.shields.io/badge/SemVer-2.0-brightgreen"></a>
                                          <a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
                                          <a href="https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE"><img alt="GitHub" src="https://img.shields.io/badge/license-Apache 2.0-blue"></a>
                                        </div>
                                        ''')
        with gr.Row():  # 两列布局
            with gr.Tab("Options"):
                with gr.Column(variant="panel", scale=1):  # 侧边栏,宽度比例为1
                    img_input = gr.Image(label="Upload or Select Image", sources="upload", value="images/lineless3.jpg")

                    # 示例图片选择器
                    examples = gr.Examples(
                        examples=example_images,
                        examples_per_page=len(example_images),
                        inputs=img_input,
                        fn=lambda x: x,  # 简单返回图片路径
                        outputs=img_input,
                        cache_examples=False
                    )

                    table_engine_type = gr.Dropdown(table_engine_list, label="Select Recognition Table Engine",
                                                    value=table_engine_list[0])
                    small_box_cut_enhance = gr.Checkbox(
                        label="Box Cutting Enhancement (Disable to avoid excessive cutting, Enable to reduce missed cutting)",
                        value=True
                    )
                    det_model = gr.Dropdown(det_models_labels, label="Select OCR Detection Model",
                                            value=det_models_labels[0])
                    rec_model = gr.Dropdown(rec_models_labels, label="Select OCR Recognition Model",
                                            value=rec_models_labels[0])

                    run_button = gr.Button("Run")
                    gr.Markdown("# Elapsed Time")
                    elapse_text = gr.Text(label="")  # 使用 `gr.Text` 组件展示字符串
            with gr.Column(scale=2):  # 右边列
                # 使用 Markdown 标题分隔各个组件
                gr.Markdown("# Html Render")
                html_output = gr.HTML(label="", elem_classes="scrollable-container")
                gr.Markdown("# Table Boxes")
                table_boxes_output = gr.Image(label="")
                gr.Markdown("# OCR Boxes")
                ocr_boxes_output = gr.Image(label="")

        run_button.click(
            fn=process_image,
            inputs=[img_input, table_engine_type, det_model, rec_model, small_box_cut_enhance],
            outputs=[html_output, table_boxes_output, ocr_boxes_output, elapse_text]
        )

    demo.launch()


if __name__ == '__main__':
    main()