import torch from PIL import ImageDraw, ImageFont, Image from transformers import AutoModelForTokenClassification, AutoProcessor import fitz # PyMuPDF import io import os MODEL_KEY = os.getenv("MODEL_KEY") def extract_data_from_pdf(pdf_path, page_number=0): """ Extracts image, words, and bounding boxes from a specified page of a PDF. Args: - pdf_path (str): Path to the PDF file. - page_number (int): Page number to extract data from (0-indexed). Returns: - image: An image of the specified page. - words: A list of words found on the page. - boxes: A list of bounding boxes corresponding to the words. """ # Open the PDF doc = fitz.open(pdf_path) page = doc.load_page(page_number) # Extract image of the page pix = page.get_pixmap() image_bytes = pix.tobytes("png") image = Image.open(io.BytesIO(image_bytes)) # Extract words and their bounding boxes words = [] boxes = [] for word in page.get_text("words"): words.append(word[4]) boxes.append(word[:4]) # (x0, y0, x1, y1) doc.close() return image, words, boxes def merge_pairs_v2(pairs): if not pairs: return [] merged = [pairs[0]] for current in pairs[1:]: last = merged[-1] if last[0] == current[0]: # Merge 'y' values (as strings) if 'x' values are the same merged[-1] = [last[0], last[1] + " " + current[1]] else: merged.append(current) return merged def create_pretty_table(data): table = "
" for row in data: color = ( "blue" if row[0] == "Heder" else "green" if row[0] == "Section" else "red" if row[0] == "Number" else "black" ) table += "

---{}---

{}".format( color, row[0], row[1] ) table += "
" return table # When using this function in Gradio, set the output type to 'html' def interference(example, page_number=0): try: image, words, boxes = extract_data_from_pdf(example, page_number) boxes = [list(map(int, box)) for box in boxes] # Process the image and words model = AutoModelForTokenClassification.from_pretrained( "karida/LayoutLMv3_RFP", use_auth_token=MODEL_KEY ) processor = AutoProcessor.from_pretrained( "microsoft/layoutlmv3-base", apply_ocr=False ) encoding = processor(image, words, boxes=boxes, return_tensors="pt") # Prediction with torch.no_grad(): outputs = model(**encoding) logits = outputs.logits predictions = logits.argmax(-1).squeeze().tolist() model_words = encoding.word_ids() # Process predictions token_boxes = encoding.bbox.squeeze().tolist() width, height = image.size true_predictions = [model.config.id2label[pred] for pred in predictions] true_boxes = token_boxes # Draw annotations on the image draw = ImageDraw.Draw(image) font = ImageFont.load_default() def iob_to_label(label): label = label[2:] return "other" if not label else label.lower() label2color = { "question": "blue", "answer": "green", "header": "red", "other": "violet", } # print(len(true_predictions), len(true_boxes), len(model_words)) table = [] ids = set() for prediction, box, model_word in zip( true_predictions, true_boxes, model_words ): predicted_label = iob_to_label(prediction) draw.rectangle(box, outline=label2color[predicted_label], width=2) # draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font) if ( model_word and model_word not in ids and predicted_label != "other" ): ids.add(model_word) table.append([predicted_label[0], words[model_word]]) values = merge_pairs_v2(table) # values = [ # ["Heder", x[1]] if x[0] == "q" else ["Section", x[1]] # for x in values # ] new_values = [] for x in values: if x[0] == "q": new_values.append(["Heder", x[1]]) elif x[0] == "a": new_values.append(["Section", x[1]]) elif x[0] == "h": new_values.append(["Number", x[1]]) table = create_pretty_table(new_values) print(table) return image, table except IndexError as e: image = Image.open("error_image.png") # Return a custom HTML-styled error message if an IndexError occurs gr.Error( "Error: in the current version of the model, the maximum number of words per page is 512." ) return ( image, f"
IndexError: {str(e)}
", ) except Exception as e: image = Image.open("error_image.png") # Handle other exceptions gr.Error(f"An error occurred: {e}") return ( image, f"
An error occurred: {str(e)}
", ) import gradio as gr description_text = """

Information Retrieval for Request For Proposal documents with use of the LayoutLM for Token Classification.
Classified tokens:
- Number - shown in red
- Header - shown in blue
- Section - shown in green
- other (ignored) - shown in violet

""" flagging_options = ["great example", "bad example"] iface = gr.Interface( fn=interference, inputs=["file", "number"], outputs=["image", "html"], # examples=[["output.pdf", 1]], description=description_text, flagging_options=flagging_options, ) if __name__ == "__main__": iface.launch()