File size: 10,258 Bytes
230c9a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
==================================
代码实现
==================================

PDF-Extract-Kit项目的核心代码实现在pdf_extract_kit目录下,该路径下包含下述几个模块:

- configs: 特定模块的配置文件,如 ``pdf_extract_kit/configs/unimernet.yaml`` ,如果本身配置简单,建议放在 ``repo_root/configs````yaml`` 文件中的 ``model_config`` 里进行定义,方便用户修改。

- dataset: 自定义的 ``ImageDataset`` 类,用于加载和预处理图像数据。它支持多种输入类型,并且可以对图像进行统一的预处理操作(如调整大小、转换为张量等),以便于后续的模型推理加速。

- evaluation: 模型结果评测模块,支持多种任务类型评测,如 ``布局检测````公式检测````公式识别`` 等等,方便用户对不同任务、不同模型进行公平对比。

- registry: ``Registry`` 类是一个通用的注册表类,提供了注册、获取和列出注册项的功能。用户可以使用该类创建不同类型的注册表,例如任务注册表、模型注册表等。

- tasks: 最核心的任务模块,包含了许多不同类型的任务,如 ``布局检测````公式检测````公式识别`` 等等,用户添加新任务和新模型一般仅需要在这里进行代码添加。


.. note::
    基于上述的模块化设计,用户拓展新模块一般只需要在tasks里实现自己的新任务类及对应模型(更多情况下仅需要实现对应模型,任务已经定义好),然后在registry里注册即可。


下面我们以添加基于 ``YOLO````布局检测`` 模型为例,介绍如何添加新任务和新模型.

任务定义及注册
==============

首先我们在 ``tasks`` 下添加一个 ``layout_detection`` 目录,然后在该目录下添加一个 ``task.py`` 文件用于定义布局检测任务类,具体如下:

.. code-block:: python

    from pdf_extract_kit.registry.registry import TASK_REGISTRY
    from pdf_extract_kit.tasks.base_task import BaseTask


    @TASK_REGISTRY.register("layout_detection")
    class LayoutDetectionTask(BaseTask):
        def __init__(self, model):
            super().__init__(model)

        def predict_images(self, input_data, result_path):
            """

            Predict layouts in images.



            Args:

                input_data (str): Path to a single image file or a directory containing image files.

                result_path (str): Path to save the prediction results.



            Returns:

                list: List of prediction results.

            """
            images = self.load_images(input_data)
            # Perform detection
            return self.model.predict(images, result_path)

        def predict_pdfs(self, input_data, result_path):
            """

            Predict layouts in PDF files.



            Args:

                input_data (str): Path to a single PDF file or a directory containing PDF files.

                result_path (str): Path to save the prediction results.



            Returns:

                list: List of prediction results.

            """
            pdf_images = self.load_pdf_images(input_data)
            # Perform detection
            return self.model.predict(list(pdf_images.values()), result_path, list(pdf_images.keys()))

可以看到,任务定义包含下面几个要点:

* 使用 ``@TASK_REGISTRY.register("layout_detection")`` 语法直接将布局任务类注册到 ``TASK_REGISTRY`` 下 ;
* ``__init__`` 初始化函数传入 ``model`` , 具体参考 ``BaseTask`` 类
* 实现推理函数,这里考虑到布局检测通常会处理图像类及PDF文件,所以提供了两个函数 ``predict_images````predict_pdfs`` ,方便用户灵活选择。

模型定义及注册
==============

接下来我们实现具体模型,在task下面新建models目录,并添加yolo.py用于YOLO模型定义,具体定义如下:

