Spaces:
Runtime error
Runtime error
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| def convert_back_image(image): | |
| """Using mean and std deviation convert image back to normal""" | |
| cifar10_mean = (0.4914, 0.4822, 0.4471) | |
| cifar10_std = (0.2469, 0.2433, 0.2615) | |
| image = image.numpy().astype(dtype=np.float32) | |
| for i in range(image.shape[0]): | |
| image[i] = (image[i] * cifar10_std[i]) + cifar10_mean[i] | |
| # To stop throwing a warning that image pixels exceeds bounds | |
| image = image.clip(0, 1) | |
| return np.transpose(image, (1, 2, 0)) | |
| def plot_sample_training_images(batch_data, batch_label, class_label, num_images=30): | |
| """Function to plot sample images from the training data.""" | |
| images, labels = batch_data, batch_label | |
| # Calculate the number of images to plot | |
| num_images = min(num_images, len(images)) | |
| # calculate the number of rows and columns to plot | |
| num_cols = 5 | |
| num_rows = int(np.ceil(num_images / num_cols)) | |
| # Initialize a subplot with the required number of rows and columns | |
| fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, 10)) | |
| # Iterate through the images and plot them in the grid along with class labels | |
| for img_index in range(1, num_images + 1): | |
| plt.subplot(num_rows, num_cols, img_index) | |
| plt.tight_layout() | |
| plt.axis("off") | |
| plt.imshow(convert_back_image(images[img_index - 1])) | |
| plt.title(class_label[labels[img_index - 1].item()]) | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| return fig, axs | |
| def plot_train_test_metrics(results): | |
| """ | |
| Function to plot the training and test metrics. | |
| """ | |
| # Extract train_losses, train_acc, test_losses, test_acc from results | |
| train_losses = results["train_loss"] | |
| train_acc = results["train_acc"] | |
| test_losses = results["test_loss"] | |
| test_acc = results["test_acc"] | |
| # Plot the graphs in a 1x2 grid showing the training and test metrics | |
| fig, axs = plt.subplots(1, 2, figsize=(16, 8)) | |
| # Loss plot | |
| axs[0].plot(train_losses, label="Train") | |
| axs[0].plot(test_losses, label="Test") | |
| axs[0].set_title("Loss") | |
| axs[0].legend(loc="upper right") | |
| # Accuracy plot | |
| axs[1].plot(train_acc, label="Train") | |
| axs[1].plot(test_acc, label="Test") | |
| axs[1].set_title("Accuracy") | |
| axs[1].legend(loc="upper right") | |
| return fig, axs | |
| def plot_misclassified_images(data, class_label, num_images=10): | |
| """Plot the misclassified images from the test dataset.""" | |
| # Calculate the number of images to plot | |
| num_images = min(num_images, len(data["ground_truths"])) | |
| # calculate the number of rows and columns to plot | |
| num_cols = 5 | |
| num_rows = int(np.ceil(num_images / num_cols)) | |
| # Initialize a subplot with the required number of rows and columns | |
| fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2)) | |
| # Iterate through the images and plot them in the grid along with class labels | |
| for img_index in range(1, num_images + 1): | |
| # Get the ground truth and predicted labels for the image | |
| label = data["ground_truths"][img_index - 1].cpu().item() | |
| pred = data["predicted_vals"][img_index - 1].cpu().item() | |
| # Get the image | |
| image = data["images"][img_index - 1].cpu() | |
| # Plot the image | |
| plt.subplot(num_rows, num_cols, img_index) | |
| plt.tight_layout() | |
| plt.axis("off") | |
| plt.imshow(convert_back_image(image)) | |
| plt.title(f"""ACT: {class_label[label]} \nPRED: {class_label[pred]}""") | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| return fig, axs | |
| # Function to plot gradcam for misclassified images using pytorch_grad_cam | |
| def plot_gradcam_images( | |
| model, | |
| data, | |
| class_label, | |
| target_layers, | |
| targets=None, | |
| num_images=10, | |
| image_weight=0.25, | |
| ): | |
| """Show gradcam for misclassified images""" | |
| # Calculate the number of images to plot | |
| num_images = min(num_images, len(data["ground_truths"])) | |
| # calculate the number of rows and columns to plot | |
| num_cols = 5 | |
| num_rows = int(np.ceil(num_images / num_cols)) | |
| # Initialize a subplot with the required number of rows and columns | |
| fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2)) | |
| # Initialize the GradCAM object | |
| # https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/grad_cam.py | |
| # https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/base_cam.py | |
| # Alert: Change the device to cpu for gradio app | |
| cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) | |
| # Iterate through the images and plot them in the grid along with class labels | |
| for img_index in range(1, num_images + 1): | |
| # Extract elements from the data dictionary | |
| # Get the ground truth and predicted labels for the image | |
| label = data["ground_truths"][img_index - 1].cpu().item() | |
| pred = data["predicted_vals"][img_index - 1].cpu().item() | |
| # Get the image | |
| image = data["images"][img_index - 1].cpu() | |
| # Get the GradCAM output | |
| # https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/model_targets.py | |
| grad_cam_output = cam( | |
| input_tensor=image.unsqueeze(0), | |
| targets=targets, | |
| aug_smooth=True, | |
| eigen_smooth=True, | |
| ) | |
| grad_cam_output = grad_cam_output[0, :] | |
| # Overlay gradcam on top of numpy image | |
| overlayed_image = show_cam_on_image( | |
| convert_back_image(image), | |
| grad_cam_output, | |
| use_rgb=True, | |
| image_weight=image_weight, | |
| ) | |
| # Plot the image | |
| plt.subplot(num_rows, num_cols, img_index) | |
| plt.tight_layout() | |
| plt.axis("off") | |
| plt.imshow(overlayed_image) | |
| plt.title(f"""ACT: {class_label[label]} \nPRED: {class_label[pred]}""") | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| return fig, axs | |