hotdog / src /loadTrained.py
asidfactory's picture
Uploaded coded initial version
2a0bba9 verified
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, f1_score, recall_score, roc_auc_score, roc_curve
import matplotlib.pyplot as plt
import numpy as np
import logging
# Define the model architecture (same as in training.py)
class ResNet18(nn.Module):
def __init__(self, num_classes):
super(ResNet18, self).__init__()
self.resnet18 = models.resnet18(weights='ResNet18_Weights.DEFAULT') # Pretrained weights
self.resnet18.fc = nn.Linear(self.resnet18.fc.in_features, num_classes) # Custom classifier
def forward(self, x):
return self.resnet18(x)
# Instantiate the model
num_classes = 2 # Adjust based on your training setup
model = ResNet18(num_classes=num_classes)
# Load the state dictionary
state_dict_path = 'resnet_state_dict.pth'
model.load_state_dict(torch.load(state_dict_path))
model.eval() # Set to evaluation mode
# Move model to appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Define test transformations (same as training)
test_transforms = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Load test dataset
test_image_folder = 'data/test' # Path to the test image directory
test_dataset = datasets.ImageFolder(root=test_image_folder, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
# Evaluate the model on the test set
all_labels = []
all_preds = []
all_probs = []
class_names = test_dataset.classes # Get class labels
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
probabilities = torch.softmax(outputs, dim=1)[:, 1] # Probabilities for the positive class
_, predicted = torch.max(outputs, 1)
all_labels.extend(labels.cpu().numpy())
all_preds.extend(predicted.cpu().numpy())
all_probs.extend(probabilities.cpu().numpy())
# Calculate F1 score, recall, and generate a classification report
f1 = f1_score(all_labels, all_preds)
recall = recall_score(all_labels, all_preds)
print(f"F1 Score: {f1:.2f}")
print(f"Recall: {recall:.2f}")
print("\nClassification Report:\n", classification_report(all_labels, all_preds, target_names=class_names))
# Calculate and plot AUC-ROC curve
auc_score = roc_auc_score(all_labels, all_probs)
fpr, tpr, _ = roc_curve(all_labels, all_probs)
logging.info(f'AUC Score: {auc_score}')
logging.info(f'FPR: {fpr}')
logging.info(f'TPR: {tpr}')
logging.info(f'F1 Score', {f1})
logging.info(f'Recall: {recall}')
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', label=f"AUC = {auc_score:.2f}")
plt.plot([0, 1], [0, 1], color='red', linestyle='--')
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend(loc="lower right")
plt.grid()
plt.savefig(r'AUC.png', dpi = 300)
# Print overall accuracy
accuracy = 100 * np.mean(np.array(all_labels) == np.array(all_preds))
print(f"Accuracy on test set: {accuracy:.2f}%")