''' This module contains utility functions for plotting ''' # Handling files import os # Random import random # Handling images and visualization from PIL import Image import matplotlib.pyplot as plt import seaborn as sns # Confusion matrix from sklearn.metrics import confusion_matrix def plot_samples( data_path, sample_classes=["Tyrannosaurus Rex", "Pteranodon", "Triceratops"], img_per_class=3, show=False, save_path=None ): ''' Plot random samples of dinosaur species Args: data_path (str) : path to dataset sample_classes (list of str): classes (i.e. dinosaur species) we want to display. Default list: "Tyrannosaurus Rex", "Pteranodon", "Triceratops" img_per_class (int) : number of samples to show from each class, default value is 3 show (bool) : decide to show the plot or not, default: False save_path (str) : path to save the plot if provided, default: None Returns: None ''' # Set up plt.figure(figsize=(12, 6)) num_sample_classes = len(sample_classes) with plt.rc_context( rc={ "axes.grid": False, "axes.spines.top": False, "axes.spines.right": False, "axes.spines.left": False, "axes.spines.bottom": False } ): for y, cls in enumerate(sample_classes): # Get sample image paths imgs = os.listdir(os.path.join(data_path, cls)) samples = random.sample(imgs, img_per_class) # Plotting for i, img in enumerate(samples): plt_idx = i * num_sample_classes + y + 1 plt.subplot(img_per_class, num_sample_classes, plt_idx) plt.imshow(Image.open(os.path.join(data_path, cls, img))) plt.axis('off') if i == 0: plt.title(cls) plt.tight_layout() plt.suptitle("Sample Images", fontsize=16) plt.subplots_adjust(top=0.88) if save_path: plt.savefig(save_path) if show: plt.show() def plot_class_balance(data_path, show=False, save_path=None): ''' Plot class balance from a given path Args: data_path (str): path to dataset show (bool) : decide to show the plot or not, default: False save_path (str): path to save the plot if provided, default: None Returns: None ''' # Set up plt.figure(figsize=(12, 6)) classes = os.listdir(data_path) img_count = [ len(os.listdir(os.path.join(data_path, cls))) for cls in classes ] class_frequency = [(cnt/sum(img_count))*100 for cnt in img_count] # Sort descending sorted_data = sorted( zip(classes, class_frequency), key=lambda x: x[1], reverse=True ) sorted_classes, sorted_frequency = zip(*sorted_data) # Plotting plt.bar(x=sorted_classes, height=sorted_frequency) plt.title("Class Frequency", fontsize=16) plt.xlabel("Class") plt.ylabel("Frequency (%)") plt.xticks(rotation=45, ha="right") if save_path: plt.savefig(save_path, bbox_inches="tight") if show: plt.show() def plot_confusion_matrix( y_true, y_pred, display_labels, top_k=10, figsize=(18, 24), normalize="true", show=False, save_path=None ): ''' Plot confusion matrix and a table of top_k misclassified pairs Args: y_true (lst) : true labels y_pred (lst) : predictions from model display_labels (lst): labels to display top_k (int) : number of classes with highest confusion to include in confusion matrix, default: 10 figsize (tuple) : figure size, default: (18, 24) fontsize (float) : size of texts for labels normalize (str) : option to normalize confusion matrix (same in sklearn.metrics.confusion_matrix), but only accepts 2 value: "true" (normalize by row) and None (no normalization), default: "true" show (bool) : decide to show the plot or not, default: False save_path (str) : path to save the plot if provided, default: None Returns: None ''' # Full confusion matrix cm = confusion_matrix(y_true, y_pred, normalize=normalize) # Find (i, j) indices (i != j) that have highest confusion confusions = [] for i in range(len(cm)): for j in range(len(cm)): if i != j and cm[i][j] > 0: confusions.append((i, j, cm[i][j])) # Sorting and find top-k confused pairs top_confusions = sorted(confusions, key=lambda x: x[2], reverse=True)[:top_k] # Set up plots fig, axes = plt.subplots( nrows=2, figsize=figsize, gridspec_kw={"height_ratios": [4, 1]} ) # Plot confusion matrix sns.heatmap( cm, cmap="Blues", linewidths=0.5, linecolor="gray", xticklabels=display_labels, yticklabels=display_labels, cbar_kws={"label": "Proportion" if normalize else "Count"}, ax=axes[0] ) axes[0].set_title("Confusion Matrix", fontsize=16, fontweight="bold", pad=20) axes[0].set_xlabel("Predicted Label", fontsize=14) axes[0].set_ylabel("True Label", fontsize=14) axes[0].tick_params(axis="x", labelsize=14) axes[0].tick_params(axis="y", labelsize=14) plt.setp(axes[0].get_xticklabels(), rotation=90, ha="right") # Table of top_k misclassified pairs columns = ["Ground Truth", "Predicted", "Proportion" if normalize else "Count"] data = [ [display_labels[i], display_labels[j], f"{v:.2f}" if normalize else int(v)] for i, j, v in top_confusions ] axes[1].axis("off") table = axes[1].table( cellText=data, colLabels=columns, loc="center", cellLoc="center", colColours=["#d3d3d3"] * len(columns), bbox=[0, 0, 1, 1] ) table.auto_set_font_size(False) table.set_fontsize(14) table.scale(1, 2) axes[1].set_title( f"Top-{top_k} Misclassified Pairs", fontsize=14, fontweight="bold", pad=10 ) plt.tight_layout(h_pad=5) if save_path: plt.savefig(save_path, bbox_inches="tight") if show: plt.show() def plot_training_progress( avg_training_losses, avg_val_losses, accuracy_scores, f1_scores, lr_changes, show=False, save_path=None ): ''' Plot training process over epochs, specifically, 3 subplots are created: - One plot for average train and validation loss - One plot for accuracy and weighted F1 score on validation data - One plot for learning rates Args: avg_training_losses (lst): average training loss avg_val_losses (lst) : average validation loss accuracy_scores (lst) : accuracy on validation data f1_scores (lst) : weighted F1 score on validation data lr_changes (lst) : learning rates show (bool) : decide to show the plot or not, default: False save_path (str) : path to save the plot if provided, default: None Returns: None ''' fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(12, 15)) n_epochs = [i+1 for i in range(len(avg_training_losses))] # Avg Training vs Validation loss axes[0].plot(n_epochs, avg_training_losses, label="Train", color="blue") axes[0].plot(n_epochs, avg_val_losses, label="Validation", color="red") axes[0].set( xlabel="Epoch", ylabel="Average Loss", title="Average Training vs Validation Loss" ) axes[0].legend(loc="upper right") axes[0].grid(True) # Accuracy vs Weighted F1 score on validation data axes[1].plot(n_epochs, accuracy_scores, label="Accuracy", color="blue") axes[1].plot(n_epochs, f1_scores, label="Weighted F1 Score", color="red") axes[1].set( xlabel="Epoch", ylabel="", title="Validation Accuracy vs Weighted F1 Score" ) axes[1].legend(loc="lower right") axes[1].grid(True) # Learning rate axes[2].plot(n_epochs, lr_changes) axes[2].set( xlabel="Epoch", ylabel="Learning Rate", title="Learning Rate Changes" ) axes[2].grid(True) plt.suptitle("Training Process", fontsize=16) plt.tight_layout() if save_path: plt.savefig(save_path, bbox_inches="tight") if show: plt.show()