|
import torch |
|
import numpy as np |
|
import pathlib |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
from matplotlib import rcParams |
|
from sklearn.metrics import ( |
|
classification_report, |
|
precision_recall_curve, |
|
accuracy_score, |
|
f1_score, |
|
confusion_matrix, |
|
matthews_corrcoef, |
|
ConfusionMatrixDisplay, |
|
roc_curve, |
|
auc, |
|
average_precision_score, |
|
cohen_kappa_score, |
|
|
|
) |
|
from sklearn.preprocessing import label_binarize |
|
from configs import * |
|
|
|
|
|
rcParams["font.family"] = "Times New Roman" |
|
|
|
|
|
model = MODEL.to(DEVICE) |
|
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE)) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_image(image_path, model, transform): |
|
model.eval() |
|
correct_predictions = 0 |
|
|
|
|
|
images = list(pathlib.Path(image_path).rglob("*.png")) |
|
|
|
total_predictions = len(images) |
|
|
|
true_classes = [] |
|
predicted_labels = [] |
|
predicted_scores = [] |
|
|
|
with torch.no_grad(): |
|
for image_file in images: |
|
print("---------------------------") |
|
|
|
true_class = CLASSES.index(image_file.parts[-2]) |
|
print("Image path:", image_file) |
|
print("True class:", true_class) |
|
image = Image.open(image_file).convert("RGB") |
|
image = transform(image).unsqueeze(0) |
|
image = image.to(DEVICE) |
|
output = model(image) |
|
predicted_class = torch.argmax(output, dim=1).item() |
|
|
|
print("Predicted class:", predicted_class) |
|
|
|
true_classes.append(true_class) |
|
predicted_labels.append(predicted_class) |
|
predicted_scores.append( |
|
output.softmax(dim=1).cpu().numpy() |
|
) |
|
|
|
|
|
if predicted_class == true_class: |
|
correct_predictions += 1 |
|
|
|
|
|
accuracy = accuracy_score(true_classes, predicted_labels) |
|
print("Accuracy:", accuracy) |
|
f1 = f1_score(true_classes, predicted_labels, average="weighted") |
|
print("Weighted F1 Score:", f1) |
|
|
|
|
|
predicted_labels_tensor = torch.tensor(predicted_labels) |
|
true_classes_tensor = torch.tensor(true_classes) |
|
|
|
|
|
conf_matrix = confusion_matrix( |
|
true_classes, |
|
predicted_labels, |
|
) |
|
|
|
|
|
ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=CLASSES).plot( |
|
cmap=plt.cm.Blues, xticks_rotation=25 |
|
) |
|
|
|
plt.subplots_adjust( |
|
top=0.935, |
|
bottom=0.155, |
|
left=0.125, |
|
right=0.905, |
|
hspace=0.2, |
|
wspace=0.2, |
|
) |
|
plt.title("Confusion Matrix") |
|
manager = plt.get_current_fig_manager() |
|
manager.full_screen_toggle() |
|
plt.savefig("docs/evaluation/confusion_matrix.png") |
|
plt.show() |
|
|
|
|
|
class_names = CLASSES |
|
report = classification_report( |
|
true_classes, predicted_labels, target_names=class_names |
|
) |
|
print("Classification Report:\n", report) |
|
|
|
|
|
true_classes_binary = label_binarize(true_classes, classes=range(NUM_CLASSES)) |
|
precision, recall, _ = precision_recall_curve( |
|
true_classes_binary.ravel(), np.array(predicted_scores).ravel() |
|
) |
|
|
|
fpr, tpr, _ = roc_curve( |
|
true_classes_binary.ravel(), np.array(predicted_scores).ravel() |
|
) |
|
auc_roc = auc(fpr, tpr) |
|
print("AUC-ROC:", auc_roc) |
|
|
|
|
|
precision, recall, _ = precision_recall_curve( |
|
true_classes_binary.ravel(), np.array(predicted_scores).ravel() |
|
) |
|
auc_prc = average_precision_score( |
|
true_classes_binary.ravel(), np.array(predicted_scores).ravel() |
|
) |
|
print("AUC PRC:", auc_prc) |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
plt.plot(recall, precision) |
|
plt.title("Precision-Recall Curve") |
|
plt.xlabel("Recall") |
|
plt.ylabel("Precision") |
|
|
|
plt.text( |
|
0.6, |
|
0.2, |
|
"AUC-PRC = {:.3f}".format(auc_prc), |
|
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8), |
|
) |
|
plt.savefig("docs/evaluation/prc.png") |
|
plt.show() |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
plt.plot(fpr, tpr) |
|
plt.title("ROC Curve") |
|
plt.xlabel("False Positive Rate") |
|
plt.ylabel("True Positive Rate") |
|
|
|
plt.text( |
|
0.6, |
|
0.2, |
|
"AUC-ROC = {:.3f}".format(auc_roc), |
|
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8), |
|
) |
|
plt.savefig("docs/evaluation/roc.png") |
|
plt.show() |
|
|
|
|
|
|
|
print("Matthew's correlation coefficient:", matthews_corrcoef(true_classes, predicted_labels)) |
|
|
|
|
|
print("Cohen's kappa:", cohen_kappa_score(true_classes, predicted_labels)) |
|
|
|
|
|
predict_image("data/test/Task 1/", model, preprocess) |
|
|
|
|
|
|
|
|
|
|
|
|