File size: 6,240 Bytes
fffbbeb
 
ae5ac9c
 
 
03a594f
 
 
be84858
 
 
 
 
 
8dc2d5d
47bf6ec
be84858
 
ae5ac9c
47bf6ec
 
 
 
 
 
 
ae5ac9c
 
 
 
 
 
fffbbeb
be84858
 
 
 
 
47bf6ec
be84858
 
 
fffbbeb
 
 
 
1402288
8dc2d5d
1402288
 
8dc2d5d
 
1402288
ae5ac9c
8dc2d5d
ae5ac9c
 
8dc2d5d
 
ae5ac9c
1402288
8dc2d5d
 
03a594f
 
 
 
 
 
8dc2d5d
 
03a594f
1402288
8dc2d5d
 
be84858
 
 
 
 
 
 
 
 
 
 
 
 
7f16c12
 
be84858
 
 
 
 
 
 
 
 
8dc2d5d
 
fffbbeb
 
 
 
 
 
 
1402288
bf6c79b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fffbbeb
d4225f7
2c07938
fffbbeb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from PyPDF2 import PdfReader
import gradio as gr
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.datamodel.base_models import InputFormat
from paddleocr import PPStructureV3
from pdf2image import convert_from_path
import numpy as np
import torch
from docling_core.types.doc import DoclingDocument
from docling_core.types.doc.document import DocTagsDocument
from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers.image_utils import load_image
from pathlib import Path
import time
import os

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Optimize for CPU: set float32 and use all CPU threads
if DEVICE == "cpu":
    torch.set_num_threads(os.cpu_count() or 1)
    smoldocling_dtype = torch.float32
else:
    smoldocling_dtype = torch.bfloat16

pipeline_options = PdfPipelineOptions(enable_remote_services=True)
converter = DocumentConverter(
    format_options={
        InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
    }
)

pipeline = PPStructureV3()

processor = AutoProcessor.from_pretrained("ds4sd/SmolDocling-256M-preview")
model = AutoModelForVision2Seq.from_pretrained(
    "ds4sd/SmolDocling-256M-preview",
    torch_dtype=smoldocling_dtype,
    _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
).to(DEVICE)

def get_pdf_page_count(pdf_path):
    reader = PdfReader(pdf_path)
    return len(reader.pages)

def get_page_image(pdf_path, page_num):
    start = time.time()
    images = convert_from_path(pdf_path, first_page=page_num, last_page=page_num)
    page_image = images[0]
    runtime = time.time() - start
    return page_image, f"{runtime:.2f} s"

def get_docling_ocr(pdf_path, page_num):
    start = time.time()
    result = converter.convert(pdf_path, page_range=(page_num, page_num))
    markdown_text_docling = result.document.export_to_markdown()
    runtime = time.time() - start
    return markdown_text_docling, f"{runtime:.2f} s"

def get_paddle_ocr(pdf_path, page_num):
    start = time.time()
    page_image = get_page_image(pdf_path, page_num)[0]
    output = pipeline.predict(input=np.array(page_image))
    markdown_list = []
    for res in output:
        md_info = res.markdown
        markdown_list.append(md_info)
    markdown_text_paddleOCR = pipeline.concatenate_markdown_pages(markdown_list)
    runtime = time.time() - start
    return markdown_text_paddleOCR, f"{runtime:.2f} s"

def get_smoldocling_ocr(pdf_path, page_num):
    start = time.time()
    page_image = get_page_image(pdf_path, page_num)[0]
    image = load_image(page_image)
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": "Convert this page to docling."}
            ]
        },
    ]
    prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(text=prompt, images=[image], return_tensors="pt")
    inputs = inputs.to(DEVICE)
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=1500, do_sample=False, num_beams=1, temperature=1.0)
    prompt_length = inputs.input_ids.shape[1]
    trimmed_generated_ids = generated_ids[:, prompt_length:]
    doctags = processor.batch_decode(
        trimmed_generated_ids,
        skip_special_tokens=False,
    )[0].lstrip()
    doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctags], [image])
    doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
    markdown_text_smoldocling = doc.export_to_markdown()
    runtime = time.time() - start
    return markdown_text_smoldocling, f"{runtime:.2f} s"

title = "OCR Arena"
description = "A simple Gradio interface to extract text from PDFs and compare OCR models"
examples = [["data/amazon-10-k-2024.pdf"],
            ["data/goog-10-k-2023.pdf"]]

with gr.Blocks(theme=gr.themes.Glass()) as demo:
    gr.Markdown(f"# {title}\n{description}")
    with gr.Column():
        pdf = gr.File(label="Input PDFs", file_types=[".pdf"])

        @gr.render(inputs=pdf)
        def show_slider(pdf_path):
            if pdf_path is None:
                page_num = gr.Markdown("## No Input Provided")
            else:
                page_count = get_pdf_page_count(pdf_path)
                page_num = gr.Slider(1, page_count, value=1, step=1, label="Page Number")

                with gr.Row():
                    clear_btn = gr.ClearButton(components=[pdf, page_num])
                    submit_btn = gr.Button("Submit", variant='primary')

                submit_btn.click(get_page_image, inputs=[pdf, page_num], outputs=[original, original_runtime]).then(    
                    get_docling_ocr, inputs=[pdf, page_num], outputs=[docling_ocr_out, docling_ocr_runtime]).then(
                    get_paddle_ocr, inputs=[pdf, page_num], outputs=[paddle_ocr_out, paddle_ocr_runtime]).then(
                    get_smoldocling_ocr, inputs=[pdf, page_num], outputs=[smoldocling_ocr_out, smoldocling_ocr_runtime])

    with gr.Column():
        with gr.Row():
            with gr.Column():
                original = gr.Image(width=640, height=640, label="Original Page", interactive=False)
                original_runtime = gr.Textbox(label="Image Extraction Time", type="text", interactive=False)
            with gr.Column():
                docling_ocr_out = gr.Textbox(label="Docling OCR Output", type="text", interactive=False)
                docling_ocr_runtime = gr.Textbox(label="Docling OCR Time", type="text", interactive=False)
        with gr.Row():
            with gr.Column():
                paddle_ocr_out = gr.Textbox(label="Paddle OCR Output", type="text", interactive=False)
                paddle_ocr_runtime = gr.Textbox(label="Paddle OCR Time", type="text", interactive=False)
            with gr.Column():
                smoldocling_ocr_out = gr.Textbox(label="SmolDocling OCR Output", type="text", interactive=False)
                smoldocling_ocr_runtime = gr.Textbox(label="SmolDocling OCR Time", type="text", interactive=False)

    examples_obj = gr.Examples(examples=examples, inputs=[pdf])

demo.launch()