import torch import torch.nn as nn import torchvision.models as models from torchvision import datasets, transforms from torch.utils.data import DataLoader from sklearn.metrics import classification_report, f1_score, recall_score, roc_auc_score, roc_curve import matplotlib.pyplot as plt import numpy as np import logging # Define the model architecture (same as in training.py) class ResNet18(nn.Module): def __init__(self, num_classes): super(ResNet18, self).__init__() self.resnet18 = models.resnet18(weights='ResNet18_Weights.DEFAULT') # Pretrained weights self.resnet18.fc = nn.Linear(self.resnet18.fc.in_features, num_classes) # Custom classifier def forward(self, x): return self.resnet18(x) # Instantiate the model num_classes = 2 # Adjust based on your training setup model = ResNet18(num_classes=num_classes) # Load the state dictionary state_dict_path = 'resnet_state_dict.pth' model.load_state_dict(torch.load(state_dict_path)) model.eval() # Set to evaluation mode # Move model to appropriate device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) # Define test transformations (same as training) test_transforms = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Load test dataset test_image_folder = 'data/test' # Path to the test image directory test_dataset = datasets.ImageFolder(root=test_image_folder, transform=test_transforms) test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False) # Evaluate the model on the test set all_labels = [] all_preds = [] all_probs = [] class_names = test_dataset.classes # Get class labels with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) probabilities = torch.softmax(outputs, dim=1)[:, 1] # Probabilities for the positive class _, predicted = torch.max(outputs, 1) all_labels.extend(labels.cpu().numpy()) all_preds.extend(predicted.cpu().numpy()) all_probs.extend(probabilities.cpu().numpy()) # Calculate F1 score, recall, and generate a classification report f1 = f1_score(all_labels, all_preds) recall = recall_score(all_labels, all_preds) print(f"F1 Score: {f1:.2f}") print(f"Recall: {recall:.2f}") print("\nClassification Report:\n", classification_report(all_labels, all_preds, target_names=class_names)) # Calculate and plot AUC-ROC curve auc_score = roc_auc_score(all_labels, all_probs) fpr, tpr, _ = roc_curve(all_labels, all_probs) logging.info(f'AUC Score: {auc_score}') logging.info(f'FPR: {fpr}') logging.info(f'TPR: {tpr}') logging.info(f'F1 Score', {f1}) logging.info(f'Recall: {recall}') plt.figure(figsize=(8, 6)) plt.plot(fpr, tpr, color='blue', label=f"AUC = {auc_score:.2f}") plt.plot([0, 1], [0, 1], color='red', linestyle='--') plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") plt.title("ROC Curve") plt.legend(loc="lower right") plt.grid() plt.savefig(r'AUC.png', dpi = 300) # Print overall accuracy accuracy = 100 * np.mean(np.array(all_labels) == np.array(all_preds)) print(f"Accuracy on test set: {accuracy:.2f}%")