File size: 2,741 Bytes
c5bd7aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import torch
import numpy as np
import matplotlib.pyplot as plt
from mlxtend.evaluate import confusion_matrix
from mlxtend.plotting import plot_confusion_matrix
import model_builder
def plot_confusion_Matrix(model_path, dataloader, class_names, device, figsize=(12, 12)):
"""
Generate and plot confusion matrix using mlxtend library from a PyTorch model and DataLoader.
Args:
model: PyTorch model's path eg(".pth" or ".pt").
dataloader: DataLoader instance for the dataset.
class_names (list): List of class names.
device: Target device to compute on (e.g., "cuda" or "cpu").
figsize (tuple): Figure size.
Returns:
None
"""
# Load the model
model = model_builder.TrashClassificationCNNModel(input_shape=3,
hidden_units=15,
output_shape=len(class_names)
)
model.load_state_dict(torch.load(model_path))
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
for X, y in dataloader:
X = X.to(device)
y_true.extend(y.cpu().numpy())
y_logit = model(X)
y_pred.extend(torch.argmax(y_logit, dim=1).cpu().numpy())
confmat = confusion_matrix(y_target=y_true, y_predicted=y_pred, binary=False)
# Plot confusion matrix
fig, ax = plot_confusion_matrix(conf_mat=confmat,
class_names=class_names,
figsize=figsize)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()
def plot_metrics(metrics):
"""
Plots training and testing loss and accuracy.
Args:
metrics (dict): A dictionary containing training and testing loss and accuracy.
Returns:
None
"""
epochs = range(1, len(metrics['train_loss']) + 1)
# Plot training and testing loss
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs, metrics['train_loss'], 'b', label='Training loss')
plt.plot(epochs, metrics['test_loss'], 'r', label='Testing loss')
plt.title('Training and Testing Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
# Plot training and testing accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs, metrics['train_acc'], 'b', label='Training accuracy')
plt.plot(epochs, metrics['test_acc'], 'r', label='Testing accuracy')
plt.title('Training and Testing Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
# Show plot
plt.tight_layout()
plt.show()
|