import onnxruntime as ort import torch from PIL import Image import torchvision.transforms as T import numpy as np import string import logging import os from typing import List, Tuple from torch import Tensor # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class TokenDecoder: def __init__(self): self.specials_first = ('',) # [E] self.specials_last = ('', '') # [B], [P] self.charset = tuple(string.digits + string.ascii_lowercase + string.ascii_uppercase + string.punctuation) self.itos = self.specials_first + self.charset + self.specials_last self.stoi = {s: i for i, s in enumerate(self.itos)} self.eos_id = self.stoi[''] self.sos_id = self.stoi[''] self.pad_id = self.stoi[''] logger.info(f"Initialized TokenDecoder with {len(self.itos)} tokens, including {len(self.charset)} charset tokens.") def ids2tok(self, token_ids: List[int], join: bool = True) -> str: tokens = [self.itos[i] for i in token_ids if i < len(self.itos)] # Skip invalid indices return ''.join(tokens) if join else tokens def filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: ids = ids.tolist() try: eos_idx = ids.index(self.eos_id) except ValueError: eos_idx = len(ids) # No EOS, take all ids = ids[:eos_idx] # Exclude EOS and beyond probs = probs[:eos_idx] # Probabilities up to (excluding) EOS return probs, ids def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]: batch_tokens = [] batch_probs = [] for dist in token_dists: probs, ids = dist.max(-1) # Greedy selection if not raw: probs, ids = self.filter(probs, ids) tokens = self.ids2tok(ids) batch_tokens.append(tokens) batch_probs.append(probs) return batch_tokens, batch_probs def infer_onnx(onnx_path: str, image_path: str) -> None: try: # Verify ONNX model exists if not os.path.exists(onnx_path): raise FileNotFoundError(f"ONNX model not found at {onnx_path}") # Initialize ONNX runtime session logger.info(f"Loading ONNX model from {onnx_path}") session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider']) input_name = session.get_inputs()[0].name # Verify image exists if not os.path.exists(image_path): raise FileNotFoundError(f"Image not found at {image_path}") # Preprocess image logger.info(f"Processing image {image_path}") img = Image.open(image_path).convert('RGB') transform = T.Compose([ T.Resize((32, 128)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img_tensor = transform(img).unsqueeze(0).numpy() # (1, 3, 32, 128) # Run inference logger.info("Running inference") outputs = session.run(None, {input_name: img_tensor})[0] # (1, seq_len, 95) logits = torch.from_numpy(outputs) # Decode predictions decoder = TokenDecoder() pred, conf_scores = decoder.decode(logits) logger.info(f"Prediction: {pred[0]}") logger.info(f"Confidence scores: {conf_scores[0].numpy().tolist()}") return pred[0] except Exception as e: logger.error(f"Error during inference: {str(e)}") raise if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Perform inference with ONNX model.') parser.add_argument('--onnx', required=True, help='Path to ONNX model') parser.add_argument('--image', required=True, help='Path to input CAPTCHA image') args = parser.parse_args() infer_onnx(args.onnx, args.image)