| import gradio as gr |
| import torch |
| import json |
| import spaces |
| import os |
| from PIL import Image |
| from transformers import AutoModelForCausalLM, AutoProcessor |
| from transformers.processing_utils import ProcessorMixin |
| from qwen_vl_utils import process_vision_info |
| from huggingface_hub import login |
|
|
| |
| MODEL_PATH = "rednote-hilab/dots.ocr" |
|
|
| |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
| if HF_TOKEN: |
| print("Authenticating with Hugging Face token...") |
| login(token=HF_TOKEN, add_to_git_credential=False) |
|
|
| |
| model = None |
| processor = None |
|
|
| def load_model(): |
| """Load model and processor on GPU""" |
| global model, processor |
| if model is None: |
| print(f"Loading model weights from {MODEL_PATH}...") |
| |
| |
| try: |
| import flash_attn |
| attn_implementation = "flash_attention_2" |
| print("Using FlashAttention2 for faster inference") |
| except ImportError: |
| attn_implementation = "eager" |
| print("FlashAttention2 not available, using default attention") |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_PATH, |
| dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| token=HF_TOKEN, |
| attn_implementation=attn_implementation |
| ) |
| print("Model loaded successfully.") |
|
|
| print(f"Loading processor from {MODEL_PATH}...") |
| |
| |
| _original_check = ProcessorMixin.check_argument_for_proper_class |
| |
| def _patched_check(self, attribute_name, value): |
| if attribute_name == "video_processor" and value is None: |
| return |
| return _original_check(self, attribute_name, value) |
| |
| ProcessorMixin.check_argument_for_proper_class = _patched_check |
| |
| try: |
| processor = AutoProcessor.from_pretrained( |
| MODEL_PATH, |
| trust_remote_code=True, |
| token=HF_TOKEN |
| ) |
| print("Processor loaded successfully.") |
| finally: |
| |
| ProcessorMixin.check_argument_for_proper_class = _original_check |
| |
| return model, processor |
|
|
| |
| PROMPTS = { |
| "Full Layout + OCR (English)": """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox. |
| |
| 1. Bbox format: [x1, y1, x2, y2] |
| |
| 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. |
| |
| 3. Text Extraction & Formatting Rules: |
| - Picture: For the 'Picture' category, the text field should be omitted. |
| - Formula: Format its text as LaTeX. |
| - Table: Format its text as HTML. |
| - All Others (Text, Title, etc.): Format their text as Markdown. |
| |
| 4. Constraints: |
| - The output text must be the original text from the image, with no translation. |
| - All layout elements must be sorted according to human reading order. |
| |
| 5. Final Output: The entire output must be a single JSON object.""", |
|
|
| "OCR Only": """Please extract all text from the image in reading order. Format the output as plain text, preserving the original structure as much as possible.""", |
| |
| "Layout Detection Only": """Please detect all layout elements in the image and output their bounding boxes and categories. Format: [{"bbox": [x1, y1, x2, y2], "category": "category_name"}]""", |
| |
| "Custom": "" |
| } |
|
|
| @spaces.GPU(duration=120) |
| def process_image(image, prompt_type, custom_prompt): |
| """Process image with OCR model""" |
| try: |
| |
| current_model, current_processor = load_model() |
| |
| |
| if prompt_type == "Custom" and custom_prompt.strip(): |
| prompt = custom_prompt |
| else: |
| prompt = PROMPTS[prompt_type] |
| |
| |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": prompt} |
| ] |
| } |
| ] |
| |
| |
| text = current_processor.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| image_inputs, video_inputs = process_vision_info(messages) |
| inputs = current_processor( |
| text=[text], |
| images=image_inputs, |
| videos=video_inputs, |
| padding=True, |
| return_tensors="pt", |
| ) |
| |
| inputs = inputs.to("cuda") |
| |
| |
| with torch.no_grad(): |
| generated_ids = current_model.generate( |
| **inputs, |
| max_new_tokens=24000, |
| temperature=0.1, |
| top_p=0.9, |
| ) |
| |
| |
| generated_ids_trimmed = [ |
| out_ids[len(in_ids):] |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
| ] |
| output_text = current_processor.batch_decode( |
| generated_ids_trimmed, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False |
| )[0] |
| |
| |
| try: |
| parsed_json = json.loads(output_text) |
| output_text = json.dumps(parsed_json, ensure_ascii=False, indent=2) |
| except: |
| pass |
| |
| return output_text |
| |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
| |
| with gr.Blocks(title="dots.ocr - Multilingual Document OCR") as demo: |
| gr.Markdown(""" |
| # ๐ dots.ocr - Multilingual Document Layout Parsing |
| |
| Upload a document image and get OCR results with layout detection. |
| This space uses the [dots.ocr](https://github.com/rednote-hilab/dots.ocr) model. |
| |
| **Features:** |
| - Multilingual support |
| - Layout detection (tables, formulas, text, etc.) |
| - Reading order preservation |
| - Formula extraction (LaTeX format) |
| - Table extraction (HTML format) |
| """) |
| |
| with gr.Row(): |
| with gr.Column(): |
| image_input = gr.Image( |
| type="pil", |
| label="Upload Document Image", |
| height=400 |
| ) |
| |
| prompt_type = gr.Dropdown( |
| choices=list(PROMPTS.keys()), |
| value="Full Layout + OCR (English)", |
| label="Prompt Type", |
| info="Select the type of processing you want" |
| ) |
| |
| custom_prompt = gr.Textbox( |
| label="Custom Prompt (used when 'Custom' is selected)", |
| placeholder="Enter your custom prompt here...", |
| lines=5, |
| visible=False |
| ) |
| |
| submit_btn = gr.Button("Process Document", variant="primary", size="lg") |
| |
| with gr.Column(): |
| output_text = gr.Textbox( |
| label="OCR Result", |
| lines=25, |
| show_copy_button=True |
| ) |
| |
| |
| def toggle_custom_prompt(choice): |
| return gr.update(visible=(choice == "Custom")) |
| |
| prompt_type.change( |
| fn=toggle_custom_prompt, |
| inputs=[prompt_type], |
| outputs=[custom_prompt] |
| ) |
| |
| submit_btn.click( |
| fn=process_image, |
| inputs=[image_input, prompt_type, custom_prompt], |
| outputs=[output_text] |
| ) |
| |
| |
| gr.Markdown("## ๐ Examples") |
| gr.Examples( |
| examples=[ |
| ["examples/example1.jpg", "Full Layout + OCR (English)", ""], |
| ["examples/example2.jpg", "OCR Only", ""], |
| ], |
| inputs=[image_input, prompt_type, custom_prompt], |
| outputs=[output_text], |
| fn=process_image, |
| cache_examples=False, |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|