Spaces:
Running
Running
import gradio as gr | |
import torch | |
import json | |
from PIL import Image, ImageDraw | |
import numpy as np | |
from transformers import ( | |
LayoutLMv3FeatureExtractor, | |
LayoutLMv3Tokenizer, | |
LayoutLMv3ForTokenClassification, | |
LayoutLMv3Config | |
) | |
import pytesseract | |
from datasets import load_dataset | |
import os | |
# Set up device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Constants | |
NUM_LABELS = 5 # 0: regular text, 1: title, 2: H1, 3: H2, 4: H3 | |
def create_student_model(num_labels=5): | |
"""Create a distilled version of LayoutLMv3""" | |
student_config = LayoutLMv3Config( | |
hidden_size=384, # vs 768 original | |
num_attention_heads=6, # vs 12 original | |
intermediate_size=1536, # vs 3072 original | |
num_hidden_layers=8, # vs 12 original | |
num_labels=num_labels | |
) | |
model = LayoutLMv3ForTokenClassification(student_config) | |
return model | |
def load_model(): | |
"""Load the model and components""" | |
print("Creating model components...") | |
# Create feature extractor | |
feature_extractor = LayoutLMv3FeatureExtractor( | |
do_resize=True, | |
size=224, | |
apply_ocr=False, | |
image_mean=[0.5, 0.5, 0.5], | |
image_std=[0.5, 0.5, 0.5] | |
) | |
# Create tokenizer | |
tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") | |
# Create student model | |
model = create_student_model(num_labels=NUM_LABELS) | |
model.to(device) | |
# For demo purposes, we'll use random weights | |
# In production, you would load your trained weights here | |
print("Model components created successfully!") | |
return model, feature_extractor, tokenizer | |
def perform_ocr(image): | |
"""Extract text and bounding boxes from image using OCR""" | |
try: | |
# Convert PIL image to numpy array | |
img_array = np.array(image) | |
# Get OCR data | |
ocr_data = pytesseract.image_to_data(img_array, output_type=pytesseract.Output.DICT) | |
words = [] | |
boxes = [] | |
confidences = ocr_data['conf'] | |
for i in range(len(ocr_data['text'])): | |
if int(confidences[i]) > 30: # Filter low confidence | |
word = ocr_data['text'][i].strip() | |
if word: # Only add non-empty words | |
x, y, w, h = (ocr_data['left'][i], ocr_data['top'][i], | |
ocr_data['width'][i], ocr_data['height'][i]) | |
# Normalize coordinates | |
img_width, img_height = image.size | |
normalized_box = [ | |
x / img_width, | |
y / img_height, | |
(x + w) / img_width, | |
(y + h) / img_height | |
] | |
words.append(word) | |
boxes.append(normalized_box) | |
return words, boxes | |
except Exception as e: | |
print(f"OCR failed: {e}") | |
return ["sample", "text"], [[0, 0, 0.5, 0.1], [0.5, 0, 1.0, 0.1]] | |
def extract_headings_from_image(image, model, feature_extractor, tokenizer): | |
"""Extract headings from uploaded image using the model""" | |
try: | |
# Perform OCR to get words and boxes | |
words, boxes = perform_ocr(image) | |
if not words: | |
return {"ERROR": ["No text found in image"]} | |
# Prepare inputs for the model | |
# Process image | |
pixel_values = feature_extractor(image, return_tensors="pt")["pixel_values"] | |
pixel_values = pixel_values.to(device) | |
# Process text and boxes (limit to first 512 tokens) | |
max_words = min(len(words), 500) # Leave room for special tokens | |
words = words[:max_words] | |
boxes = boxes[:max_words] | |
# Convert boxes to the format expected by LayoutLMv3 (0-1000 scale) | |
scaled_boxes = [] | |
for box in boxes: | |
scaled_box = [ | |
int(box[0] * 1000), | |
int(box[1] * 1000), | |
int(box[2] * 1000), | |
int(box[3] * 1000) | |
] | |
scaled_boxes.append(scaled_box) | |
# Tokenize | |
encoding = tokenizer( | |
words, | |
boxes=scaled_boxes, | |
max_length=512, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt" | |
) | |
# Move to device | |
input_ids = encoding["input_ids"].to(device) | |
attention_mask = encoding["attention_mask"].to(device) | |
bbox = encoding["bbox"].to(device) | |
# Run inference | |
with torch.no_grad(): | |
outputs = model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
bbox=bbox, | |
pixel_values=pixel_values | |
) | |
# Get predictions | |
predictions = torch.argmax(outputs.logits, dim=-1).cpu().numpy()[0] | |
# Map predictions back to words | |
word_ids = encoding.word_ids(batch_index=0) | |
# Extract headings by label | |
headings = {"TITLE": [], "H1": [], "H2": [], "H3": []} | |
label_map = {0: "TEXT", 1: "TITLE", 2: "H1", 3: "H2", 4: "H3"} | |
current_heading = {"text": "", "level": None} | |
for i, (word_id, pred) in enumerate(zip(word_ids, predictions)): | |
if word_id is not None and word_id < len(words): | |
predicted_label = label_map.get(pred, "TEXT") | |
if predicted_label != "TEXT": | |
if current_heading["level"] == predicted_label: | |
# Continue building current heading | |
current_heading["text"] += " " + words[word_id] | |
else: | |
# Save previous heading if it exists | |
if current_heading["text"] and current_heading["level"]: | |
headings[current_heading["level"]].append(current_heading["text"].strip()) | |
# Start new heading | |
current_heading = {"text": words[word_id], "level": predicted_label} | |
else: | |
# Save current heading when we hit regular text | |
if current_heading["text"] and current_heading["level"]: | |
headings[current_heading["level"]].append(current_heading["text"].strip()) | |
current_heading = {"text": "", "level": None} | |
# Save final heading | |
if current_heading["text"] and current_heading["level"]: | |
headings[current_heading["level"]].append(current_heading["text"].strip()) | |
# Remove empty lists and return | |
headings = {k: v for k, v in headings.items() if v} | |
if not headings: | |
return {"INFO": ["No headings detected - this might be a model training issue"]} | |
return headings | |
except Exception as e: | |
return {"ERROR": [f"Processing failed: {str(e)}"]} | |
# Load model (this will happen when the Space starts) | |
print("Loading model...") | |
model, feature_extractor, tokenizer = load_model() | |
print("Model loaded successfully!") | |
def process_document(image): | |
"""Main function to process uploaded document""" | |
if image is None: | |
return "Please upload an image" | |
print("Processing uploaded image...") | |
# Extract headings | |
headings = extract_headings_from_image(image, model, feature_extractor, tokenizer) | |
# Format output | |
result = "## Extracted Document Structure:\n\n" | |
if "ERROR" in headings: | |
result += f"❌ **Error:** {headings['ERROR'][0]}\n" | |
return result | |
if "INFO" in headings: | |
result += f"ℹ️ **Info:** {headings['INFO'][0]}\n" | |
return result | |
# Display found headings | |
for level, texts in headings.items(): | |
result += f"**{level}:**\n" | |
for text in texts: | |
if level == "TITLE": | |
result += f"# {text}\n" | |
elif level == "H1": | |
result += f"## {text}\n" | |
elif level == "H2": | |
result += f"### {text}\n" | |
elif level == "H3": | |
result += f"#### {text}\n" | |
result += "\n" | |
if not any(headings.values()): | |
result += "⚠️ No headings were detected in this image.\n\n" | |
result += "**Possible reasons:**\n" | |
result += "- The model needs training on actual data\n" | |
result += "- The image quality is too low\n" | |
result += "- The document doesn't contain clear headings\n" | |
return result | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=process_document, | |
inputs=gr.Image(type="pil", label="Upload Document Image"), | |
outputs=gr.Markdown(label="Extracted Headings"), | |
title="📄 PDF Heading Extractor", | |
description=""" | |
Upload an image of a document to extract its heading hierarchy. | |
**Note:** This is a demo version using an untrained model. | |
The actual model would need to be trained on DocLayNet data for accurate results. | |
""", | |
examples=None, | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch() |