import numpy as np import gradio as gr import os from transformers import AutoModel, AutoTokenizer import torch from PIL import Image import warnings import re # Suppress warnings warnings.simplefilter("ignore") # Retrieve Hugging Face token hf_token = os.getenv("HF_TOKEN") # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, use_auth_token=hf_token) model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda' if torch.cuda.is_available() else 'cpu', use_safetensors=True, pad_token_id=tokenizer.eos_token_id, use_auth_token=hf_token) model = model.eval() # Global variable to store OCR result ocr_result = "" # Perform OCR function def perform_ocr(image): global ocr_result # Convert the numpy array to a PIL image pil_image = Image.fromarray(image) # Save the image temporarily image_file = "temp_image.png" pil_image.save(image_file) # Perform OCR with the model with torch.no_grad(): ocr_result = model.chat(tokenizer, image_file, ocr_type='ocr') # Optionally remove the temporary image file os.remove(image_file) return ocr_result # Function to highlight search term with a different color (e.g., light blue) def highlight_text(text, query): # Use regex to wrap the search query with a span for styling pattern = re.compile(re.escape(query), re.IGNORECASE) highlighted_text = pattern.sub(f"{query}", text) return highlighted_text # Search functionality to search within OCR result, highlight, and return the modified text def search_text(query): # If no query is provided, return the original OCR result if not query: return ocr_result, "No matches found." # Highlight the searched term in the OCR text highlighted_result = highlight_text(ocr_result, query) # Split OCR result into lines and search for the query lines = ocr_result.split('\n') matching_lines = [line for line in lines if query.lower() in line.lower()] if matching_lines: return highlighted_result, '\n'.join(matching_lines) # Return highlighted text and matched lines else: return highlighted_result, "No matches found." # Set up Gradio interface with gr.Blocks() as demo: # Section for uploading image and getting OCR results with gr.Row(): with gr.Column(): image_input = gr.Image(type="numpy", label="Upload Image") ocr_output = gr.HTML(label="OCR Output") # Changed to HTML for displaying highlighted text ocr_button = gr.Button("Run OCR") # Section for searching within the OCR result with gr.Row(): with gr.Column(): search_input = gr.Textbox(label="Search Text") search_output = gr.HTML(label="Search Result") # Separate output for search matches search_button = gr.Button("Search in OCR Text") # Define button actions ocr_button.click(perform_ocr, inputs=image_input, outputs=ocr_output) search_button.click(search_text, inputs=search_input, outputs=[ocr_output, search_output]) # Launch the Gradio interface demo.launch(share=True)