.. code-block:: python

    import os
    import cv2
    import torch
    from torch.utils.data import DataLoader, Dataset
    from ultralytics import YOLO
    from pdf_extract_kit.registry import MODEL_REGISTRY
    from pdf_extract_kit.utils.visualization import  visualize_bbox
    from pdf_extract_kit.dataset.dataset import ImageDataset
    import torchvision.transforms as transforms


    @MODEL_REGISTRY.register('layout_detection_yolo')
    class LayoutDetectionYOLO:
        def __init__(self, config):
            """

            Initialize the LayoutDetectionYOLO class.



            Args:

                config (dict): Configuration dictionary containing model parameters.

            """
            # Mapping from class IDs to class names
            self.id_to_names = {
                0: 'title', 
                1: 'plain text',
                2: 'abandon', 
                3: 'figure', 
                4: 'figure_caption', 
                5: 'table', 
                6: 'table_caption', 
                7: 'table_footnote', 
                8: 'isolate_formula', 
                9: 'formula_caption'
            }

            # Load the YOLO model from the specified path
            self.model = YOLO(config['model_path'])

            # Set model parameters
            self.img_size = config.get('img_size', 1280)
            self.pdf_dpi = config.get('pdf_dpi', 200)
            self.conf_thres = config.get('conf_thres', 0.25)
            self.iou_thres = config.get('iou_thres', 0.45)
            self.visualize = config.get('visualize', False)
            self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
            self.batch_size = config.get('batch_size', 1)

        def predict(self, images, result_path, image_ids=None):
            """

            Predict layouts in images.



            Args:

                images (list): List of images to be predicted.

                result_path (str): Path to save the prediction results.

                image_ids (list, optional): List of image IDs corresponding to the images.



            Returns:

                list: List of prediction results.

            """
            results = []
            for idx, image in enumerate(images):
                result = self.model.predict(image, imgsz=self.img_size, conf=self.conf_thres, iou=self.iou_thres, verbose=False)[0]
                if self.visualize:
                    if not os.path.exists(result_path):
                        os.makedirs(result_path)
                    boxes = result.__dict__['boxes'].xyxy
                    classes = result.__dict__['boxes'].cls
                    vis_result = visualize_bbox(image, boxes, classes, self.id_to_names)

                    # Determine the base name of the image
                    if image_ids:
                        base_name = image_ids[idx]
                    else:
                        base_name = os.path.basename(image)
                    
                    result_name = f"{base_name}_MFD.png"
                    
                    # Save the visualized result                
                    cv2.imwrite(os.path.join(result_path, result_name), vis_result)
                results.append(result)
            return results


可以看到,模型定义包含下面几个要点:

* 使用 ``@MODEL_REGISTRY.register('layout_detection_yolo')`` 语法直接将yolo布局模型注册到 ``MODEL_REGISTRY`` 下;
* 初始化函数需要实现:
    + id_to_names的类别映射,用于可视化展示
    + 模型参数配置
    + 模型初始化
* 模型推理函数需要实现多种类型的模型推理:这里支持图像列表和PIL.Image类,可以方便用户直接基于图像路径或者图像流进行推理。

实现上述类定义后,将 ``LayoutDetectionYOLO`` 添加到 ``layout_detection`` 任务下 ``__init__.py````__all__`` 中即可。

.. code-block:: python

    from pdf_extract_kit.tasks.layout_detection.models.yolo import LayoutDetectionYOLO
    from pdf_extract_kit.registry.registry import MODEL_REGISTRY


    __all__ = [
        "LayoutDetectionYOLO",
    ]


.. note:: 
    对于同一个任务,我们支持多种模型,用户具体选择哪个可以根据评测结果进行选择,结合模型 ``精度````速度````场景适配程度`` 进行选择。


实现了任务和模型后,可以在 repo_root/scripts下添加脚本程序 ``layout_detection.py``

示例脚本
==============

.. code-block:: python

    import os
    import sys
    import os.path as osp
    import argparse

    sys.path.append(osp.join(os.path.dirname(os.path.abspath(__file__)), '..'))
    from pdf_extract_kit.utils.config_loader import load_config, initialize_tasks_and_models
    import pdf_extract_kit.tasks  # 确保所有任务模块被导入

    TASK_NAME = 'layout_detection'


    def parse_args():
        parser = argparse.ArgumentParser(description="Run a task with a given configuration file.")
        parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')
        return parser.parse_args()

    def main(config_path):
        config = load_config(config_path)
        task_instances = initialize_tasks_and_models(config)

        # get input and output path from config
        input_data = config.get('inputs', None)
        result_path = config.get('outputs', 'outputs'+'/'+TASK_NAME)

        # layout_detection_task
        model_layout_detection = task_instances[TASK_NAME]

        # for image detection
        detection_results = model_layout_detection.predict_images(input_data, result_path)

        # for pdf detection
        # detection_results = model_layout_detection.predict_pdfs(input_data, result_path)

        # print(detection_results)
        print(f'The predicted results can be found at {result_path}')


    if __name__ == "__main__":
        args = parse_args()
        main(args.config)

支持类型拓展
==============


批处理拓展
==============