Spaces:
Running
Running
File size: 4,678 Bytes
d4a6a10 4a0e08c 73be2fe d4a6a10 d15bee8 4a0e08c d4a6a10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import torch
from PIL import ImageDraw, ImageFont, Image
from transformers import AutoModelForTokenClassification, AutoProcessor
import fitz # PyMuPDF
import io
import os
MODEL_KEY = os.getenv("MODEL_KEY")
def extract_data_from_pdf(pdf_path, page_number=0):
"""
Extracts image, words, and bounding boxes from a specified page of a PDF.
Args:
- pdf_path (str): Path to the PDF file.
- page_number (int): Page number to extract data from (0-indexed).
Returns:
- image: An image of the specified page.
- words: A list of words found on the page.
- boxes: A list of bounding boxes corresponding to the words.
"""
# Open the PDF
doc = fitz.open(pdf_path)
page = doc.load_page(page_number)
# Extract image of the page
pix = page.get_pixmap()
image_bytes = pix.tobytes("png")
image = Image.open(io.BytesIO(image_bytes))
# Extract words and their bounding boxes
words = []
boxes = []
for word in page.get_text("words"):
words.append(word[4])
boxes.append(word[:4]) # (x0, y0, x1, y1)
doc.close()
return image, words, boxes
def merge_pairs_v2(pairs):
if not pairs:
return []
merged = [pairs[0]]
for current in pairs[1:]:
last = merged[-1]
if last[0] == current[0]:
# Merge 'y' values (as strings) if 'x' values are the same
merged[-1] = [last[0], last[1] + " " + current[1]]
else:
merged.append(current)
return merged
def create_pretty_table(data):
table = "<div>"
for row in data:
color = (
"blue"
if row[0] == "Heder"
else "green"
if row[0] == "Section"
else "black"
)
table += "<p style='color:{};'>---{}---</p>{}".format(
color, row[0], row[1]
)
table += "</div>"
return table
# When using this function in Gradio, set the output type to 'html'
def interference(example, page_number=0):
image, words, boxes = extract_data_from_pdf(example, page_number)
boxes = [list(map(int, box)) for box in boxes]
# Process the image and words
model = AutoModelForTokenClassification.from_pretrained(
"karida/LayoutLMv3_RFP",
use_auth_token=MODEL_KEY
)
processor = AutoProcessor.from_pretrained(
"microsoft/layoutlmv3-base", apply_ocr=False
)
encoding = processor(image, words, boxes=boxes, return_tensors="pt")
# Prediction
with torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
predictions = logits.argmax(-1).squeeze().tolist()
model_words = encoding.word_ids()
# Process predictions
token_boxes = encoding.bbox.squeeze().tolist()
width, height = image.size
true_predictions = [model.config.id2label[pred] for pred in predictions]
true_boxes = token_boxes
# Draw annotations on the image
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
def iob_to_label(label):
label = label[2:]
return "other" if not label else label.lower()
label2color = {
"question": "blue",
"answer": "green",
"header": "orange",
"other": "violet",
}
# print(len(true_predictions), len(true_boxes), len(model_words))
table = []
ids = set()
for prediction, box, model_word in zip(
true_predictions, true_boxes, model_words
):
predicted_label = iob_to_label(prediction)
draw.rectangle(box, outline=label2color[predicted_label], width=2)
# draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
if model_word and model_word not in ids and predicted_label != "other":
ids.add(model_word)
table.append([predicted_label[0], words[model_word]])
values = merge_pairs_v2(table)
values = [
["Heder", x[1]] if x[0] == "q" else ["Section", x[1]] for x in values
]
table = create_pretty_table(values)
return image, table
import gradio as gr
description_text = """
<p>
Heading - <span style='color: blue;'>shown in blue</span><br>
Section - <span style='color: green;'>shown in green</span><br>
other - (ignored)<span style='color: violet;'>shown in violet</span>
</p>
"""
flagging_options = ["great example", "bad example"]
iface = gr.Interface(
fn=interference,
inputs=["file", "number"],
outputs=["image", "html"],
# examples=[["output.pdf", 1]],
description=description_text,
flagging_options=flagging_options,
)
# iface.save(".")
if __name__ == "__main__":
iface.launch()
|