Ayush kumar
Initial PDF heading extractor app
7fd6367
import gradio as gr
import torch
import json
from PIL import Image, ImageDraw
import numpy as np
from transformers import (
LayoutLMv3FeatureExtractor,
LayoutLMv3Tokenizer,
LayoutLMv3ForTokenClassification,
LayoutLMv3Config
)
import pytesseract
from datasets import load_dataset
import os
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Constants
NUM_LABELS = 5 # 0: regular text, 1: title, 2: H1, 3: H2, 4: H3
def create_student_model(num_labels=5):
"""Create a distilled version of LayoutLMv3"""
student_config = LayoutLMv3Config(
hidden_size=384, # vs 768 original
num_attention_heads=6, # vs 12 original
intermediate_size=1536, # vs 3072 original
num_hidden_layers=8, # vs 12 original
num_labels=num_labels
)
model = LayoutLMv3ForTokenClassification(student_config)
return model
def load_model():
"""Load the model and components"""
print("Creating model components...")
# Create feature extractor
feature_extractor = LayoutLMv3FeatureExtractor(
do_resize=True,
size=224,
apply_ocr=False,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5]
)
# Create tokenizer
tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
# Create student model
model = create_student_model(num_labels=NUM_LABELS)
model.to(device)
# For demo purposes, we'll use random weights
# In production, you would load your trained weights here
print("Model components created successfully!")
return model, feature_extractor, tokenizer
def perform_ocr(image):
"""Extract text and bounding boxes from image using OCR"""
try:
# Convert PIL image to numpy array
img_array = np.array(image)
# Get OCR data
ocr_data = pytesseract.image_to_data(img_array, output_type=pytesseract.Output.DICT)
words = []
boxes = []
confidences = ocr_data['conf']
for i in range(len(ocr_data['text'])):
if int(confidences[i]) > 30: # Filter low confidence
word = ocr_data['text'][i].strip()
if word: # Only add non-empty words
x, y, w, h = (ocr_data['left'][i], ocr_data['top'][i],
ocr_data['width'][i], ocr_data['height'][i])
# Normalize coordinates
img_width, img_height = image.size
normalized_box = [
x / img_width,
y / img_height,
(x + w) / img_width,
(y + h) / img_height
]
words.append(word)
boxes.append(normalized_box)
return words, boxes
except Exception as e:
print(f"OCR failed: {e}")
return ["sample", "text"], [[0, 0, 0.5, 0.1], [0.5, 0, 1.0, 0.1]]
def extract_headings_from_image(image, model, feature_extractor, tokenizer):
"""Extract headings from uploaded image using the model"""
try:
# Perform OCR to get words and boxes
words, boxes = perform_ocr(image)
if not words:
return {"ERROR": ["No text found in image"]}
# Prepare inputs for the model
# Process image
pixel_values = feature_extractor(image, return_tensors="pt")["pixel_values"]
pixel_values = pixel_values.to(device)
# Process text and boxes (limit to first 512 tokens)
max_words = min(len(words), 500) # Leave room for special tokens
words = words[:max_words]
boxes = boxes[:max_words]
# Convert boxes to the format expected by LayoutLMv3 (0-1000 scale)
scaled_boxes = []
for box in boxes:
scaled_box = [
int(box[0] * 1000),
int(box[1] * 1000),
int(box[2] * 1000),
int(box[3] * 1000)
]
scaled_boxes.append(scaled_box)
# Tokenize
encoding = tokenizer(
words,
boxes=scaled_boxes,
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt"
)
# Move to device
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
bbox = encoding["bbox"].to(device)
# Run inference
with torch.no_grad():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
bbox=bbox,
pixel_values=pixel_values
)
# Get predictions
predictions = torch.argmax(outputs.logits, dim=-1).cpu().numpy()[0]
# Map predictions back to words
word_ids = encoding.word_ids(batch_index=0)
# Extract headings by label
headings = {"TITLE": [], "H1": [], "H2": [], "H3": []}
label_map = {0: "TEXT", 1: "TITLE", 2: "H1", 3: "H2", 4: "H3"}
current_heading = {"text": "", "level": None}
for i, (word_id, pred) in enumerate(zip(word_ids, predictions)):
if word_id is not None and word_id < len(words):
predicted_label = label_map.get(pred, "TEXT")
if predicted_label != "TEXT":
if current_heading["level"] == predicted_label:
# Continue building current heading
current_heading["text"] += " " + words[word_id]
else:
# Save previous heading if it exists
if current_heading["text"] and current_heading["level"]:
headings[current_heading["level"]].append(current_heading["text"].strip())
# Start new heading
current_heading = {"text": words[word_id], "level": predicted_label}
else:
# Save current heading when we hit regular text
if current_heading["text"] and current_heading["level"]:
headings[current_heading["level"]].append(current_heading["text"].strip())
current_heading = {"text": "", "level": None}
# Save final heading
if current_heading["text"] and current_heading["level"]:
headings[current_heading["level"]].append(current_heading["text"].strip())
# Remove empty lists and return
headings = {k: v for k, v in headings.items() if v}
if not headings:
return {"INFO": ["No headings detected - this might be a model training issue"]}
return headings
except Exception as e:
return {"ERROR": [f"Processing failed: {str(e)}"]}
# Load model (this will happen when the Space starts)
print("Loading model...")
model, feature_extractor, tokenizer = load_model()
print("Model loaded successfully!")
def process_document(image):
"""Main function to process uploaded document"""
if image is None:
return "Please upload an image"
print("Processing uploaded image...")
# Extract headings
headings = extract_headings_from_image(image, model, feature_extractor, tokenizer)
# Format output
result = "## Extracted Document Structure:\n\n"
if "ERROR" in headings:
result += f"❌ **Error:** {headings['ERROR'][0]}\n"
return result
if "INFO" in headings:
result += f"ℹ️ **Info:** {headings['INFO'][0]}\n"
return result
# Display found headings
for level, texts in headings.items():
result += f"**{level}:**\n"
for text in texts:
if level == "TITLE":
result += f"# {text}\n"
elif level == "H1":
result += f"## {text}\n"
elif level == "H2":
result += f"### {text}\n"
elif level == "H3":
result += f"#### {text}\n"
result += "\n"
if not any(headings.values()):
result += "⚠️ No headings were detected in this image.\n\n"
result += "**Possible reasons:**\n"
result += "- The model needs training on actual data\n"
result += "- The image quality is too low\n"
result += "- The document doesn't contain clear headings\n"
return result
# Create Gradio interface
demo = gr.Interface(
fn=process_document,
inputs=gr.Image(type="pil", label="Upload Document Image"),
outputs=gr.Markdown(label="Extracted Headings"),
title="📄 PDF Heading Extractor",
description="""
Upload an image of a document to extract its heading hierarchy.
**Note:** This is a demo version using an untrained model.
The actual model would need to be trained on DocLayNet data for accurate results.
""",
examples=None,
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()