File size: 9,406 Bytes
d748bf5
 
 
317c295
4769339
317c295
6c71924
ac7b15a
 
4769339
317c295
 
 
 
4769339
 
 
 
 
317c295
4769339
 
981daf7
 
 
 
 
4769339
 
317c295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4769339
 
317c295
 
 
4769339
ac7b15a
317c295
 
 
ac7b15a
317c295
 
 
ac7b15a
317c295
 
ac7b15a
317c295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72386ad
6c47f29
ac7b15a
317c295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac7b15a
317c295
 
 
 
 
 
 
 
 
 
 
 
 
ac7b15a
317c295
ac7b15a
317c295
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import os
os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')

from os import getcwd, path, environ
import deepdoctection as dd
from deepdoctection.dataflow.serialize import DataFromList

import gradio as gr


_DD_ONE = "conf_dd_one.yaml"
_DETECTIONS = ["table", "ocr"]

dd.ModelCatalog.register("layout/model_final_inf_only.pt",dd.ModelProfile(
            name="layout/model_final_inf_only.pt",
            description="Detectron2 layout detection model trained on private datasets",
            config="dd/d2/layout/CASCADE_RCNN_R_50_FPN_GN.yaml",
            size=[274632215],
            tp_model=False,
            hf_repo_id=environ.get("HF_REPO"),
            hf_model_name="model_final_inf_only.pt",
            hf_config_file=["Base-RCNN-FPN.yaml", "CASCADE_RCNN_R_50_FPN_GN.yaml"],
            categories={"1": dd.LayoutType.text,
                        "2": dd.LayoutType.title,
                        "3": dd.LayoutType.list,
                        "4": dd.LayoutType.table,
                        "5": dd.LayoutType.figure},
        ))

# Set up of the configuration and logging. Models are globally defined, so that they are not re-loaded once the input
# updates
cfg = dd.set_config_by_yaml(path.join(getcwd(),_DD_ONE))
cfg.freeze(freezed=False)
cfg.DEVICE = "cpu"
cfg.freeze()

# layout detector
layout_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2LAYOUT)
layout_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2LAYOUT)
categories_layout = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2LAYOUT).categories
assert categories_layout is not None
assert layout_weights_path is not None
d_layout = dd.D2FrcnnDetector(layout_config_path, layout_weights_path, categories_layout, device=cfg.DEVICE)

# cell detector
cell_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2CELL)
cell_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2CELL)
categories_cell = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2CELL).categories
assert categories_cell is not None
d_cell = dd.D2FrcnnDetector(cell_config_path, cell_weights_path, categories_cell, device=cfg.DEVICE)

# row/column detector
item_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2ITEM)
item_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2ITEM)
categories_item = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2ITEM).categories
assert categories_item is not None
d_item = dd.D2FrcnnDetector(item_config_path, item_weights_path, categories_item, device=cfg.DEVICE)

# word detector
det = dd.DoctrTextlineDetector()

# text recognizer
rec = dd.DoctrTextRecognizer()


def build_gradio_analyzer(table, table_ref, ocr):
    """Building the Detectron2/DocTr analyzer based on the given config"""

    cfg.freeze(freezed=False)
    cfg.TAB = table
    cfg.TAB_REF = table_ref
    cfg.OCR = ocr
    cfg.freeze()

    pipe_component_list = []
    layout = dd.ImageLayoutService(d_layout, to_image=True, crop_image=True)
    pipe_component_list.append(layout)

    if cfg.TAB:
        cell = dd.SubImageLayoutService(d_cell, dd.LayoutType.table, {1: 6}, True)
        pipe_component_list.append(cell)

        item = dd.SubImageLayoutService(d_item, dd.LayoutType.table, {1: 7, 2: 8}, True)
        pipe_component_list.append(item)

        table_segmentation = dd.TableSegmentationService(
            cfg.SEGMENTATION.ASSIGNMENT_RULE,
            cfg.SEGMENTATION.IOU_THRESHOLD_ROWS
            if cfg.SEGMENTATION.ASSIGNMENT_RULE in ["iou"]
            else cfg.SEGMENTATION.IOA_THRESHOLD_ROWS,
            cfg.SEGMENTATION.IOU_THRESHOLD_COLS
            if cfg.SEGMENTATION.ASSIGNMENT_RULE in ["iou"]
            else cfg.SEGMENTATION.IOA_THRESHOLD_COLS,
            cfg.SEGMENTATION.FULL_TABLE_TILING,
            cfg.SEGMENTATION.REMOVE_IOU_THRESHOLD_ROWS,
            cfg.SEGMENTATION.REMOVE_IOU_THRESHOLD_COLS,
        )
        pipe_component_list.append(table_segmentation)

        if cfg.TAB_REF:
            table_segmentation_refinement = dd.TableSegmentationRefinementService()
            pipe_component_list.append(table_segmentation_refinement)

    if cfg.OCR:
        d_layout_text = dd.ImageLayoutService(det, to_image=True, crop_image=True)
        pipe_component_list.append(d_layout_text)

        d_text = dd.TextExtractionService(rec, extract_from_roi="WORD")
        pipe_component_list.append(d_text)

        match = dd.MatchingService(
            parent_categories=cfg.WORD_MATCHING.PARENTAL_CATEGORIES,
            child_categories=dd.LayoutType.word,
            matching_rule=cfg.WORD_MATCHING.RULE,
            threshold=cfg.WORD_MATCHING.IOU_THRESHOLD
            if cfg.WORD_MATCHING.RULE in ["iou"]
            else cfg.WORD_MATCHING.IOA_THRESHOLD,
        )
        pipe_component_list.append(match)
        order = dd.TextOrderService(
            text_container=dd.LayoutType.word,
            floating_text_block_names=[dd.LayoutType.title, dd.LayoutType.text, dd.LayoutType.list],
            text_block_names=[
                dd.LayoutType.title,
                dd.LayoutType.text,
                dd.LayoutType.list,
                dd.LayoutType.cell,
                dd.CellType.header,
                dd.CellType.body,
            ],
        )
        pipe_component_list.append(order)

    pipe = dd.DoctectionPipe(pipeline_component_list=pipe_component_list)

    return pipe


