#!/usr/bin/env python3 """ Utility Script containing functions to be used for training Author: Shilpaj Bhalerao """ # Standard Library Imports import math from typing import NoReturn import io from PIL import Image # Third-Party Imports import numpy as np import matplotlib.pyplot as plt import torch from torchsummary import summary from torchvision import transforms from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image def get_summary(model, input_size: tuple) -> NoReturn: """ Function to get the summary of the model architecture :param model: Object of model architecture class :param input_size: Input data shape (Channels, Height, Width) """ use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") network = model.to(device) summary(network, input_size=input_size) def get_misclassified_data(model, device, test_loader): """ Function to run the model on test set and return misclassified images :param model: Network Architecture :param device: CPU/GPU :param test_loader: DataLoader for test set """ # Prepare the model for evaluation i.e. drop the dropout layer model.eval() # List to store misclassified Images misclassified_data = [] # Reset the gradients with torch.no_grad(): # Extract images, labels in a batch for data, target in test_loader: # Migrate the data to the device data, target = data.to(device), target.to(device) # Extract single image, label from the batch for image, label in zip(data, target): # Add batch dimension to the image image = image.unsqueeze(0) # Get the model prediction on the image output = model(image) # Convert the output from one-hot encoding to a value pred = output.argmax(dim=1, keepdim=True) # If prediction is incorrect, append the data if pred != label: misclassified_data.append((image, label, pred)) return misclassified_data # -------------------- DATA STATISTICS -------------------- def get_mnist_statistics(data_set, data_set_type='Train'): """ Function to return the statistics of the training data :param data_set: Training dataset :param data_set_type: Type of dataset [Train/Test/Val] """ # We'd need to convert it into Numpy! Remember above we have converted it into tensors already train_data = data_set.train_data train_data = data_set.transform(train_data.numpy()) print(f'[{data_set_type}]') print(' - Numpy Shape:', data_set.train_data.cpu().numpy().shape) print(' - Tensor Shape:', data_set.train_data.size()) print(' - min:', torch.min(train_data)) print(' - max:', torch.max(train_data)) print(' - mean:', torch.mean(train_data)) print(' - std:', torch.std(train_data)) print(' - var:', torch.var(train_data)) dataiter = next(iter(data_set)) images, labels = dataiter[0], dataiter[1] print(images.shape) print(labels) # Let's visualize some of the images plt.imshow(images[0].numpy().squeeze(), cmap='gray') def get_cifar_property(images, operation): """ Get the property on each channel of the CIFAR :param images: Get the property value on the images :param operation: Mean, std, Variance, etc """ param_r = eval('images[:, 0, :, :].' + operation + '()') param_g = eval('images[:, 1, :, :].' + operation + '()') param_b = eval('images[:, 2, :, :].' + operation + '()') return param_r, param_g, param_b def get_cifar_statistics(data_set, data_set_type='Train'): """ Function to get the statistical information of the CIFAR dataset :param data_set: Training set of CIFAR :param data_set_type: Training or Test data """ # Images in the dataset images = [item[0] for item in data_set] images = torch.stack(images, dim=0).numpy() # Calculate mean over each channel mean_r, mean_g, mean_b = get_cifar_property(images, 'mean') # Calculate Standard deviation over each channel std_r, std_g, std_b = get_cifar_property(images, 'std') # Calculate min value over each channel min_r, min_g, min_b = get_cifar_property(images, 'min') # Calculate max value over each channel max_r, max_g, max_b = get_cifar_property(images, 'max') # Calculate variance value over each channel var_r, var_g, var_b = get_cifar_property(images, 'var') print(f'[{data_set_type}]') print(f' - Total {data_set_type} Images: {len(data_set)}') print(f' - Tensor Shape: {images[0].shape}') print(f' - min: {min_r, min_g, min_b}') print(f' - max: {max_r, max_g, max_b}') print(f' - mean: {mean_r, mean_g, mean_b}') print(f' - std: {std_r, std_g, std_b}') print(f' - var: {var_r, var_g, var_b}') # Let's visualize some of the images plt.imshow(np.transpose(images[1].squeeze(), (1, 2, 0))) # -------------------- GradCam -------------------- def display_gradcam_output(data: list, classes, inv_normalize: transforms.Normalize, model, target_layers, targets=None, number_of_samples: int = 10, transparency: float = 0.60): """ Function to visualize GradCam output on the data :param data: List[Tuple(image, label)] :param classes: Name of classes in the dataset :param inv_normalize: Mean and Standard deviation values of the dataset :param model: Model architecture :param target_layers: Layers on which GradCam should be executed :param targets: Classes to be focused on for GradCam :param number_of_samples: Number of images to print :param transparency: Weight of Normal image when mixed with activations """ # Plot configuration fig = plt.figure(figsize=(10, 10)) x_count = 5 y_count = math.ceil(number_of_samples / x_count) # Create an object for GradCam cam = GradCAM(model=model, target_layers=target_layers) # Iterate over number of specified images for i in range(number_of_samples): plt.subplot(y_count, x_count, i + 1) input_tensor = data[i][0] # Get the activations of the layer for the images grayscale_cam = cam(input_tensor=input_tensor, targets=targets) grayscale_cam = grayscale_cam[0, :] # Get back the original image img = input_tensor.squeeze(0).to('cpu') img = inv_normalize(img) rgb_img = np.transpose(img, (1, 2, 0)) rgb_img = rgb_img.numpy().astype(np.float32) # Ensure the image data is within the [0, 1] range rgb_img = np.clip(rgb_img, 0, 1) # Mix the activations on the original image visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency) # Display the images on the plot plt.imshow(visualization) plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()]) plt.xticks([]) plt.yticks([]) plt.tight_layout() # Save the entire figure to a BytesIO object buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) img_var = Image.open(buf) return img_var