Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import time | |
| import torch | |
| import spaces | |
| import warnings | |
| import tempfile | |
| import sys | |
| from io import StringIO | |
| from contextlib import contextmanager | |
| from threading import Thread | |
| from PIL import Image | |
| from transformers import ( | |
| AutoProcessor, | |
| AutoModelForCausalLM, | |
| AutoModel, | |
| AutoTokenizer, | |
| Qwen2_5_VLForConditionalGeneration, | |
| TextIteratorStreamer | |
| ) | |
| from huggingface_hub import snapshot_download | |
| from qwen_vl_utils import process_vision_info | |
| # Suppress the warning about uninitialized weights | |
| warnings.filterwarnings('ignore', message='Some weights.*were not initialized') | |
| # Try importing Qwen3VL if available | |
| try: | |
| from transformers import Qwen3VLForConditionalGeneration | |
| except ImportError: | |
| Qwen3VLForConditionalGeneration = None | |
| MAX_MAX_NEW_TOKENS = 4096 | |
| DEFAULT_MAX_NEW_TOKENS = 2048 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| CACHE_DIR = os.getenv("HF_CACHE_DIR", "./models") | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| print(f"Initial Device: {device}") | |
| print(f"CUDA Available: {torch.cuda.is_available()}") | |
| # Load Chandra-OCR | |
| try: | |
| MODEL_ID_V = "datalab-to/chandra" | |
| processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True) | |
| if Qwen3VLForConditionalGeneration: | |
| model_v = Qwen3VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_V, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ).eval() | |
| print("✓ Chandra-OCR loaded") | |
| else: | |
| model_v = None | |
| print("✗ Chandra-OCR: Qwen3VL not available") | |
| except Exception as e: | |
| model_v = None | |
| processor_v = None | |
| print(f"✗ Chandra-OCR: Failed to load - {str(e)}") | |
| # Load Nanonets-OCR2-3B | |
| try: | |
| MODEL_ID_X = "nanonets/Nanonets-OCR2-3B" | |
| processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) | |
| model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_X, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ).eval() | |
| print("✓ Nanonets-OCR2-3B loaded") | |
| except Exception as e: | |
| model_x = None | |
| processor_x = None | |
| print(f"✗ Nanonets-OCR2-3B: Failed to load - {str(e)}") | |
| # Load olmOCR-2-7B-1025 | |
| try: | |
| MODEL_ID_M = "allenai/olmOCR-2-7B-1025" | |
| processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) | |
| model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_M, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ).eval() | |
| print("✓ olmOCR-2-7B-1025 loaded") | |
| except Exception as e: | |
| model_m = None | |
| processor_m = None | |
| print(f"✗ olmOCR-2-7B-1025: Failed to load - {str(e)}") | |
| def generate_image(model_name: str, text: str, image: Image.Image, | |
| max_new_tokens: int, temperature: float, top_p: float, | |
| top_k: int, repetition_penalty: float): | |
| """ | |
| Generates responses using the selected model for image input. | |
| Yields raw text and Markdown-formatted text. | |
| This function is decorated with @spaces.GPU to ensure it runs on GPU | |
| when available in Hugging Face Spaces. | |
| Args: | |
| model_name: Name of the OCR model to use | |
| text: Prompt text for the model | |
| image: PIL Image object to process | |
| max_new_tokens: Maximum number of tokens to generate | |
| temperature: Sampling temperature | |
| top_p: Nucleus sampling parameter | |
| top_k: Top-k sampling parameter | |
| repetition_penalty: Penalty for repeating tokens | |
| Yields: | |
| tuple: (raw_text, markdown_text) | |
| """ | |
| # Device will be cuda when @spaces.GPU decorator activates | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # Select model and processor based on model_name | |
| if model_name == "olmOCR-2-7B-1025": | |
| if model_m is None: | |
| yield "olmOCR-2-7B-1025 is not available.", "olmOCR-2-7B-1025 is not available." | |
| return | |
| processor = processor_m | |
| model = model_m | |
| elif model_name == "Nanonets-OCR2-3B": | |
| if model_x is None: | |
| yield "Nanonets-OCR2-3B is not available.", "Nanonets-OCR2-3B is not available." | |
| return | |
| processor = processor_x | |
| model = model_x | |
| elif model_name == "Chandra-OCR": | |
| if model_v is None: | |
| yield "Chandra-OCR is not available.", "Chandra-OCR is not available." | |
| return | |
| processor = processor_v | |
| model = model_v | |
| else: | |
| yield "Invalid model selected.", "Invalid model selected." | |
| return | |
| if image is None: | |
| yield "Please upload an image.", "Please upload an image." | |
| return | |
| try: | |
| # Prepare messages in chat format | |
| messages = [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": text}, | |
| ] | |
| }] | |
| # Apply chat template with fallback | |
| try: | |
| prompt_full = processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| except Exception as template_error: | |
| # Fallback: create a simple prompt without chat template | |
| print(f"Chat template error: {template_error}. Using fallback prompt.") | |
| prompt_full = f"{text}" | |
| # Process inputs | |
| inputs = processor( | |
| text=[prompt_full], | |
| images=[image], | |
| return_tensors="pt", | |
| padding=True | |
| ).to(device) | |
| # Setup streaming generation | |
| streamer = TextIteratorStreamer( | |
| processor.tokenizer if hasattr(processor, 'tokenizer') else processor, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| generation_kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "repetition_penalty": repetition_penalty, | |
| } | |
| # Start generation in separate thread | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Stream the results | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text | |
| buffer = buffer.replace("<|im_end|>", "") | |
| time.sleep(0.01) | |
| yield buffer, buffer | |
| # Ensure thread completes | |
| thread.join() | |
| except Exception as e: | |
| error_msg = f"Error during generation: {str(e)}" | |
| print(f"Full error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| yield error_msg, error_msg | |
| # Example usage for Gradio interface | |
| if __name__ == "__main__": | |
| import gradio as gr | |
| # Determine available models | |
| available_models = [] | |
| if model_m is not None: | |
| available_models.append("olmOCR-2-7B-1025") | |
| print(" Added: olmOCR-2-7B-1025") | |
| if model_x is not None: | |
| available_models.append("Nanonets-OCR2-3B") | |
| print(" Added: Nanonets-OCR2-3B") | |
| if model_v is not None: | |
| available_models.append("Chandra-OCR") | |
| print(" Added: Chandra-OCR") | |
| if not available_models: | |
| print("ERROR: No models were loaded successfully!") | |
| exit(1) | |
| print(f"\n✓ Available models for dropdown: {', '.join(available_models)}") | |
| with gr.Blocks(title="Multi-Model OCR") as demo: | |
| gr.Markdown("# 🔍 Multi-Model OCR Application") | |
| gr.Markdown("Upload an image and select a model to extract text. Models run on GPU via Hugging Face Spaces.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_selector = gr.Dropdown( | |
| choices=available_models, | |
| value=available_models[0] if available_models else None, | |
| label="Select OCR Model" | |
| ) | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| text_input = gr.Textbox( | |
| value="Extract all text from this image.", | |
| label="Prompt", | |
| lines=2 | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| max_tokens = gr.Slider( | |
| minimum=1, | |
| maximum=MAX_MAX_NEW_TOKENS, | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| step=1, | |
| label="Max New Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top P" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Top K" | |
| ) | |
| repetition_penalty = gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.1, | |
| step=0.1, | |
| label="Repetition Penalty" | |
| ) | |
| submit_btn = gr.Button("Extract Text", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Extracted Text", lines=20) | |
| output_markdown = gr.Markdown(label="Formatted Output") | |
| gr.Markdown(""" | |
| ### Available Models: | |
| - **olmOCR-2-7B-1025**: Allen AI's OCR model | |
| - **Nanonets-OCR2-3B**: Nanonets OCR model | |
| - **Chandra-OCR**: Datalab OCR model | |
| """) | |
| submit_btn.click( | |
| fn=generate_image, | |
| inputs=[ | |
| model_selector, | |
| text_input, | |
| image_input, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty | |
| ], | |
| outputs=[output_text, output_markdown] | |
| ) | |
| # Launch with share=True for Hugging Face Spaces | |
| demo.launch(share=True) | |