import torch import torchvision.transforms as transforms from PIL import Image import os # Define the device device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) class Params: def __init__(self): self.batch_size = 512 self.name = "resnet_50" self.workers = 16 self.lr = 0.1 self.momentum = 0.9 self.weight_decay = 1e-4 self.lr_step_size = 30 self.lr_gamma = 0.1 def __repr__(self): return str(self.__dict__) def __eq__(self, other): return self.__dict__ == other.__dict__ params = Params() # Path to the saved model checkpoint checkpoint_path = "checkpoints/resnet_50/checkpoint.pth" # Load the model architecture from model import ResNet50 # Assuming resnet.py contains your model definition num_classes = 1000 # Adjust this to match your dataset model = ResNet50(num_classes=num_classes).to(device) # Load the trained model weights checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint["model"]) model.eval() # Define transformations for inference inference_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Resize(size=256), transforms.CenterCrop(224), transforms.Normalize(mean=[0.485, 0.485, 0.406], std=[0.229, 0.224, 0.225]), ]) # Load class names from the text file def load_class_names(file_path): with open(file_path, 'r') as f: class_names = [line.strip() for line in f] return class_names # Function to make predictions on a single image def predict(image_path, model, transforms, class_names=None): # Load and transform the image image = Image.open(image_path).convert("RGB") image_tensor = transforms(image).unsqueeze(0).to(device) # Forward pass with torch.no_grad(): output = model(image_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) top_prob, top_class = probabilities.topk(5, largest=True, sorted=True) # Display the top predictions print("Predictions:") for i in range(top_prob.size(0)): class_name = class_names[top_class[i]] if class_names else f"Class {top_class[i].item()}" print(f"{class_name}: {top_prob[i].item() * 100:.2f}%") return top_prob, top_class # Path to the ImageNet classes text file imagenet_classes_file = "imagenet-classes.txt" # Replace with the actual path to your text file class_names = load_class_names(imagenet_classes_file) # Path to the image for inference image_path = "dog.png" # Replace with the actual path to your test image # Make a prediction predict(image_path, model, inference_transforms, class_names=class_names)