#!/usr/bin/env python3 """ Gradio Application for model trained on CIFAR10 dataset Author: Shilpaj Bhalerao Date: Aug 06, 2023 """ # Standard Library Imports import os from collections import OrderedDict # Third-Party Imports import gradio as gr import numpy as np import torch from torchvision import transforms from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from PIL import Image # Local Imports from resnet import LITResNet from visualize import FeatureMapVisualizer # Directory Path example_directory = 'examples/' model_path = 'epoch=23-step=2112.ckpt' classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') model = LITResNet.load_from_checkpoint(model_path, map_location=torch.device('cpu'), strict=False, class_names=classes) model.eval() # Create an object of the Class viz = FeatureMapVisualizer(model) def inference(input_img, transparency=0.5, number_of_top_classes=3, target_layer_number=4): """ Function to run inference on the input image :param input_img: Image provided by the user :parma transparency: Percentage of cam overlap over the input image :param number_of_top_classes: Number of top predictions for the input image :param target_layer_number: Layer for which GradCam to be shown """ # Resize the image to (32, 32) input_img = Image.fromarray(input_img).resize((32, 32)) input_img = np.array(input_img) # Calculate mean over each channel of input image mean_r, mean_g, mean_b = np.mean(input_img[:, :, 0]/255.), np.mean(input_img[:, :, 1]/255.), np.mean(input_img[:, :, 2]/255.) # Calculate Standard deviation over each channel std_r, std_g, std_b = np.std(input_img[:, :, 0]/255.), np.std(input_img[:, :, 1]/255.), np.std(input_img[:, :, 2]/255.) # Convert img to tensor and normalize it _transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((mean_r, mean_g, mean_b), (std_r, std_g, std_b)) ]) # Save a copy of input img org_img = input_img # Apply the transforms on the input image input_img = _transform(input_img) # Add batch dimension to perform inference input_img = input_img.unsqueeze(0) # Get Model Predictions with torch.no_grad(): outputs = model(input_img) o = torch.exp(outputs)[0] confidences = {classes[i]: float(o[i]) for i in range(10)} # Select the top classes based on user input sorted_confidences = sorted(confidences.items(), key=lambda val: val[1], reverse=True) show_confidences = OrderedDict(sorted_confidences[:number_of_top_classes]) # Name of layers defined in the model _layers = ['prep_layer', 'custom_block1', 'resnet_block1', 'custom_block2', 'custom_block3', 'resnet_block3'] target_layers = [eval(f'model.{_layers[target_layer_number-1]}[0]')] # Get the class activations from the selected layer cam = GradCAM(model=model, target_layers=target_layers) grayscale_cam = cam(input_tensor=input_img, targets=None) grayscale_cam = grayscale_cam[0, :] # Overlay input image with Class activations visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency) return show_confidences, visualization def display_misclassified_images(number: int = 1): """ Display the misclassified images saved during training :param number: Number of images to display """ # List to store names of misclassified images data = [] # Get the names of all the files from Misclassified directory file_names = os.listdir('misclassified/') # Save the correct name and misclassified class name as a tuple in the `data` list for file in file_names: file_name, extension = file.split('.') correct_label, misclassified = file_name.split('_') data.append((correct_label, misclassified)) # Create a path to the images for Gradio to access them file_path = ['misclassified/' + file for file in file_names] # Return the file path and names of correct and misclassified images return file_path[:number], data[:number] def feature_maps(input_img, kernel_number=32): """ Function to return feature maps for the selected image :param input_img: User input image :param kernel_number: Number of kernel in all 6 layers """ # Resize the image to (32, 32) input_img = Image.fromarray(input_img).resize((32, 32)) input_img = np.array(input_img) # Calculate mean over each channel of input image mean_r, mean_g, mean_b = np.mean(input_img[:, :, 0]/255.), np.mean(input_img[:, :, 1]/255.), np.mean(input_img[:, :, 2]/255.) # Calculate Standard deviation over each channel std_r, std_g, std_b = np.std(input_img[:, :, 0]/255.), np.std(input_img[:, :, 1]/255.), np.std(input_img[:, :, 2]/255.) # Convert img to tensor and normalize it _transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((mean_r, mean_g, mean_b), (std_r, std_g, std_b)) ]) # Apply transforms on the input image input_img = _transform(input_img) # Visualize feature maps for kernel number 32 plt = viz.visualize_feature_map_of_kernel(image=input_img, kernel_number=kernel_number) return plt def get_kernels(layer_number): """ Function to get the kernels from the layer :param layer_number: Number of layer from which kernels to be visualized """ # Visualize kernels from layer plt = viz.visualize_kernels_from_layer(layer_number=layer_number) return plt if __name__ == '__main__': with gr.Blocks() as demo: gr.Markdown( """ # CIFAR10 trained on ResNet18 Model A model architecture by [David C](https://github.com/davidcpage) which is trained on CIFAR10 for 24 Epochs to achieve accuracy of 90+% The model works for following classes: `plane`, `car`, `bird`, `cat`, `deer`, `dog`, `frog`, `horse`, `ship`, `truck` """ ) # ############################################################################# # ################################ GradCam Tab ################################ # ############################################################################# with gr.Tab("GradCam"): gr.Markdown( """ Visualize Class Activations Maps generated by the model's layer for the predicted class This is used to see what the model is actually looking at in the image """ ) with gr.Row(): img_input = gr.Image(label="Input Image") gradcam_outputs = [gr.Label(), gr.Image(label="Output")] with gr.Row(): gradcam_inputs = [gr.Slider(0, 1, value=0.5, label="How much percentage overlap of the input image on the activation maps?"), gr.Slider(1, 10, value=3, step=1, label="How many top class predictions you want to see?"), gr.Slider(1, 6, value=4, step=1, label="From 6 blocks of the model, which block's first convolutional layer's class activation you want to see?")] gradcam_button = gr.Button("Submit") gradcam_button.click(inference, inputs=[img_input] + gradcam_inputs, outputs=gradcam_outputs) gr.Markdown("## Examples") gr.Examples([example_directory + 'dog.jpg', example_directory + 'cat.jpg', example_directory + 'frog.jpg', example_directory + 'bird.jpg', example_directory + 'shark-plane.jpg', example_directory + 'car.jpg', example_directory + 'truck.jpg', example_directory + 'horse.jpg', example_directory + 'plane.jpg', example_directory + 'ship.png'], inputs=img_input, fn=inference) # ########################################################################################### # ################################ Misclassified Images Tab ################################# # ########################################################################################### with gr.Tab("Misclassified Images"): gr.Markdown( """ 10% of test images were misclassified by the model at the end of the training You can visualize those images with their correct label and misclassified label """ ) with gr.Row(): mis_inputs = gr.Slider(1, 10, value=1, step=1, label="Number of misclassified images to display") mis_outputs = [ gr.Gallery(label="Misclassified Images", show_label=False, elem_id="gallery"), gr.Dataframe(headers=["Correct Label", "Misclassified Label"], type="array", datatype="str", row_count=10, col_count=2)] mis_button = gr.Button("Display Misclassified Images") mis_button.click(display_misclassified_images, inputs=mis_inputs, outputs=mis_outputs) # ################################################################################################ # ################################ Feature Maps Visualization Tab ################################ # ################################################################################################ with gr.Tab("Feature Map Visualization"): gr.Markdown( """ The model has 6 convolutional blocks. Each block has two or three convolutional layers From each block's first convolutional layer, output of specific kernel number is visualized In the below images `l1` represents first block and `kx` represents the number of kerenel from the first convolutional layer of that block """ ) with gr.Column(): feature_map_input = gr.Image(label="Feature Map Input Image") feature_map_slider = gr.Slider(1, 32, value=16, step=1, label="Select a Kernel number whose Features Maps from all 6 block's to be shown") feature_map_output = gr.Plot() feature_map_button = gr.Button("Visualize FeatureMaps") feature_map_button.click(feature_maps, inputs=[feature_map_input, feature_map_slider], outputs=feature_map_output) # ########################################################################################## # ################################ Kernel Visualization Tab ################################ # ########################################################################################## with gr.Tab("Kernel Visualization"): gr.Markdown( """ The model has 6 convolutional blocks. Each block has two or three convolutional layers Some of the Kernels from the first convolutional layer of selected block number are visualized below """ ) with gr.Column(): kernel_input = gr.Slider(1, 4, value=2, step=1, label="Select a block number whose first convolutional layer's Kernels to be shown") kernel_output = gr.Plot() kernel_button = gr.Button("Visualize Kernels") kernel_button.click(get_kernels, inputs=kernel_input, outputs=kernel_output) gr.close_all() demo.launch()