|
|
""" |
|
|
Inference script for making predictions with trained MNIST models |
|
|
Usage: python inference.py --model-path checkpoints/best_model.pth --image-path my_digit.png |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import argparse |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
class ConvNet(nn.Module): |
|
|
"""Convolutional Neural Network for MNIST""" |
|
|
def __init__(self, dropout_rate=0.3, num_classes=10): |
|
|
super(ConvNet, self).__init__() |
|
|
|
|
|
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) |
|
|
self.bn1 = nn.BatchNorm2d(32) |
|
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) |
|
|
self.bn2 = nn.BatchNorm2d(64) |
|
|
|
|
|
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) |
|
|
self.bn3 = nn.BatchNorm2d(128) |
|
|
self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) |
|
|
self.bn4 = nn.BatchNorm2d(128) |
|
|
|
|
|
self.pool = nn.MaxPool2d(2, 2) |
|
|
self.dropout_conv = nn.Dropout2d(dropout_rate * 0.5) |
|
|
|
|
|
self.fc1 = nn.Linear(128 * 7 * 7, 256) |
|
|
self.bn5 = nn.BatchNorm1d(256) |
|
|
self.dropout1 = nn.Dropout(dropout_rate) |
|
|
|
|
|
self.fc2 = nn.Linear(256, 128) |
|
|
self.bn6 = nn.BatchNorm1d(128) |
|
|
self.dropout2 = nn.Dropout(dropout_rate * 0.5) |
|
|
|
|
|
self.fc3 = nn.Linear(128, num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.conv1(x) |
|
|
x = self.bn1(x) |
|
|
x = torch.relu(x) |
|
|
x = self.conv2(x) |
|
|
x = self.bn2(x) |
|
|
x = torch.relu(x) |
|
|
x = self.pool(x) |
|
|
x = self.dropout_conv(x) |
|
|
|
|
|
x = self.conv3(x) |
|
|
x = self.bn3(x) |
|
|
x = torch.relu(x) |
|
|
x = self.conv4(x) |
|
|
x = self.bn4(x) |
|
|
x = torch.relu(x) |
|
|
x = self.pool(x) |
|
|
x = self.dropout_conv(x) |
|
|
|
|
|
x = x.view(x.size(0), -1) |
|
|
|
|
|
x = self.fc1(x) |
|
|
x = self.bn5(x) |
|
|
x = torch.relu(x) |
|
|
x = self.dropout1(x) |
|
|
|
|
|
x = self.fc2(x) |
|
|
x = self.bn6(x) |
|
|
x = torch.relu(x) |
|
|
x = self.dropout2(x) |
|
|
|
|
|
x = self.fc3(x) |
|
|
return x |
|
|
|
|
|
class ImprovedNN(nn.Module): |
|
|
"""Enhanced fully connected network""" |
|
|
def __init__(self, input_size=784, hidden_sizes=[512, 256, 128], |
|
|
num_classes=10, dropout_rate=0.3): |
|
|
super(ImprovedNN, self).__init__() |
|
|
|
|
|
layers = [] |
|
|
prev_size = input_size |
|
|
|
|
|
for i, hidden_size in enumerate(hidden_sizes): |
|
|
layers.extend([ |
|
|
nn.Linear(prev_size, hidden_size), |
|
|
nn.BatchNorm1d(hidden_size), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout_rate if i < len(hidden_sizes) - 1 else dropout_rate * 0.5) |
|
|
]) |
|
|
prev_size = hidden_size |
|
|
|
|
|
layers.append(nn.Linear(prev_size, num_classes)) |
|
|
self.network = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x.view(x.size(0), -1) |
|
|
return self.network(x) |
|
|
|
|
|
def load_model(model_path, model_type='cnn', device='cpu'): |
|
|
"""Load a trained model from checkpoint""" |
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
|
|
|
|
|
|
if 'args' in checkpoint and 'model_type' in checkpoint['args']: |
|
|
model_type = checkpoint['args']['model_type'] |
|
|
|
|
|
|
|
|
if model_type == 'cnn': |
|
|
model = ConvNet() |
|
|
else: |
|
|
model = ImprovedNN() |
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
print(f"✓ Loaded {model_type.upper()} model from {model_path}") |
|
|
print(f" - Trained for {checkpoint.get('epoch', 'unknown')} epochs") |
|
|
print(f" - Validation accuracy: {checkpoint.get('val_acc', 'unknown'):.2f}%") |
|
|
|
|
|
return model |
|
|
|
|
|
def preprocess_image(image_path): |
|
|
"""Preprocess an image for inference""" |
|
|
|
|
|
img = Image.open(image_path).convert('L') |
|
|
|
|
|
|
|
|
img = img.resize((28, 28), Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.1307,), (0.3081,)) |
|
|
]) |
|
|
|
|
|
img_tensor = transform(img) |
|
|
|
|
|
|
|
|
img_array = np.array(img) |
|
|
|
|
|
return img_tensor, img_array |
|
|
|
|
|
def predict(model, image_tensor, device): |
|
|
"""Make prediction on a single image""" |
|
|
|
|
|
image_tensor = image_tensor.unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(image_tensor) |
|
|
probabilities = torch.softmax(outputs, dim=1) |
|
|
confidence, predicted = torch.max(probabilities, 1) |
|
|
|
|
|
return predicted.item(), confidence.item(), probabilities.squeeze().cpu().numpy() |
|
|
|
|
|
def visualize_prediction(image, predicted_digit, confidence, probabilities, save_path=None): |
|
|
"""Visualize the prediction with confidence scores""" |
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) |
|
|
|
|
|
|
|
|
ax1.imshow(image, cmap='gray') |
|
|
ax1.set_title(f'Input Image\nPredicted: {predicted_digit} ({confidence*100:.1f}%)', |
|
|
fontsize=14, fontweight='bold') |
|
|
ax1.axis('off') |
|
|
|
|
|
|
|
|
digits = np.arange(10) |
|
|
colors = ['green' if i == predicted_digit else 'gray' for i in digits] |
|
|
bars = ax2.bar(digits, probabilities * 100, color=colors, alpha=0.7) |
|
|
|
|
|
|
|
|
for i, (bar, prob) in enumerate(zip(bars, probabilities)): |
|
|
height = bar.get_height() |
|
|
ax2.text(bar.get_x() + bar.get_width()/2., height, |
|
|
f'{prob*100:.1f}%', |
|
|
ha='center', va='bottom', fontsize=9) |
|
|
|
|
|
ax2.set_xlabel('Digit', fontsize=12) |
|
|
ax2.set_ylabel('Confidence (%)', fontsize=12) |
|
|
ax2.set_title('Class Probabilities', fontsize=14, fontweight='bold') |
|
|
ax2.set_xticks(digits) |
|
|
ax2.set_ylim([0, 105]) |
|
|
ax2.grid(True, alpha=0.3, axis='y') |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight') |
|
|
print(f"✓ Visualization saved to {save_path}") |
|
|
|
|
|
plt.show() |
|
|
|
|
|
def predict_batch(model, image_paths, device): |
|
|
"""Make predictions on multiple images""" |
|
|
results = [] |
|
|
|
|
|
for image_path in image_paths: |
|
|
print(f"\nProcessing: {image_path}") |
|
|
|
|
|
|
|
|
img_tensor, img_array = preprocess_image(image_path) |
|
|
|
|
|
|
|
|
predicted, confidence, probabilities = predict(model, img_tensor, device) |
|
|
|
|
|
results.append({ |
|
|
'image_path': image_path, |
|
|
'predicted': predicted, |
|
|
'confidence': confidence, |
|
|
'probabilities': probabilities |
|
|
}) |
|
|
|
|
|
print(f" Prediction: {predicted} (Confidence: {confidence*100:.2f}%)") |
|
|
|
|
|
|
|
|
top3_idx = np.argsort(probabilities)[-3:][::-1] |
|
|
print(f" Top 3: ", end="") |
|
|
for idx in top3_idx: |
|
|
print(f"{idx}({probabilities[idx]*100:.1f}%) ", end="") |
|
|
print() |
|
|
|
|
|
return results |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='MNIST Digit Recognition Inference') |
|
|
parser.add_argument('--model-path', type=str, required=True, |
|
|
help='Path to trained model checkpoint') |
|
|
parser.add_argument('--image-path', type=str, |
|
|
help='Path to input image (28x28 recommended, grayscale)') |
|
|
parser.add_argument('--image-dir', type=str, |
|
|
help='Directory containing multiple images to predict') |
|
|
parser.add_argument('--model-type', type=str, default='cnn', choices=['cnn', 'fc'], |
|
|
help='Model architecture type (auto-detected from checkpoint if available)') |
|
|
parser.add_argument('--save-viz', type=str, |
|
|
help='Path to save visualization') |
|
|
parser.add_argument('--use-gpu', action='store_true', |
|
|
help='Use GPU if available') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
model = load_model(args.model_path, args.model_type, device) |
|
|
|
|
|
|
|
|
if args.image_path: |
|
|
print(f"\nProcessing single image: {args.image_path}") |
|
|
|
|
|
|
|
|
img_tensor, img_array = preprocess_image(args.image_path) |
|
|
|
|
|
|
|
|
predicted, confidence, probabilities = predict(model, img_tensor, device) |
|
|
|
|
|
print(f"\n{'='*50}") |
|
|
print(f"Prediction: {predicted}") |
|
|
print(f"Confidence: {confidence*100:.2f}%") |
|
|
print(f"{'='*50}") |
|
|
|
|
|
|
|
|
print("\nAll class probabilities:") |
|
|
for digit in range(10): |
|
|
print(f" {digit}: {probabilities[digit]*100:.2f}%") |
|
|
|
|
|
|
|
|
save_path = args.save_viz if args.save_viz else 'prediction_visualization.png' |
|
|
visualize_prediction(img_array, predicted, confidence, probabilities, save_path) |
|
|
|
|
|
|
|
|
elif args.image_dir: |
|
|
print(f"\nProcessing directory: {args.image_dir}") |
|
|
|
|
|
image_dir = Path(args.image_dir) |
|
|
image_paths = list(image_dir.glob('*.png')) + list(image_dir.glob('*.jpg')) + list(image_dir.glob('*.jpeg')) |
|
|
|
|
|
if not image_paths: |
|
|
print("No images found in directory!") |
|
|
return |
|
|
|
|
|
print(f"Found {len(image_paths)} images") |
|
|
|
|
|
results = predict_batch(model, [str(p) for p in image_paths], device) |
|
|
|
|
|
|
|
|
print(f"\n{'='*50}") |
|
|
print("Summary:") |
|
|
print(f"{'='*50}") |
|
|
for result in results: |
|
|
print(f"{Path(result['image_path']).name}: {result['predicted']} ({result['confidence']*100:.1f}%)") |
|
|
|
|
|
else: |
|
|
print("Please provide either --image-path or --image-dir") |
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |