File size: 4,678 Bytes
d4a6a10
 
 
 
 
4a0e08c
 
73be2fe
d4a6a10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d15bee8
4a0e08c
d4a6a10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch
from PIL import ImageDraw, ImageFont, Image
from transformers import AutoModelForTokenClassification, AutoProcessor
import fitz  # PyMuPDF
import io
import os

MODEL_KEY = os.getenv("MODEL_KEY")


def extract_data_from_pdf(pdf_path, page_number=0):
    """
    Extracts image, words, and bounding boxes from a specified page of a PDF.

    Args:
    - pdf_path (str): Path to the PDF file.
    - page_number (int): Page number to extract data from (0-indexed).

    Returns:
    - image: An image of the specified page.
    - words: A list of words found on the page.
    - boxes: A list of bounding boxes corresponding to the words.
    """
    # Open the PDF
    doc = fitz.open(pdf_path)
    page = doc.load_page(page_number)

    # Extract image of the page
    pix = page.get_pixmap()
    image_bytes = pix.tobytes("png")
    image = Image.open(io.BytesIO(image_bytes))

    # Extract words and their bounding boxes
    words = []
    boxes = []
    for word in page.get_text("words"):
        words.append(word[4])
        boxes.append(word[:4])  # (x0, y0, x1, y1)

    doc.close()
    return image, words, boxes


def merge_pairs_v2(pairs):
    if not pairs:
        return []

    merged = [pairs[0]]
    for current in pairs[1:]:
        last = merged[-1]
        if last[0] == current[0]:
            # Merge 'y' values (as strings) if 'x' values are the same
            merged[-1] = [last[0], last[1] + " " + current[1]]
        else:
            merged.append(current)

    return merged


def create_pretty_table(data):
    table = "<div>"
    for row in data:
        color = (
            "blue"
            if row[0] == "Heder"
            else "green"
            if row[0] == "Section"
            else "black"
        )
        table += "<p style='color:{};'>---{}---</p>{}".format(
            color, row[0], row[1]
        )
    table += "</div>"
    return table


# When using this function in Gradio, set the output type to 'html'


def interference(example, page_number=0):
    image, words, boxes = extract_data_from_pdf(example, page_number)
    boxes = [list(map(int, box)) for box in boxes]

    # Process the image and words
    model = AutoModelForTokenClassification.from_pretrained(
        "karida/LayoutLMv3_RFP",
        use_auth_token=MODEL_KEY
    )
    processor = AutoProcessor.from_pretrained(
        "microsoft/layoutlmv3-base", apply_ocr=False
    )
    encoding = processor(image, words, boxes=boxes, return_tensors="pt")

    # Prediction
    with torch.no_grad():
        outputs = model(**encoding)

    logits = outputs.logits
    predictions = logits.argmax(-1).squeeze().tolist()
    model_words = encoding.word_ids()

    # Process predictions
    token_boxes = encoding.bbox.squeeze().tolist()
    width, height = image.size

    true_predictions = [model.config.id2label[pred] for pred in predictions]
    true_boxes = token_boxes
    # Draw annotations on the image
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()

    def iob_to_label(label):
        label = label[2:]
        return "other" if not label else label.lower()

    label2color = {
        "question": "blue",
        "answer": "green",
        "header": "orange",
        "other": "violet",
    }

    # print(len(true_predictions), len(true_boxes), len(model_words))

    table = []
    ids = set()

    for prediction, box, model_word in zip(
        true_predictions, true_boxes, model_words
    ):
        predicted_label = iob_to_label(prediction)
        draw.rectangle(box, outline=label2color[predicted_label], width=2)
        # draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
        if model_word and model_word not in ids and predicted_label != "other":
            ids.add(model_word)
            table.append([predicted_label[0], words[model_word]])

    values = merge_pairs_v2(table)
    values = [
        ["Heder", x[1]] if x[0] == "q" else ["Section", x[1]] for x in values
    ]
    table = create_pretty_table(values)
    return image, table


import gradio as gr

description_text = """
<p>
    Heading - <span style='color: blue;'>shown in blue</span><br>
    Section - <span style='color: green;'>shown in green</span><br>
    other - (ignored)<span style='color: violet;'>shown in violet</span>
</p>
"""

flagging_options = ["great example", "bad example"]


iface = gr.Interface(
    fn=interference,
    inputs=["file", "number"],
    outputs=["image", "html"],
    # examples=[["output.pdf", 1]],
    description=description_text,
    flagging_options=flagging_options,
)
# iface.save(".")
if __name__ == "__main__":
    iface.launch()