Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
from torchvision import datasets, transforms | |
from torch.utils.data import DataLoader | |
from sklearn.metrics import classification_report, f1_score, recall_score, roc_auc_score, roc_curve | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import logging | |
# Define the model architecture (same as in training.py) | |
class ResNet18(nn.Module): | |
def __init__(self, num_classes): | |
super(ResNet18, self).__init__() | |
self.resnet18 = models.resnet18(weights='ResNet18_Weights.DEFAULT') # Pretrained weights | |
self.resnet18.fc = nn.Linear(self.resnet18.fc.in_features, num_classes) # Custom classifier | |
def forward(self, x): | |
return self.resnet18(x) | |
# Instantiate the model | |
num_classes = 2 # Adjust based on your training setup | |
model = ResNet18(num_classes=num_classes) | |
# Load the state dictionary | |
state_dict_path = 'resnet_state_dict.pth' | |
model.load_state_dict(torch.load(state_dict_path)) | |
model.eval() # Set to evaluation mode | |
# Move model to appropriate device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = model.to(device) | |
# Define test transformations (same as training) | |
test_transforms = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
]) | |
# Load test dataset | |
test_image_folder = 'data/test' # Path to the test image directory | |
test_dataset = datasets.ImageFolder(root=test_image_folder, transform=test_transforms) | |
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False) | |
# Evaluate the model on the test set | |
all_labels = [] | |
all_preds = [] | |
all_probs = [] | |
class_names = test_dataset.classes # Get class labels | |
with torch.no_grad(): | |
for images, labels in test_loader: | |
images, labels = images.to(device), labels.to(device) | |
outputs = model(images) | |
probabilities = torch.softmax(outputs, dim=1)[:, 1] # Probabilities for the positive class | |
_, predicted = torch.max(outputs, 1) | |
all_labels.extend(labels.cpu().numpy()) | |
all_preds.extend(predicted.cpu().numpy()) | |
all_probs.extend(probabilities.cpu().numpy()) | |
# Calculate F1 score, recall, and generate a classification report | |
f1 = f1_score(all_labels, all_preds) | |
recall = recall_score(all_labels, all_preds) | |
print(f"F1 Score: {f1:.2f}") | |
print(f"Recall: {recall:.2f}") | |
print("\nClassification Report:\n", classification_report(all_labels, all_preds, target_names=class_names)) | |
# Calculate and plot AUC-ROC curve | |
auc_score = roc_auc_score(all_labels, all_probs) | |
fpr, tpr, _ = roc_curve(all_labels, all_probs) | |
logging.info(f'AUC Score: {auc_score}') | |
logging.info(f'FPR: {fpr}') | |
logging.info(f'TPR: {tpr}') | |
logging.info(f'F1 Score', {f1}) | |
logging.info(f'Recall: {recall}') | |
plt.figure(figsize=(8, 6)) | |
plt.plot(fpr, tpr, color='blue', label=f"AUC = {auc_score:.2f}") | |
plt.plot([0, 1], [0, 1], color='red', linestyle='--') | |
plt.xlabel("False Positive Rate") | |
plt.ylabel("True Positive Rate") | |
plt.title("ROC Curve") | |
plt.legend(loc="lower right") | |
plt.grid() | |
plt.savefig(r'AUC.png', dpi = 300) | |
# Print overall accuracy | |
accuracy = 100 * np.mean(np.array(all_labels) == np.array(all_preds)) | |
print(f"Accuracy on test set: {accuracy:.2f}%") |