PauloFN's picture
first
6a6918c
raw
history blame
4.95 kB
import torch
from PIL import Image
from transformers import VisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer
import os
import numpy as np
from typing import Union
# --- Configuration ---
MODEL_PATH = "./ocr_model_output/checkpoint-441"
class OCRInference:
"""A class to perform OCR inference using a trained model."""
def __init__(self, model_path: str, encoder_id: str = "google/vit-base-patch16-224-in21k", decoder_id: str = "prajjwal1/bert-tiny"):
"""
Initializes the OCRInference class by loading the model, image processor, and tokenizer.
Args:
model_path (str): The path to the trained model checkpoint.
encoder_id (str): The encoder ID to load the image processor from.
decoder_id (str): The decoder ID to load the tokenizer from.
"""
print(f"Loading model from: {model_path}")
self.model = VisionEncoderDecoderModel.from_pretrained(model_path)
# Load image processor and save it if not present
try:
self.image_processor = AutoImageProcessor.from_pretrained(model_path)
except OSError:
print("Image processor not found locally. Loading from encoder ID and saving.")
self.image_processor = AutoImageProcessor.from_pretrained(encoder_id)
self.image_processor.save_pretrained(model_path)
# Load tokenizer and save it if not present
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
except (KeyError, OSError):
print("Tokenizer not found locally. Loading from decoder ID and saving.")
self.tokenizer = AutoTokenizer.from_pretrained(decoder_id)
self.tokenizer.save_pretrained(model_path)
# --- Set special tokens and generation parameters ---
self.model.config.decoder_start_token_id = self.tokenizer.cls_token_id
self.model.config.pad_token_id = self.tokenizer.pad_token_id
self.model.config.vocab_size = self.tokenizer.vocab_size
self.model.config.eos_token_id = self.tokenizer.sep_token_id
self.model.config.max_length = 64
self.model.config.early_stopping = True
self.model.config.no_repeat_ngram_size = 3
self.model.config.length_penalty = 2.0
self.model.config.num_beams = 4
print("Model, image processor, and tokenizer loaded.")
def perform_inference(self, image_input: Union[str, np.ndarray]) -> str:
"""
Performs inference on a single image, which can be a file path or a NumPy array.
Args:
image_input (Union[str, np.ndarray]): Path to the input image or a NumPy array representing the image.
Returns:
str: The predicted text.
"""
if isinstance(image_input, str):
if not os.path.exists(image_input):
raise FileNotFoundError(f"Image file not found at: {image_input}")
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, np.ndarray):
image = Image.fromarray(image_input).convert("RGB")
else:
raise TypeError("image_input must be a file path (str) or a NumPy array.")
# Process the image
pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
# Generate text
with torch.no_grad():
output_ids = self.model.generate(pixel_values, max_length=64, num_beams=4, early_stopping=True)
# Decode the generated ids to text
preds = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return preds
if __name__ == '__main__':
# Provide a path to an image for inference
# Using an example image from the dataset
image_path = "../ai_augment_output/20250901_115123_336458_ccd9d646-fc99-4d27-8076-0c17d0dba784.png"
# --- Initialize the Inference Class ---
ocr_engine = OCRInference(model_path=MODEL_PATH)
# --- Perform Inference from a file path ---
try:
predicted_text = ocr_engine.perform_inference(image_path)
print(f"\n--- Inference from file path ---")
print(f"Image: {image_path}")
print(f"Predicted Text: {predicted_text}")
except FileNotFoundError as e:
print(e)
print("Please update the 'image_path' variable in the script with a valid image path.")
# --- Perform Inference from a NumPy array (example) ---
try:
# Create a dummy numpy array for demonstration
if os.path.exists(image_path):
dummy_image_array = np.array(Image.open(image_path))
predicted_text_from_array = ocr_engine.perform_inference(dummy_image_array)
print(f"\n--- Inference from NumPy array ---")
print(f"Predicted Text: {predicted_text_from_array}")
except Exception as e:
print(f"An error occurred during inference from NumPy array: {e}")