mnist-cnn-classifier / inference.py
Pratik45's picture
Initial upload: MNIST CNN classifier with 99.60% accuracy
21f4ad5
"""
Inference script for making predictions with trained MNIST models
Usage: python inference.py --model-path checkpoints/best_model.pth --image-path my_digit.png
"""
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import argparse
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
# Model architectures (must match training)
class ConvNet(nn.Module):
"""Convolutional Neural Network for MNIST"""
def __init__(self, dropout_rate=0.3, num_classes=10):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.bn4 = nn.BatchNorm2d(128)
self.pool = nn.MaxPool2d(2, 2)
self.dropout_conv = nn.Dropout2d(dropout_rate * 0.5)
self.fc1 = nn.Linear(128 * 7 * 7, 256)
self.bn5 = nn.BatchNorm1d(256)
self.dropout1 = nn.Dropout(dropout_rate)
self.fc2 = nn.Linear(256, 128)
self.bn6 = nn.BatchNorm1d(128)
self.dropout2 = nn.Dropout(dropout_rate * 0.5)
self.fc3 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = torch.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = torch.relu(x)
x = self.pool(x)
x = self.dropout_conv(x)
x = self.conv3(x)
x = self.bn3(x)
x = torch.relu(x)
x = self.conv4(x)
x = self.bn4(x)
x = torch.relu(x)
x = self.pool(x)
x = self.dropout_conv(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.bn5(x)
x = torch.relu(x)
x = self.dropout1(x)
x = self.fc2(x)
x = self.bn6(x)
x = torch.relu(x)
x = self.dropout2(x)
x = self.fc3(x)
return x
class ImprovedNN(nn.Module):
"""Enhanced fully connected network"""
def __init__(self, input_size=784, hidden_sizes=[512, 256, 128],
num_classes=10, dropout_rate=0.3):
super(ImprovedNN, self).__init__()
layers = []
prev_size = input_size
for i, hidden_size in enumerate(hidden_sizes):
layers.extend([
nn.Linear(prev_size, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Dropout(dropout_rate if i < len(hidden_sizes) - 1 else dropout_rate * 0.5)
])
prev_size = hidden_size
layers.append(nn.Linear(prev_size, num_classes))
self.network = nn.Sequential(*layers)
def forward(self, x):
x = x.view(x.size(0), -1)
return self.network(x)
def load_model(model_path, model_type='cnn', device='cpu'):
"""Load a trained model from checkpoint"""
# Load checkpoint
checkpoint = torch.load(model_path, map_location=device)
# Get model type from checkpoint if available
if 'args' in checkpoint and 'model_type' in checkpoint['args']:
model_type = checkpoint['args']['model_type']
# Create model
if model_type == 'cnn':
model = ConvNet()
else:
model = ImprovedNN()
# Load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
print(f"✓ Loaded {model_type.upper()} model from {model_path}")
print(f" - Trained for {checkpoint.get('epoch', 'unknown')} epochs")
print(f" - Validation accuracy: {checkpoint.get('val_acc', 'unknown'):.2f}%")
return model
def preprocess_image(image_path):
"""Preprocess an image for inference"""
# Load image
img = Image.open(image_path).convert('L') # Convert to grayscale
# Resize to 28x28
img = img.resize((28, 28), Image.Resampling.LANCZOS)
# Convert to tensor and normalize (same as training)
# Note: MNIST images saved as PNG are already in correct format:
# white/light digits on dark/black background
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
img_tensor = transform(img)
# Get array for visualization
img_array = np.array(img)
return img_tensor, img_array
def predict(model, image_tensor, device):
"""Make prediction on a single image"""
# Add batch dimension
image_tensor = image_tensor.unsqueeze(0).to(device)
# Forward pass
with torch.no_grad():
outputs = model(image_tensor)
probabilities = torch.softmax(outputs, dim=1)
confidence, predicted = torch.max(probabilities, 1)
return predicted.item(), confidence.item(), probabilities.squeeze().cpu().numpy()
def visualize_prediction(image, predicted_digit, confidence, probabilities, save_path=None):
"""Visualize the prediction with confidence scores"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Show image
ax1.imshow(image, cmap='gray')
ax1.set_title(f'Input Image\nPredicted: {predicted_digit} ({confidence*100:.1f}%)',
fontsize=14, fontweight='bold')
ax1.axis('off')
# Show probability distribution
digits = np.arange(10)
colors = ['green' if i == predicted_digit else 'gray' for i in digits]
bars = ax2.bar(digits, probabilities * 100, color=colors, alpha=0.7)
# Add value labels on bars
for i, (bar, prob) in enumerate(zip(bars, probabilities)):
height = bar.get_height()
ax2.text(bar.get_x() + bar.get_width()/2., height,
f'{prob*100:.1f}%',
ha='center', va='bottom', fontsize=9)
ax2.set_xlabel('Digit', fontsize=12)
ax2.set_ylabel('Confidence (%)', fontsize=12)
ax2.set_title('Class Probabilities', fontsize=14, fontweight='bold')
ax2.set_xticks(digits)
ax2.set_ylim([0, 105])
ax2.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"✓ Visualization saved to {save_path}")
plt.show()
def predict_batch(model, image_paths, device):
"""Make predictions on multiple images"""
results = []
for image_path in image_paths:
print(f"\nProcessing: {image_path}")
# Preprocess
img_tensor, img_array = preprocess_image(image_path)
# Predict
predicted, confidence, probabilities = predict(model, img_tensor, device)
results.append({
'image_path': image_path,
'predicted': predicted,
'confidence': confidence,
'probabilities': probabilities
})
print(f" Prediction: {predicted} (Confidence: {confidence*100:.2f}%)")
# Show top 3 predictions
top3_idx = np.argsort(probabilities)[-3:][::-1]
print(f" Top 3: ", end="")
for idx in top3_idx:
print(f"{idx}({probabilities[idx]*100:.1f}%) ", end="")
print()
return results
def main():
parser = argparse.ArgumentParser(description='MNIST Digit Recognition Inference')
parser.add_argument('--model-path', type=str, required=True,
help='Path to trained model checkpoint')
parser.add_argument('--image-path', type=str,
help='Path to input image (28x28 recommended, grayscale)')
parser.add_argument('--image-dir', type=str,
help='Directory containing multiple images to predict')
parser.add_argument('--model-type', type=str, default='cnn', choices=['cnn', 'fc'],
help='Model architecture type (auto-detected from checkpoint if available)')
parser.add_argument('--save-viz', type=str,
help='Path to save visualization')
parser.add_argument('--use-gpu', action='store_true',
help='Use GPU if available')
args = parser.parse_args()
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu')
print(f"Using device: {device}")
# Load model
model = load_model(args.model_path, args.model_type, device)
# Single image prediction
if args.image_path:
print(f"\nProcessing single image: {args.image_path}")
# Preprocess
img_tensor, img_array = preprocess_image(args.image_path)
# Predict
predicted, confidence, probabilities = predict(model, img_tensor, device)
print(f"\n{'='*50}")
print(f"Prediction: {predicted}")
print(f"Confidence: {confidence*100:.2f}%")
print(f"{'='*50}")
# Show all probabilities
print("\nAll class probabilities:")
for digit in range(10):
print(f" {digit}: {probabilities[digit]*100:.2f}%")
# Visualize
save_path = args.save_viz if args.save_viz else 'prediction_visualization.png'
visualize_prediction(img_array, predicted, confidence, probabilities, save_path)
# Batch prediction
elif args.image_dir:
print(f"\nProcessing directory: {args.image_dir}")
image_dir = Path(args.image_dir)
image_paths = list(image_dir.glob('*.png')) + list(image_dir.glob('*.jpg')) + list(image_dir.glob('*.jpeg'))
if not image_paths:
print("No images found in directory!")
return
print(f"Found {len(image_paths)} images")
results = predict_batch(model, [str(p) for p in image_paths], device)
# Summary
print(f"\n{'='*50}")
print("Summary:")
print(f"{'='*50}")
for result in results:
print(f"{Path(result['image_path']).name}: {result['predicted']} ({result['confidence']*100:.1f}%)")
else:
print("Please provide either --image-path or --image-dir")
return
if __name__ == '__main__':
main()