Spaces:
Sleeping
Sleeping
| """ | |
| Predictor Module | |
| This module handles image classification predictions using Vision Transformer models. | |
| It provides functions for making predictions and creating visualization plots of results. | |
| Author: ViT-XAI-Dashboard Team | |
| License: MIT | |
| """ | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| def predict_image(image, model, processor, top_k=5): | |
| """ | |
| Perform inference on an image and return top-k predicted classes with probabilities. | |
| This function takes a PIL Image, preprocesses it using the model's processor, | |
| performs a forward pass through the model, and returns the top-k most likely | |
| class predictions along with their confidence scores. | |
| Args: | |
| image (PIL.Image): Input image to classify. Should be in RGB format. | |
| model (ViTForImageClassification): Pre-trained ViT model for inference. | |
| processor (ViTImageProcessor): Image processor for preprocessing. | |
| top_k (int, optional): Number of top predictions to return. Defaults to 5. | |
| Returns: | |
| tuple: A tuple containing three elements: | |
| - top_probs (np.ndarray): Array of shape (top_k,) with confidence scores | |
| - top_indices (np.ndarray): Array of shape (top_k,) with class indices | |
| - top_labels (list): List of length top_k with human-readable class names | |
| Raises: | |
| Exception: If prediction fails due to invalid image, model issues, or memory errors. | |
| Example: | |
| >>> from PIL import Image | |
| >>> image = Image.open("cat.jpg") | |
| >>> probs, indices, labels = predict_image(image, model, processor, top_k=3) | |
| >>> print(f"Top prediction: {labels[0]} with {probs[0]:.2%} confidence") | |
| Top prediction: tabby cat with 87.34% confidence | |
| Note: | |
| - Inference is performed with torch.no_grad() for efficiency | |
| - Automatically handles device placement (CPU/GPU) | |
| - Applies softmax to convert logits to probabilities | |
| """ | |
| try: | |
| # Get the device from the model parameters | |
| # This ensures inputs are moved to the same device as model (CPU or GPU) | |
| device = next(model.parameters()).device | |
| # Preprocess the image using the ViT processor | |
| # This handles resizing, normalization, and conversion to tensors | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Move all input tensors to the same device as the model | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Perform inference without gradient computation (saves memory and speeds up) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits # Raw model outputs before softmax | |
| # Apply softmax to convert logits to probabilities | |
| # dim=-1 applies softmax across the class dimension | |
| probabilities = F.softmax(logits, dim=-1)[0] # [0] removes batch dimension | |
| # Get the top-k highest probability predictions | |
| # Returns both values (probabilities) and indices (class IDs) | |
| top_probs, top_indices = torch.topk(probabilities, top_k) | |
| # Convert PyTorch tensors to NumPy arrays for easier handling | |
| top_probs = top_probs.cpu().numpy() | |
| top_indices = top_indices.cpu().numpy() | |
| # Convert class indices to human-readable labels using model's label mapping when available | |
| id2label = None | |
| if hasattr(model, "config") and hasattr(model.config, "id2label"): | |
| id2label = model.config.id2label | |
| top_labels = [ | |
| (id2label.get(int(idx), f"class_{int(idx)}") if isinstance(id2label, dict) else f"class_{int(idx)}") | |
| for idx in top_indices | |
| ] | |
| return top_probs, top_indices, top_labels | |
| except Exception as e: | |
| print(f"❌ Error during prediction: {str(e)}") | |
| raise | |
| def create_prediction_plot(probs, labels): | |
| """ | |
| Create a professional horizontal bar chart visualizing top predictions. | |
| This function generates a matplotlib figure with a horizontal bar chart showing | |
| the model's top predictions along with their confidence scores. The chart includes | |
| percentage labels on each bar and a clean, minimalist design. | |
| Args: | |
| probs (np.ndarray or list): Array of probability scores for each class. | |
| Should be in descending order (highest probability first). | |
| labels (list): List of human-readable class names corresponding to probabilities. | |
| Length must match probs. | |
| Returns: | |
| matplotlib.figure.Figure: A matplotlib Figure object containing the bar chart. | |
| Can be displayed with fig.show() or saved with fig.savefig(). | |
| Example: | |
| >>> probs = np.array([0.87, 0.08, 0.03, 0.01, 0.01]) | |
| >>> labels = ['tabby cat', 'tiger cat', 'Egyptian cat', 'lynx', 'cougar'] | |
| >>> fig = create_prediction_plot(probs, labels) | |
| >>> fig.savefig('predictions.png') | |
| Note: | |
| - Uses horizontal bars for better label readability | |
| - Automatically adds percentage labels on each bar | |
| - Includes subtle grid lines for easier value reading | |
| - X-axis is scaled to provide padding for percentage labels | |
| """ | |
| # Create figure and axis with specified size | |
| fig, ax = plt.subplots(figsize=(8, 4)) | |
| # Create horizontal bar chart | |
| # y_pos represents the vertical position of each bar | |
| y_pos = np.arange(len(labels)) | |
| bars = ax.barh(y_pos, probs, color="skyblue", alpha=0.8) | |
| # Set y-axis ticks and labels | |
| ax.set_yticks(y_pos) | |
| ax.set_yticklabels(labels, fontsize=10) | |
| # Set axis labels and title | |
| ax.set_xlabel("Confidence", fontsize=12) | |
| ax.set_title("Top Predictions", fontsize=14, fontweight="bold") | |
| # Add probability percentage text on each bar | |
| for i, (bar, prob) in enumerate(zip(bars, probs)): | |
| width = bar.get_width() # Get the bar length (probability value) | |
| # Place text slightly to the right of the bar end | |
| ax.text( | |
| width + 0.01, # X position (slightly right of bar) | |
| bar.get_y() + bar.get_height() / 2, # Y position (center of bar) | |
| f"{prob:.2%}", # Format as percentage with 2 decimal places | |
| va="center", # Vertical alignment | |
| fontsize=9, | |
| ) | |
| # Set x-axis limits with padding for percentage labels | |
| # 1.15 multiplier adds 15% padding to the right | |
| ax.set_xlim(0, max(probs) * 1.15) | |
| # Add subtle grid lines for easier value reading | |
| ax.grid(axis="x", alpha=0.3, linestyle="--") | |
| # Adjust layout to prevent label cutoff | |
| plt.tight_layout() | |
| return fig | |