def prepare_output(dp, add_table, add_ocr):
    out = dp.as_dict()
    out.pop("image")

    layout_items = dp.items
    if add_ocr:
        layout_items.sort(key=lambda x: x.reading_order)
    layout_items_str = ""
    for item in layout_items:
        layout_items_str += f"\n {item.layout_type}: {item.text}"
    if add_table:
        html_list = [table.html for table in dp.tables]
        if html_list:
            html = html_list[0]
        else:
            html = None
    else:
        html = None

    return dp.viz(show_table_structure=False), layout_items_str, html, out


def analyze_image(img, pdf, attributes):

    # creating an image object and passing to the analyzer by using dataflows
    add_table = _DETECTIONS[0] in attributes
    add_ocr = _DETECTIONS[1] in attributes

    analyzer = build_gradio_analyzer(add_table, add_table, add_ocr)

    if img is not None:
        image = dd.Image(file_name="input.png", location="")
        image.image = img[:, :, ::-1]

        df = DataFromList(lst=[image])
        df = analyzer.analyze(dataset_dataflow=df)
    elif pdf:
        df = analyzer.analyze(path=pdf.name, max_datapoints=3)
    else:
        raise ValueError

    df.reset_state()
    df_iter = iter(df)

    dp = next(df_iter)

    return prepare_output(dp, add_table, add_ocr)


demo = gr.Blocks(css="scrollbar.css")

with demo:
    with gr.Box():
        gr.Markdown("<h1><center>deepdoctection - A Document AI Package</center></h1>")
        gr.Markdown("<strong>deep</strong>doctection is a Python library that orchestrates document extraction"
                    " and document layout analysis tasks using deep learning models. It does not implement models"
                    " but enables you to build pipelines using highly acknowledged libraries for object detection,"
                    " OCR and selected NLP tasks and provides an integrated frameworks for fine-tuning, evaluating"
                    " and running models.\n This pipeline consists of a stack of models powered by <strong>Detectron2"
                    "</strong> for layout analysis and table recognition and <strong>DocTr</strong> for OCR.")
    with gr.Box():
        gr.Markdown("<h2><center>Upload a document and choose setting</center></h2>")
        with gr.Row():
            with gr.Column():
                with gr.Tab("Image upload"):
                    with gr.Column():
                        inputs = gr.Image(type='numpy', label="Original Image")
                with gr.Tab("PDF upload (only first image will be processed)"):
                    with gr.Column():
                        inputs_pdf = gr.File(label="PDF")
            with gr.Column():
                gr.Examples(
                    examples=[path.join(getcwd(), "sample_1.jpg"), path.join(getcwd(), "sample_2.png")],
                    inputs = inputs)

        with gr.Row():
            tok_input = gr.CheckboxGroup(
                _DETECTIONS, value=_DETECTIONS, label="Additional extractions", interactive=True)
        with gr.Row():
            btn = gr.Button("Run model", variant="primary")

    with gr.Box():
        with gr.Row():
            with gr.Column():
                gr.Markdown("<h2><center>Text output</center></h2>")
                gr.Markdown("Will only show contiguous text from text blocks, titles and lists")
                image_text = gr.Textbox()
                gr.Markdown("<h2><center>First table</center></h2>")
                html = gr.HTML()
                gr.Markdown("<h2><center>JSON output</center></h2>") 
                json = gr.JSON()
            with gr.Column():
                gr.Markdown("<h2><center>Layout detection</center></h2>")
                image_output = gr.Image(type="numpy", label="Output Image")

    btn.click(fn=analyze_image, inputs=[inputs, inputs_pdf, tok_input], outputs=[image_output, image_text, html, json])

demo.launch()