Spaces:
Sleeping
Sleeping
| import torch | |
| from PIL import Image | |
| from transformers import VisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer | |
| import os | |
| import numpy as np | |
| from typing import Union | |
| # --- Configuration --- | |
| MODEL_PATH = "./ocr_model_output/checkpoint-441" | |
| class OCRInference: | |
| """A class to perform OCR inference using a trained model.""" | |
| def __init__(self, model_path: str, encoder_id: str = "google/vit-base-patch16-224-in21k", decoder_id: str = "prajjwal1/bert-tiny"): | |
| """ | |
| Initializes the OCRInference class by loading the model, image processor, and tokenizer. | |
| Args: | |
| model_path (str): The path to the trained model checkpoint. | |
| encoder_id (str): The encoder ID to load the image processor from. | |
| decoder_id (str): The decoder ID to load the tokenizer from. | |
| """ | |
| print(f"Loading model from: {model_path}") | |
| self.model = VisionEncoderDecoderModel.from_pretrained(model_path) | |
| # Load image processor and save it if not present | |
| try: | |
| self.image_processor = AutoImageProcessor.from_pretrained(model_path) | |
| except OSError: | |
| print("Image processor not found locally. Loading from encoder ID and saving.") | |
| self.image_processor = AutoImageProcessor.from_pretrained(encoder_id) | |
| self.image_processor.save_pretrained(model_path) | |
| # Load tokenizer and save it if not present | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| except (KeyError, OSError): | |
| print("Tokenizer not found locally. Loading from decoder ID and saving.") | |
| self.tokenizer = AutoTokenizer.from_pretrained(decoder_id) | |
| self.tokenizer.save_pretrained(model_path) | |
| # --- Set special tokens and generation parameters --- | |
| self.model.config.decoder_start_token_id = self.tokenizer.cls_token_id | |
| self.model.config.pad_token_id = self.tokenizer.pad_token_id | |
| self.model.config.vocab_size = self.tokenizer.vocab_size | |
| self.model.config.eos_token_id = self.tokenizer.sep_token_id | |
| self.model.config.max_length = 64 | |
| self.model.config.early_stopping = True | |
| self.model.config.no_repeat_ngram_size = 3 | |
| self.model.config.length_penalty = 2.0 | |
| self.model.config.num_beams = 4 | |
| print("Model, image processor, and tokenizer loaded.") | |
| def perform_inference(self, image_input: Union[str, np.ndarray]) -> str: | |
| """ | |
| Performs inference on a single image, which can be a file path or a NumPy array. | |
| Args: | |
| image_input (Union[str, np.ndarray]): Path to the input image or a NumPy array representing the image. | |
| Returns: | |
| str: The predicted text. | |
| """ | |
| if isinstance(image_input, str): | |
| if not os.path.exists(image_input): | |
| raise FileNotFoundError(f"Image file not found at: {image_input}") | |
| image = Image.open(image_input).convert("RGB") | |
| elif isinstance(image_input, np.ndarray): | |
| image = Image.fromarray(image_input).convert("RGB") | |
| else: | |
| raise TypeError("image_input must be a file path (str) or a NumPy array.") | |
| # Process the image | |
| pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values | |
| # Generate text | |
| with torch.no_grad(): | |
| output_ids = self.model.generate(pixel_values, max_length=64, num_beams=4, early_stopping=True) | |
| # Decode the generated ids to text | |
| preds = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return preds | |
| if __name__ == '__main__': | |
| # Provide a path to an image for inference | |
| # Using an example image from the dataset | |
| image_path = "../ai_augment_output/20250901_115123_336458_ccd9d646-fc99-4d27-8076-0c17d0dba784.png" | |
| # --- Initialize the Inference Class --- | |
| ocr_engine = OCRInference(model_path=MODEL_PATH) | |
| # --- Perform Inference from a file path --- | |
| try: | |
| predicted_text = ocr_engine.perform_inference(image_path) | |
| print(f"\n--- Inference from file path ---") | |
| print(f"Image: {image_path}") | |
| print(f"Predicted Text: {predicted_text}") | |
| except FileNotFoundError as e: | |
| print(e) | |
| print("Please update the 'image_path' variable in the script with a valid image path.") | |
| # --- Perform Inference from a NumPy array (example) --- | |
| try: | |
| # Create a dummy numpy array for demonstration | |
| if os.path.exists(image_path): | |
| dummy_image_array = np.array(Image.open(image_path)) | |
| predicted_text_from_array = ocr_engine.perform_inference(dummy_image_array) | |
| print(f"\n--- Inference from NumPy array ---") | |
| print(f"Predicted Text: {predicted_text_from_array}") | |
| except Exception as e: | |
| print(f"An error occurred during inference from NumPy array: {e}") |