Spaces:
Running
Running
#!/usr/bin/env python3 | |
import argparse | |
import os | |
from pathlib import Path | |
from PIL import Image, ImageDraw | |
from docling_core.types.doc import DoclingDocument, ImageRefMode | |
from docling_core.types.doc.document import DocTagsDocument | |
import torch | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
from transformers.image_utils import load_image | |
import sys | |
from pdf2image import convert_from_path | |
import tempfile | |
import json | |
import matplotlib.pyplot as plt | |
from pprint import pprint | |
import base64 | |
from dotenv import load_dotenv | |
import openai | |
from azure.ai.documentintelligence import DocumentIntelligenceClient | |
from azure.core.credentials import AzureKeyCredential | |
from smoldocling.overlays import generate_azure_overlay_html, generate_docling_overlay | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
load_dotenv() | |
def load_model(verbose=True): | |
"""Load the Smoldocling model and return model and processor.""" | |
if verbose: | |
print("Loading Smoldocling model...") | |
model_path = "ds4sd/SmolDocling-256M-preview" | |
processor = AutoProcessor.from_pretrained(model_path) | |
model = AutoModelForVision2Seq.from_pretrained( | |
model_path, | |
torch_dtype=torch.float16, # Use float16 for T4 GPU | |
).to(DEVICE) | |
return model, processor | |
def run_model(model, processor, image, prompt="Convert this page to docling.", verbose=True): | |
"""Run the Smoldocling model with the given image and prompt and return the doctags.""" | |
# Prepare inputs | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": prompt} | |
] | |
}, | |
] | |
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
inputs = processor( | |
text=formatted_prompt, | |
images=[image], | |
return_tensors="pt", | |
truncation=True, # ✅ Avoid truncation warning | |
).to(DEVICE) | |
# Generate output | |
if verbose: | |
print("Generating text...") | |
generated_ids = model.generate(**inputs, max_new_tokens=8192) | |
prompt_length = inputs.input_ids.shape[1] | |
trimmed_generated_ids = generated_ids[:, prompt_length:] | |
return processor.batch_decode(trimmed_generated_ids, skip_special_tokens=False)[0].lstrip() | |
def extract_text_from_document(image_path, model, processor, output_format="html", verbose=True): | |
"""Extract text from a document image using Smoldocling-256.""" | |
try: | |
# Load and preprocess the image | |
image = Image.open(image_path) | |
if verbose: | |
print(f"Processing {image_path}") | |
print(f"Image mode: {image.mode}") | |
print(f"Image size: {image.size}") | |
# Run docling vlm | |
output = run_model(model, processor, image, verbose=verbose) | |
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs( | |
[output], | |
[image] | |
) | |
doc = DoclingDocument(name=Path(image_path).stem).load_from_doctags(doctags_doc) | |
# Handle formatting and export | |
if output_format == "json": | |
# Export to dict (no images) | |
doc_dict = doc.export_to_dict() | |
# Remove images from the dict if present | |
if "pictures" in doc_dict: | |
for picture in doc_dict["pictures"]: | |
if "image" in picture: | |
if "uri" in picture["image"]: | |
del picture["image"]["uri"] | |
return doc_dict | |
else: | |
html_output = doc.export_to_html(image_mode=ImageRefMode.EMBEDDED) | |
return html_output | |
except Exception as e: | |
if verbose: | |
print(f"Error processing 1: {image_path}: {str(e)}", file=sys.stderr) | |
return None | |
def process_pdf(pdf_path, model, processor, output_dir, output_format="html", debug=False, verbose=True): | |
"""Process a PDF file by converting it to images and processing each page.""" | |
try: | |
if verbose: | |
print(f"\nProcessing PDF: {pdf_path}") | |
# Convert PDF to images | |
with tempfile.TemporaryDirectory() as temp_dir: | |
if verbose: | |
print("Converting PDF to images...") | |
# TODO: Review this. It's not working when the PDF is large. | |
images = convert_from_path( | |
pdf_path, | |
output_folder=temp_dir, | |
first_page=1, | |
fmt="png" | |
) | |
if not images: | |
if verbose: | |
print(f"No pages found in PDF: {pdf_path}", file=sys.stderr) | |
return | |
all_doctags = [] | |
all_images = [] | |
for i, image in enumerate(images, start=1): | |
image_path = os.path.join(temp_dir, f"page_{i}.png") | |
image.save(image_path, "PNG") | |
if verbose: | |
print(f"\nProcessing page {i}") | |
try: | |
image = Image.open(image_path) | |
if verbose: | |
print(f"Processing {image_path}") | |
print(f"Image mode: {image.mode}") | |
print(f"Image size: {image.size}") | |
output = run_model(model, processor, image, verbose=verbose) | |
cleaned_output = output.replace("<end_of_utterance>", "").strip() | |
# If you have charts: | |
if "<chart>" in cleaned_output: | |
cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>") | |
all_doctags.append(cleaned_output) | |
all_images.append(image) | |
if verbose: | |
print(f"Successfully processed page {i}") | |
# DEBUG: Dump per-page JSON if requested | |
if debug and output_dir is not None: | |
# Create a single-page DocTagsDocument and DoclingDocument | |
doctags_doc_page = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], [image]) | |
doc_page = DoclingDocument(name=f"{Path(pdf_path).stem}_p{i}") | |
doc_page.load_from_doctags(doctags_doc_page) | |
doc_dict_page = doc_page.export_to_dict() | |
# Remove images from the dict if present | |
if "pages" in doc_dict_page: | |
for page in doc_dict_page["pages"]: | |
if "image" in page: | |
page["image"] = None | |
page_json_path = Path(output_dir) / f"{Path(pdf_path).stem}_p{i}.json" | |
with open(page_json_path, 'w', encoding='utf-8') as f: | |
json.dump(doc_dict_page, f, ensure_ascii=False, indent=2) | |
if verbose: | |
print(f"[DEBUG] Dumped page {i} JSON to {page_json_path}") | |
except Exception as e: | |
if verbose: | |
print(f"Error processing page {i}: {str(e)}", file=sys.stderr) | |
if all_doctags and all_images: | |
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs( | |
all_doctags, | |
all_images | |
) | |
doc = DoclingDocument(name=Path(pdf_path).stem) | |
doc.load_from_doctags(doctags_doc) | |
if output_format == "json": | |
doc_dict = doc.export_to_dict() | |
if "pages" in doc_dict: | |
for page in doc_dict["pages"]: | |
if "image" in page: | |
page["image"] = None | |
if output_dir is None: | |
return doc_dict | |
output_filename = f"{Path(pdf_path).stem}.json" | |
output_path = Path(output_dir) / output_filename | |
with open(output_path, 'w', encoding='utf-8') as f: | |
json.dump(doc_dict, f, ensure_ascii=False, indent=2) | |
if verbose: | |
print(f"\nSuccessfully saved combined output to {output_path}") | |
else: | |
html_output = doc.export_to_html(image_mode=ImageRefMode.EMBEDDED) | |
if output_dir is None: | |
return html_output | |
output_filename = f"{Path(pdf_path).stem}.html" | |
output_path = Path(output_dir) / output_filename | |
with open(output_path, 'w', encoding='utf-8') as f: | |
f.write(html_output) | |
if verbose: | |
print(f"\nSuccessfully saved combined output to {output_path}") | |
else: | |
if verbose: | |
print("No pages were successfully processed", file=sys.stderr) | |
except Exception as e: | |
if verbose: | |
print(f"Error processing PDF {pdf_path}: {str(e)}", file=sys.stderr) | |
def process_files(input_files, output_dir, output_format="html", debug=False, verbose=True): | |
"""Process multiple input files and generate outputs in the specified format.""" | |
if output_dir is not None: | |
os.makedirs(output_dir, exist_ok=True) | |
model, processor = load_model(verbose=verbose) | |
results = [] | |
for input_file in input_files: | |
try: | |
input_path = Path(input_file) | |
if input_path.suffix.lower() == '.pdf': | |
if output_dir is None: | |
# Collect results instead of writing to files | |
pdf_result = process_pdf(input_file, model, processor, None, output_format=output_format, debug=debug, verbose=verbose) | |
if pdf_result: | |
results.extend(pdf_result) | |
else: | |
process_pdf(input_file, model, processor, output_dir, output_format=output_format, debug=debug, verbose=verbose) | |
else: | |
if verbose: | |
print(f"\nProcessing: {input_file}") | |
doc_dict = extract_text_from_document(input_path, model, processor, output_format=output_format, verbose=verbose) | |
if doc_dict: | |
if output_dir is None: | |
results.append(doc_dict) | |
else: | |
output_path = Path(output_dir) / f"{input_path.stem}.{output_format}" | |
if verbose: | |
print(f"Output will be saved to: {output_path}") | |
with open(output_path, 'w', encoding='utf-8') as f: | |
if output_format == "json": | |
json.dump(doc_dict, f, ensure_ascii=False, indent=2) | |
elif output_format == "html": | |
f.write(doc_dict) | |
if verbose: | |
print(f"Successfully processed {input_file}") | |
else: | |
if verbose: | |
print(f"Failed to process {input_file}", file=sys.stderr) | |
except Exception as e: | |
if verbose: | |
print(f"Error processing 2 {input_file}: {str(e)}", file=sys.stderr) | |
if output_dir is None: | |
return results | |
def visualize_doc(doc_path, page_num=0): | |
""" | |
Visualize a document (PDF or image) with bounding boxes from its corresponding JSON annotation. | |
Args: | |
doc_path (str): Path to the input document file (PDF or image) | |
page_num (int): Page number to visualize for PDFs (default 0) | |
""" | |
# Load document | |
if doc_path.lower().endswith('.pdf'): | |
# Handle PDF with pdf2image | |
# pdf_doc = fitz.open(doc_path) | |
# page = pdf_doc[page_num] | |
# pix = page.get_pixmap() | |
# image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
images = convert_from_path(doc_path, first_page=1) | |
image = images[page_num] | |
else: | |
# Handle image | |
image = Image.open(doc_path).convert("RGB") | |
# Load corresponding JSON | |
json_path = doc_path.replace("input", "output").replace(".png", ".json") | |
if doc_path.lower().endswith('.pdf'): | |
# For PDFs, append page number to JSON filename | |
json_path = json_path.replace(".pdf", f"_p{page_num+1}.json") | |
with open(json_path, "r") as f: | |
doc = json.load(f) | |
# Collect all bounding boxes from texts and pictures | |
bboxes = [] | |
labels = [] | |
for text in doc.get("texts", []): | |
for prov in text.get("prov", []): | |
# Only process boxes from specified page for PDFs | |
# if doc_path.lower().endswith('.pdf') and prov.get("page_no") != page_num + 1: | |
if doc_path.lower().endswith('.pdf') and prov.get("page_no") != 1: # currently only works for first page | |
continue | |
bbox = prov.get("bbox") | |
if bbox: | |
bboxes.append([bbox["l"], bbox["t"], bbox["r"], bbox["b"]]) | |
labels.append(text.get("label", "")) | |
for pic in doc.get("pictures", []): | |
for prov in pic.get("prov", []): | |
# Only process boxes from specified page for PDFs | |
# if doc_path.lower().endswith('.pdf') and prov.get("page_no") != page_num + 1: | |
if doc_path.lower().endswith('.pdf') and prov.get("page_no") != 1: # currently only works for first page | |
continue | |
bbox = prov.get("bbox") | |
if bbox: | |
bboxes.append([bbox["l"], bbox["t"], bbox["r"], bbox["b"]]) | |
labels.append(pic.get("label", "picture")) | |
for table in doc.get("tables", []): | |
for prov in table.get("prov", []): | |
bbox = prov.get("bbox") | |
if bbox: | |
bboxes.append([bbox["l"], bbox["t"], bbox["r"], bbox["b"]]) | |
labels.append(table.get("label", "")) | |
# Draw bounding boxes | |
draw = ImageDraw.Draw(image) | |
for (l, t, r, b), label in zip(bboxes, labels): | |
draw.rectangle([l, t, r, b], outline="red", width=2) | |
if label: | |
draw.text((l, t-10), f"{label} ({l:.1f}, {t:.1f}, {r:.1f}, {b:.1f})", fill="red") | |
# Display | |
plt.figure(figsize=(10, 12)) | |
plt.imshow(image) | |
plt.axis("off") | |
plt.show() | |
def stitch_text_from_json(json_path, gpt_fix=False): | |
""" | |
Given a JSON file in the DoclingDocument format, stitch together all text fragments in the order specified in the body and group sections. | |
Print the result as plain text. Optionally send to GPT to fix line breaks and hyphenation. | |
Returns the stitched (and optionally cleaned) text as a string. | |
""" | |
stitched_text = None | |
with open(json_path, 'r', encoding='utf-8') as f: | |
doc = json.load(f) | |
texts = doc.get('texts', []) | |
groups = doc.get('groups', []) | |
body = doc.get('body', {}) | |
# Build lookup tables | |
texts_by_ref = {f"#/texts/{i}": t for i, t in enumerate(texts)} | |
groups_by_ref = {g['self_ref']: g for g in groups} | |
def extract_texts(children): | |
result = [] | |
for child in children: | |
ref = child.get('$ref') | |
if ref is None: | |
continue | |
if ref.startswith('#/texts/'): | |
text_obj = texts_by_ref.get(ref) | |
if text_obj: | |
text = text_obj.get('text', '') | |
if text: | |
result.append(text) | |
elif ref.startswith('#/groups/'): | |
group_obj = groups_by_ref.get(ref) | |
if group_obj: | |
result.extend(extract_texts(group_obj.get('children', []))) | |
return result | |
stitched_texts = extract_texts(body.get('children', [])) | |
final_text = '\n'.join(stitched_texts) | |
if gpt_fix: | |
try: | |
api_key = os.environ.get('OPENAI_API_KEY') | |
if not api_key: | |
print("OPENAI_API_KEY not set. Printing original stitched text.", file=sys.stderr) | |
print(final_text) | |
return final_text | |
client = openai.OpenAI(api_key=api_key) | |
prompt = ( | |
"You are a helpful assistant. " | |
"The following text was extracted from a document and may contain odd line breaks, hyphenated words split across lines, or other OCR artifacts. " | |
"Please rewrite the text as clean, readable prose, fixing line breaks, joining hyphenated words, and correcting obvious errors, but do not add or remove content.\n\n" | |
f"Text to fix:\n\n{final_text}\n\nCleaned text:" | |
) | |
response = client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=4096, | |
temperature=0.0, | |
) | |
cleaned_text = response.choices[0].message.content.strip() | |
print(cleaned_text) | |
return cleaned_text | |
except Exception as e: | |
print(f"[GPT-fix error] {e}. Printing original stitched text.", file=sys.stderr) | |
print(final_text) | |
return final_text | |
else: | |
print(final_text) | |
return final_text | |
def extract_with_azure(input_files, output_dir, output_format="json", verbose=True): | |
endpoint = os.environ.get("AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT") | |
key = os.environ.get("AZURE_DOCUMENT_INTELLIGENCE_KEY") | |
if not endpoint or not key: | |
print("Azure endpoint/key not set. Set AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT and AZURE_DOCUMENT_INTELLIGENCE_KEY in your environment.", file=sys.stderr) | |
return | |
client = DocumentIntelligenceClient(endpoint, AzureKeyCredential(key)) | |
os.makedirs(output_dir, exist_ok=True) | |
for input_file in input_files: | |
with open(input_file, "rb") as f: | |
file_bytes = f.read() | |
poller = client.begin_analyze_document( | |
model_id="prebuilt-layout", | |
body={"base64Source": base64.b64encode(file_bytes).decode("utf-8")} | |
) | |
result = poller.result() | |
output_path = Path(output_dir) / (Path(input_file).stem + ".json") | |
with open(output_path, "w", encoding="utf-8") as out_f: | |
json.dump(result.as_dict(), out_f, ensure_ascii=False, indent=2) | |
if verbose: | |
print(f"Azure baseline output saved to {output_path}") | |
def main(): | |
parser = argparse.ArgumentParser( | |
description="Process document images and PDFs using Smoldocling and generate HTML or JSON outputs" | |
) | |
subparsers = parser.add_subparsers(dest="command", required=False) | |
# Default parser for main processing | |
parser_main = subparsers.add_parser("process", help="Process images or PDFs to HTML/JSON (default)") | |
parser_main.add_argument( | |
'input_files', nargs='+', help='One or more input files (images or PDFs) to process' | |
) | |
parser_main.add_argument( | |
'-o', '--output-dir', default='output', help='Output directory for result files (default: output)' | |
) | |
parser_main.add_argument( | |
'--format', choices=['html', 'json'], default='html', help='Output format: html or json (default: html)' | |
) | |
parser_main.add_argument( | |
'--debug', action='store_true', help='Enable debug mode: dump each PDF page as a separate JSON file.' | |
) | |
# Overlay HTML subcommand | |
parser_overlay = subparsers.add_parser("overlay-html", help="Generate HTML overlay from PNG and JSON") | |
parser_overlay.add_argument('image_file', help='Source PNG image file') | |
parser_overlay.add_argument('json_file', help='Extracted JSON file with bounding boxes') | |
parser_overlay.add_argument('-o', '--output', help='Output HTML file (default: <image_file>_overlay.html)') | |
# Stitch text subcommand | |
parser_stitch = subparsers.add_parser("stitch-text", help="Stitch together text fragments from a JSON file and print as plain text") | |
parser_stitch.add_argument('json_file', help='Extracted JSON file to stitch') | |
parser_stitch.add_argument('--gpt-fix', action='store_true', help='Send stitched text to GPT to fix line breaks and hyphenation') | |
# Azure baseline subcommand | |
parser_azure = subparsers.add_parser( | |
"azure-baseline", help="Extract content using Azure Document Intelligence for baseline comparison" | |
) | |
parser_azure.add_argument( | |
'input_files', nargs='+', help='One or more input files (images or PDFs) to process with Azure Document Intelligence' | |
) | |
parser_azure.add_argument( | |
'-o', '--output-dir', default='output_azure', help='Output directory for Azure baseline result files (default: output_azure)' | |
) | |
parser_azure.add_argument( | |
'--format', choices=['json'], default='json', help='Output format: json (default: json)' | |
) | |
# Azure overlay HTML subcommand | |
parser_azure_overlay = subparsers.add_parser("azure-overlay-html", help="Generate HTML overlay for Azure Document Intelligence output (words)") | |
parser_azure_overlay.add_argument('--image', required=True, help='Path to scanned image file') | |
parser_azure_overlay.add_argument('--json', required=True, help='Path to Azure Document Intelligence JSON file') | |
parser_azure_overlay.add_argument('--output', required=True, help='Path to output HTML file') | |
args = parser.parse_args() | |
if args.command == "overlay-html": | |
output_html = args.output or (os.path.splitext(args.image_file)[0] + "_overlay.html") | |
generate_docling_overlay(args.image_file, args.json_file, output_html) | |
return | |
if args.command == "stitch-text": | |
stitch_text_from_json(args.json_file, gpt_fix=getattr(args, 'gpt_fix', False)) | |
return | |
if args.command == "azure-baseline": | |
extract_with_azure( | |
args.input_files, | |
args.output_dir, | |
output_format=args.format, | |
verbose=True | |
) | |
return | |
if args.command == "azure-overlay-html": | |
generate_azure_overlay_html(args.image, args.json, args.output) | |
return | |
# Default: process | |
valid_files = [] | |
for file_path in args.input_files: | |
if not os.path.exists(file_path): | |
print(f"Warning: File not found: {file_path}", file=sys.stderr) | |
else: | |
valid_files.append(file_path) | |
if not valid_files: | |
print("Error: No valid input files provided", file=sys.stderr) | |
sys.exit(1) | |
process_files(valid_files, args.output_dir, output_format=args.format, debug=args.debug) | |
if __name__ == '__main__': | |
main() |