Spaces:
Running
on
Zero
Running
on
Zero
| # CRITICAL: Import spaces FIRST before any CUDA-related packages | |
| import spaces | |
| import os | |
| # Now import other packages | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import ( | |
| AutoProcessor, | |
| AutoModel, | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TextIteratorStreamer | |
| ) | |
| from threading import Thread | |
| import time | |
| # Device setup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load Dots.OCR | |
| MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix" | |
| processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True) | |
| model_d = AutoModelForCausalLM.from_pretrained( | |
| MODEL_PATH_D, | |
| attn_implementation="sdpa", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ).eval() | |
| # Load olmOCR-2-7B-1025 (non-FP8 version for simplicity) | |
| MODEL_ID_M = "allenai/olmOCR-2-7B-1025" | |
| processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) | |
| model_m = AutoModel.from_pretrained( | |
| MODEL_ID_M, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="sdpa", | |
| device_map="auto" | |
| ).eval() | |
| # Load DeepSeek-OCR | |
| MODEL_ID_DS = "deepseek-ai/DeepSeek-OCR" | |
| tokenizer_ds = AutoTokenizer.from_pretrained(MODEL_ID_DS, trust_remote_code=True) | |
| model_ds = AutoModel.from_pretrained( | |
| MODEL_ID_DS, | |
| attn_implementation="sdpa", | |
| trust_remote_code=True, | |
| use_safetensors=True, | |
| device_map="auto" | |
| ).eval().to(torch.bfloat16) | |
| def generate_image(model_name: str, text: str, image: Image.Image, | |
| max_new_tokens: int, temperature: float, top_p: float, | |
| top_k: int, repetition_penalty: float, resolution_mode: str): | |
| """ | |
| Generates responses using the selected model for image input. | |
| Yields raw text and Markdown-formatted text. | |
| """ | |
| if image is None: | |
| yield "Please upload an image.", "Please upload an image." | |
| return | |
| # Handle DeepSeek-OCR separately due to different API | |
| if model_name == "DeepSeek-OCR": | |
| resolution_configs = { | |
| "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, | |
| "Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, | |
| "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, | |
| "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, | |
| "Gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True} | |
| } | |
| config = resolution_configs[resolution_mode] | |
| temp_image_path = "/tmp/temp_ocr_image.jpg" | |
| image.save(temp_image_path) | |
| if not text: | |
| text = "Free OCR." | |
| prompt_ds = f"<image>\n{text}" | |
| try: | |
| result = model_ds.infer( | |
| tokenizer_ds, | |
| prompt=prompt_ds, | |
| image_file=temp_image_path, | |
| output_path="/tmp", | |
| base_size=config["base_size"], | |
| image_size=config["image_size"], | |
| crop_mode=config["crop_mode"], | |
| test_compress=True, | |
| save_results=False | |
| ) | |
| yield result, result | |
| except Exception as e: | |
| yield f"Error: {str(e)}", f"Error: {str(e)}" | |
| finally: | |
| if os.path.exists(temp_image_path): | |
| os.remove(temp_image_path) | |
| return | |
| # Handle other models with standard API | |
| if model_name == "olmOCR-2-7B-1025": | |
| processor = processor_m | |
| model = model_m | |
| elif model_name == "Dots.OCR": | |
| processor = processor_d | |
| model = model_d | |
| else: | |
| yield "Invalid model selected.", "Invalid model selected." | |
| return | |
| messages = [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": text if text else "Perform OCR on this image."}, | |
| ] | |
| }] | |
| prompt_full = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = processor( | |
| text=[prompt_full], | |
| images=[image], | |
| return_tensors="pt", | |
| padding=True | |
| ).to(device) | |
| streamer = TextIteratorStreamer( | |
| processor, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generation_kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "repetition_penalty": repetition_penalty, | |
| } | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text | |
| buffer = buffer.replace("<|im_end|>", "") | |
| time.sleep(0.01) | |
| yield buffer, buffer | |
| # Image examples | |
| image_examples = [ | |
| ["OCR the content perfectly.", "examples/3.jpg"], | |
| ["Perform OCR on the image.", "examples/1.jpg"], | |
| ["Extract the contents. [page].", "examples/2.jpg"], | |
| ] | |
| # CSS styling | |
| css = """ | |
| .gradio-container { | |
| max-width: 1400px; | |
| margin: auto; | |
| } | |
| .model-selector { | |
| font-size: 16px; | |
| } | |
| """ | |
| # Build Gradio interface | |
| with gr.Blocks(css=css, title="Multi-Model OCR Space") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🔍 Multi-Model OCR Comparison Space | |
| Compare three state-of-the-art OCR models: | |
| - **Dots.OCR**: Lightweight and efficient OCR | |
| - **olmOCR-2-7B-1025**: Advanced OCR for math, tables, and complex layouts (82.4% accuracy) | |
| - **DeepSeek-OCR**: Context compression OCR with 10× compression (97% accuracy) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_selector = gr.Dropdown( | |
| choices=["Dots.OCR", "olmOCR-2-7B-1025", "DeepSeek-OCR"], | |
| value="olmOCR-2-7B-1025", | |
| label="Select OCR Model", | |
| elem_classes=["model-selector"] | |
| ) | |
| resolution_selector = gr.Dropdown( | |
| choices=["Tiny", "Small", "Base", "Large", "Gundam"], | |
| value="Gundam", | |
| label="DeepSeek-OCR Resolution Mode", | |
| info="Only applies to DeepSeek-OCR. Gundam mode recommended.", | |
| visible=False | |
| ) | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| text_input = gr.Textbox( | |
| value="Perform OCR on this image.", | |
| label="Prompt", | |
| lines=2 | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| max_tokens_slider = gr.Slider( | |
| minimum=256, | |
| maximum=8192, | |
| value=2048, | |
| step=256, | |
| label="Max New Tokens" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top P" | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Top K" | |
| ) | |
| repetition_penalty_slider = gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.1, | |
| step=0.1, | |
| label="Repetition Penalty" | |
| ) | |
| submit_btn = gr.Button("🚀 Extract Text", variant="primary") | |
| clear_btn = gr.ClearButton() | |
| with gr.Column(scale=1): | |
| output_text = gr.Textbox( | |
| label="Extracted Text", | |
| lines=20, | |
| show_copy_button=True | |
| ) | |
| output_markdown = gr.Markdown(label="Formatted Output") | |
| gr.Examples( | |
| examples=image_examples, | |
| inputs=[text_input, image_input], | |
| label="Example Images" | |
| ) | |
| # Show/hide resolution selector based on model | |
| def update_resolution_visibility(model_name): | |
| return gr.update(visible=(model_name == "DeepSeek-OCR")) | |
| model_selector.change( | |
| fn=update_resolution_visibility, | |
| inputs=[model_selector], | |
| outputs=[resolution_selector] | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=generate_image, | |
| inputs=[ | |
| model_selector, | |
| text_input, | |
| image_input, | |
| max_tokens_slider, | |
| temperature_slider, | |
| top_p_slider, | |
| top_k_slider, | |
| repetition_penalty_slider, | |
| resolution_selector | |
| ], | |
| outputs=[output_text, output_markdown] | |
| ) | |
| clear_btn.add([image_input, text_input, output_text, output_markdown]) | |
| gr.Markdown( | |
| """ | |
| ### Model Strengths: | |
| **Dots.OCR**: Fast and lightweight, great for simple documents and quick processing | |
| **olmOCR-2-7B-1025**: Best for complex documents with tables, LaTeX equations, multi-column layouts, and handwritten text | |
| **DeepSeek-OCR**: Excellent for markdown conversion, table extraction, and efficient context compression (10× smaller output) | |
| ### Tips: | |
| - Upload clear, well-lit images for best results | |
| - Use olmOCR for academic papers and technical documents | |
| - Use DeepSeek for efficient processing of large document batches | |
| - Adjust temperature for more creative or conservative outputs | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() | |