import torch from PIL import Image from transformers import VisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer import os import numpy as np from typing import Union # --- Configuration --- MODEL_PATH = "./ocr_model_output/checkpoint-441" class OCRInference: """A class to perform OCR inference using a trained model.""" def __init__(self, model_path: str, encoder_id: str = "google/vit-base-patch16-224-in21k", decoder_id: str = "prajjwal1/bert-tiny"): """ Initializes the OCRInference class by loading the model, image processor, and tokenizer. Args: model_path (str): The path to the trained model checkpoint. encoder_id (str): The encoder ID to load the image processor from. decoder_id (str): The decoder ID to load the tokenizer from. """ print(f"Loading model from: {model_path}") self.model = VisionEncoderDecoderModel.from_pretrained(model_path) # Load image processor and save it if not present try: self.image_processor = AutoImageProcessor.from_pretrained(model_path) except OSError: print("Image processor not found locally. Loading from encoder ID and saving.") self.image_processor = AutoImageProcessor.from_pretrained(encoder_id) self.image_processor.save_pretrained(model_path) # Load tokenizer and save it if not present try: self.tokenizer = AutoTokenizer.from_pretrained(model_path) except (KeyError, OSError): print("Tokenizer not found locally. Loading from decoder ID and saving.") self.tokenizer = AutoTokenizer.from_pretrained(decoder_id) self.tokenizer.save_pretrained(model_path) # --- Set special tokens and generation parameters --- self.model.config.decoder_start_token_id = self.tokenizer.cls_token_id self.model.config.pad_token_id = self.tokenizer.pad_token_id self.model.config.vocab_size = self.tokenizer.vocab_size self.model.config.eos_token_id = self.tokenizer.sep_token_id self.model.config.max_length = 64 self.model.config.early_stopping = True self.model.config.no_repeat_ngram_size = 3 self.model.config.length_penalty = 2.0 self.model.config.num_beams = 4 print("Model, image processor, and tokenizer loaded.") def perform_inference(self, image_input: Union[str, np.ndarray]) -> str: """ Performs inference on a single image, which can be a file path or a NumPy array. Args: image_input (Union[str, np.ndarray]): Path to the input image or a NumPy array representing the image. Returns: str: The predicted text. """ if isinstance(image_input, str): if not os.path.exists(image_input): raise FileNotFoundError(f"Image file not found at: {image_input}") image = Image.open(image_input).convert("RGB") elif isinstance(image_input, np.ndarray): image = Image.fromarray(image_input).convert("RGB") else: raise TypeError("image_input must be a file path (str) or a NumPy array.") # Process the image pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values # Generate text with torch.no_grad(): output_ids = self.model.generate(pixel_values, max_length=64, num_beams=4, early_stopping=True) # Decode the generated ids to text preds = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) return preds if __name__ == '__main__': # Provide a path to an image for inference # Using an example image from the dataset image_path = "../ai_augment_output/20250901_115123_336458_ccd9d646-fc99-4d27-8076-0c17d0dba784.png" # --- Initialize the Inference Class --- ocr_engine = OCRInference(model_path=MODEL_PATH) # --- Perform Inference from a file path --- try: predicted_text = ocr_engine.perform_inference(image_path) print(f"\n--- Inference from file path ---") print(f"Image: {image_path}") print(f"Predicted Text: {predicted_text}") except FileNotFoundError as e: print(e) print("Please update the 'image_path' variable in the script with a valid image path.") # --- Perform Inference from a NumPy array (example) --- try: # Create a dummy numpy array for demonstration if os.path.exists(image_path): dummy_image_array = np.array(Image.open(image_path)) predicted_text_from_array = ocr_engine.perform_inference(dummy_image_array) print(f"\n--- Inference from NumPy array ---") print(f"Predicted Text: {predicted_text_from_array}") except Exception as e: print(f"An error occurred during inference from NumPy array: {e}")