Multi-Model-OCR / app.py
IFMedTechdemo's picture
Update app.py
e71abcc verified
raw
history blame
9.96 kB
# CRITICAL: Import spaces FIRST before any CUDA-related packages
import spaces
import os
# Now import other packages
import gradio as gr
import torch
from PIL import Image
from transformers import (
AutoProcessor,
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer
)
from threading import Thread
import time
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Dots.OCR
MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix"
processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
model_d = AutoModelForCausalLM.from_pretrained(
MODEL_PATH_D,
attn_implementation="sdpa",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
).eval()
# Load olmOCR-2-7B-1025 (non-FP8 version for simplicity)
MODEL_ID_M = "allenai/olmOCR-2-7B-1025"
processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
model_m = AutoModel.from_pretrained(
MODEL_ID_M,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map="auto"
).eval()
# Load DeepSeek-OCR
MODEL_ID_DS = "deepseek-ai/DeepSeek-OCR"
tokenizer_ds = AutoTokenizer.from_pretrained(MODEL_ID_DS, trust_remote_code=True)
model_ds = AutoModel.from_pretrained(
MODEL_ID_DS,
attn_implementation="sdpa",
trust_remote_code=True,
use_safetensors=True,
device_map="auto"
).eval().to(torch.bfloat16)
@spaces.GPU
def generate_image(model_name: str, text: str, image: Image.Image,
max_new_tokens: int, temperature: float, top_p: float,
top_k: int, repetition_penalty: float, resolution_mode: str):
"""
Generates responses using the selected model for image input.
Yields raw text and Markdown-formatted text.
"""
if image is None:
yield "Please upload an image.", "Please upload an image."
return
# Handle DeepSeek-OCR separately due to different API
if model_name == "DeepSeek-OCR":
resolution_configs = {
"Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
"Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
"Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
"Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
"Gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True}
}
config = resolution_configs[resolution_mode]
temp_image_path = "/tmp/temp_ocr_image.jpg"
image.save(temp_image_path)
if not text:
text = "Free OCR."
prompt_ds = f"<image>\n{text}"
try:
result = model_ds.infer(
tokenizer_ds,
prompt=prompt_ds,
image_file=temp_image_path,
output_path="/tmp",
base_size=config["base_size"],
image_size=config["image_size"],
crop_mode=config["crop_mode"],
test_compress=True,
save_results=False
)
yield result, result
except Exception as e:
yield f"Error: {str(e)}", f"Error: {str(e)}"
finally:
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
return
# Handle other models with standard API
if model_name == "olmOCR-2-7B-1025":
processor = processor_m
model = model_m
elif model_name == "Dots.OCR":
processor = processor_d
model = model_d
else:
yield "Invalid model selected.", "Invalid model selected."
return
messages = [{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": text if text else "Perform OCR on this image."},
]
}]
prompt_full = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[prompt_full],
images=[image],
return_tensors="pt",
padding=True
).to(device)
streamer = TextIteratorStreamer(
processor, skip_prompt=True, skip_special_tokens=True
)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
buffer = buffer.replace("<|im_end|>", "")
time.sleep(0.01)
yield buffer, buffer
# Image examples
image_examples = [
["OCR the content perfectly.", "examples/3.jpg"],
["Perform OCR on the image.", "examples/1.jpg"],
["Extract the contents. [page].", "examples/2.jpg"],
]
# CSS styling
css = """
.gradio-container {
max-width: 1400px;
margin: auto;
}
.model-selector {
font-size: 16px;
}
"""
# Build Gradio interface
with gr.Blocks(css=css, title="Multi-Model OCR Space") as demo:
gr.Markdown(
"""
# 🔍 Multi-Model OCR Comparison Space
Compare three state-of-the-art OCR models:
- **Dots.OCR**: Lightweight and efficient OCR
- **olmOCR-2-7B-1025**: Advanced OCR for math, tables, and complex layouts (82.4% accuracy)
- **DeepSeek-OCR**: Context compression OCR with 10× compression (97% accuracy)
"""
)
with gr.Row():
with gr.Column(scale=1):
model_selector = gr.Dropdown(
choices=["Dots.OCR", "olmOCR-2-7B-1025", "DeepSeek-OCR"],
value="olmOCR-2-7B-1025",
label="Select OCR Model",
elem_classes=["model-selector"]
)
resolution_selector = gr.Dropdown(
choices=["Tiny", "Small", "Base", "Large", "Gundam"],
value="Gundam",
label="DeepSeek-OCR Resolution Mode",
info="Only applies to DeepSeek-OCR. Gundam mode recommended.",
visible=False
)
image_input = gr.Image(type="pil", label="Upload Image")
text_input = gr.Textbox(
value="Perform OCR on this image.",
label="Prompt",
lines=2
)
with gr.Accordion("Advanced Settings", open=False):
max_tokens_slider = gr.Slider(
minimum=256,
maximum=8192,
value=2048,
step=256,
label="Max New Tokens"
)
temperature_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top P"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top K"
)
repetition_penalty_slider = gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.1,
label="Repetition Penalty"
)
submit_btn = gr.Button("🚀 Extract Text", variant="primary")
clear_btn = gr.ClearButton()
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Extracted Text",
lines=20,
show_copy_button=True
)
output_markdown = gr.Markdown(label="Formatted Output")
gr.Examples(
examples=image_examples,
inputs=[text_input, image_input],
label="Example Images"
)
# Show/hide resolution selector based on model
def update_resolution_visibility(model_name):
return gr.update(visible=(model_name == "DeepSeek-OCR"))
model_selector.change(
fn=update_resolution_visibility,
inputs=[model_selector],
outputs=[resolution_selector]
)
# Event handlers
submit_btn.click(
fn=generate_image,
inputs=[
model_selector,
text_input,
image_input,
max_tokens_slider,
temperature_slider,
top_p_slider,
top_k_slider,
repetition_penalty_slider,
resolution_selector
],
outputs=[output_text, output_markdown]
)
clear_btn.add([image_input, text_input, output_text, output_markdown])
gr.Markdown(
"""
### Model Strengths:
**Dots.OCR**: Fast and lightweight, great for simple documents and quick processing
**olmOCR-2-7B-1025**: Best for complex documents with tables, LaTeX equations, multi-column layouts, and handwritten text
**DeepSeek-OCR**: Excellent for markdown conversion, table extraction, and efficient context compression (10× smaller output)
### Tips:
- Upload clear, well-lit images for best results
- Use olmOCR for academic papers and technical documents
- Use DeepSeek for efficient processing of large document batches
- Adjust temperature for more creative or conservative outputs
"""
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()