Multi-Model-OCR / app.py
IFMedTechdemo's picture
Update app.py
962d22d verified
raw
history blame
11.7 kB
"""
OCR Application with Multiple Models including DeepSeek OCR
Fixed version with @spaces.GPU decorator for Hugging Face Spaces
"""
import os
import time
import torch
import spaces
from threading import Thread
from PIL import Image
from transformers import (
AutoProcessor,
AutoModelForCausalLM,
Qwen2_5_VLForConditionalGeneration,
TextIteratorStreamer
)
from qwen_vl_utils import process_vision_info
# Try importing Qwen3VL if available
try:
from transformers import Qwen3VLForConditionalGeneration
except ImportError:
Qwen3VLForConditionalGeneration = None
MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 2048
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Initial Device: {device}")
print(f"CUDA Available: {torch.cuda.is_available()}")
# Load Chandra-OCR
try:
MODEL_ID_V = "datalab-to/chandra"
processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
if Qwen3VLForConditionalGeneration:
model_v = Qwen3VLForConditionalGeneration.from_pretrained(
MODEL_ID_V,
trust_remote_code=True,
torch_dtype=torch.float16
).eval()
print("✓ Chandra-OCR loaded")
else:
model_v = None
print("✗ Chandra-OCR: Qwen3VL not available")
except Exception as e:
model_v = None
processor_v = None
print(f"✗ Chandra-OCR: Failed to load - {str(e)}")
# Load Nanonets-OCR2-3B
try:
MODEL_ID_X = "nanonets/Nanonets-OCR2-3B"
processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_X,
trust_remote_code=True,
torch_dtype=torch.float16
).eval()
print("✓ Nanonets-OCR2-3B loaded")
except Exception as e:
model_x = None
processor_x = None
print(f"✗ Nanonets-OCR2-3B: Failed to load - {str(e)}")
# Load Dots.OCR - will be moved to GPU when needed
try:
MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix"
processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
model_d = AutoModelForCausalLM.from_pretrained(
MODEL_PATH_D,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
trust_remote_code=True
).eval()
print("✓ Dots.OCR loaded")
except Exception as e:
model_d = None
processor_d = None
print(f"✗ Dots.OCR: Failed to load - {str(e)}")
# Load olmOCR-2-7B-1025
try:
MODEL_ID_M = "allenai/olmOCR-2-7B-1025"
processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_M,
trust_remote_code=True,
torch_dtype=torch.float16
).eval()
print("✓ olmOCR-2-7B-1025 loaded")
except Exception as e:
model_m = None
processor_m = None
print(f"✗ olmOCR-2-7B-1025: Failed to load - {str(e)}")
# Load DeepSeek-OCR
try:
MODEL_ID_DS = "deepseek-ai/deepseek-ocr"
processor_ds = AutoProcessor.from_pretrained(MODEL_ID_DS, trust_remote_code=True)
model_ds = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_DS,
trust_remote_code=True,
torch_dtype=torch.float16
).eval()
print("✓ DeepSeek-OCR loaded")
except Exception as e:
model_ds = None
processor_ds = None
print(f"✗ DeepSeek-OCR: Failed to load - {str(e)}")
@spaces.GPU
def generate_image(model_name: str, text: str, image: Image.Image,
max_new_tokens: int, temperature: float, top_p: float,
top_k: int, repetition_penalty: float):
"""
Generates responses using the selected model for image input.
Yields raw text and Markdown-formatted text.
This function is decorated with @spaces.GPU to ensure it runs on GPU
when available in Hugging Face Spaces.
Args:
model_name: Name of the OCR model to use
text: Prompt text for the model
image: PIL Image object to process
max_new_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
repetition_penalty: Penalty for repeating tokens
Yields:
tuple: (raw_text, markdown_text)
"""
# Device will be cuda when @spaces.GPU decorator activates
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Select model and processor based on model_name
if model_name == "olmOCR-2-7B-1025":
if model_m is None:
yield "olmOCR-2-7B-1025 is not available.", "olmOCR-2-7B-1025 is not available."
return
processor = processor_m
model = model_m.to(device)
elif model_name == "Nanonets-OCR2-3B":
if model_x is None:
yield "Nanonets-OCR2-3B is not available.", "Nanonets-OCR2-3B is not available."
return
processor = processor_x
model = model_x.to(device)
elif model_name == "Chandra-OCR":
if model_v is None:
yield "Chandra-OCR is not available.", "Chandra-OCR is not available."
return
processor = processor_v
model = model_v.to(device)
elif model_name == "Dots.OCR":
if model_d is None:
yield "Dots.OCR is not available.", "Dots.OCR is not available."
return
processor = processor_d
model = model_d.to(device)
elif model_name == "DeepSeek-OCR":
if model_ds is None:
yield "DeepSeek-OCR is not available.", "DeepSeek-OCR is not available."
return
processor = processor_ds
model = model_ds.to(device)
else:
yield "Invalid model selected.", "Invalid model selected."
return
if image is None:
yield "Please upload an image.", "Please upload an image."
return
try:
# Prepare messages in chat format
messages = [{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": text},
]
}]
# Apply chat template
prompt_full = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Process inputs
inputs = processor(
text=[prompt_full],
images=[image],
return_tensors="pt",
padding=True
).to(device)
# Setup streaming generation
streamer = TextIteratorStreamer(
processor,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
}
# Start generation in separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream the results
buffer = ""
for new_text in streamer:
buffer += new_text
buffer = buffer.replace("<|im_end|>", "")
time.sleep(0.01)
yield buffer, buffer
# Ensure thread completes
thread.join()
except Exception as e:
error_msg = f"Error during generation: {str(e)}"
yield error_msg, error_msg
# Example usage for Gradio interface
if __name__ == "__main__":
import gradio as gr
# Determine available models
available_models = []
if model_m is not None:
available_models.append("olmOCR-2-7B-1025")
if model_x is not None:
available_models.append("Nanonets-OCR2-3B")
if model_v is not None:
available_models.append("Chandra-OCR")
if model_d is not None:
available_models.append("Dots.OCR")
if model_ds is not None:
available_models.append("DeepSeek-OCR")
if not available_models:
print("ERROR: No models were loaded successfully!")
exit(1)
print(f"\n✓ Available models: {', '.join(available_models)}")
with gr.Blocks(title="Multi-Model OCR") as demo:
gr.Markdown("# 🔍 Multi-Model OCR Application")
gr.Markdown("Upload an image and select a model to extract text. Models run on GPU via Hugging Face Spaces.")
with gr.Row():
with gr.Column():
model_selector = gr.Dropdown(
choices=available_models,
value=available_models[0] if available_models else None,
label="Select OCR Model"
)
image_input = gr.Image(type="pil", label="Upload Image")
text_input = gr.Textbox(
value="Extract all text from this image.",
label="Prompt",
lines=2
)
with gr.Accordion("Advanced Settings", open=False):
max_tokens = gr.Slider(
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
value=DEFAULT_MAX_NEW_TOKENS,
step=1,
label="Max New Tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.9,
step=0.05,
label="Top P"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top K"
)
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.1,
label="Repetition Penalty"
)
submit_btn = gr.Button("Extract Text", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Extracted Text", lines=20)
output_markdown = gr.Markdown(label="Formatted Output")
gr.Markdown("""
### Available Models:
- **olmOCR-2-7B-1025**: Allen AI's OCR model
- **Nanonets-OCR2-3B**: Nanonets OCR model
- **Chandra-OCR**: Datalab OCR model
- **Dots.OCR**: Stranger Vision OCR model
- **DeepSeek-OCR**: DeepSeek AI's OCR model
""")
submit_btn.click(
fn=generate_image,
inputs=[
model_selector,
text_input,
image_input,
max_tokens,
temperature,
top_p,
top_k,
repetition_penalty
],
outputs=[output_text, output_markdown]
)
demo.launch()