nviraj's picture
Fixes GradioUnusedKwargWarning: You have unused kwarg parameters
1134c36
raw
history blame contribute delete
No virus
7.41 kB
# Outline
# Import packages
# Import modules
# Constants
# Load model
# Function to process user uploaded image/ examples
# Inference function
# Gradio examples
# Gradio App
# Import packages required for the app
import gradio as gr
# Import custom modules
import modules.config as config
import numpy as np
import torch
# import torchvision
from modules.custom_resnet import CustomResNet
from modules.visualize import plot_gradcam_images, plot_misclassified_images
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision import transforms
# Load and initialize the model
model = CustomResNet()
# Define device
cpu = torch.device("cpu")
# Using the checkpoint path present in config, load the trained model
model.load_state_dict(torch.load(config.MODEL_PATH, map_location=cpu), strict=False)
# Send model to CPU
model.to(cpu)
# Make the model in evaluation mode
model.eval()
# print(f"Model Device: {next(model.parameters()).device}")
# Load the misclassified images data
misclassified_image_data = torch.load(config.MISCLASSIFIED_PATH, map_location=cpu)
# Class Names
classes = list(config.CIFAR_CLASSES)
# Allowed model names
model_layer_names = ["prep", "layer1_x", "layer1_r1", "layer2", "layer3_x", "layer3_r2"]
def get_target_layer(layer_name):
"""Get target layer for visualization"""
if layer_name == "prep":
return [model.prep[-1]]
elif layer_name == "layer1_x":
return [model.layer1_x[-1]]
elif layer_name == "layer1_r1":
return [model.layer1_r1[-1]]
elif layer_name == "layer2":
return [model.layer2[-1]]
elif layer_name == "layer3_x":
return [model.layer3_x[-1]]
elif layer_name == "layer3_r2":
return [model.layer3_r2[-1]]
else:
return None
def generate_prediction(input_image, num_classes=3, show_gradcam=True, transparency=0.6, layer_name="layer3_x"):
""" "Given an input image, generate the prediction, confidence and display_image"""
mean = list(config.CIFAR_MEAN)
std = list(config.CIFAR_STD)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
with torch.no_grad():
orginal_img = input_image
input_image = transform(input_image).unsqueeze(0).to(cpu)
# print(f"Input Device: {input_image.device}")
model_output = model(input_image).to(cpu)
# print(f"Output Device: {outputs.device}")
output_exp = torch.exp(model_output).to(cpu)
# print(f"Output Exp Device: {o.device}")
output_numpy = np.squeeze(np.asarray(output_exp.numpy()))
# get indexes of probabilties in descending order
sorted_indexes = np.argsort(output_numpy)[::-1]
# sort the probabilities in descending order
# final_class = classes[o_np.argmax()]
confidences = {}
for _ in range(int(num_classes)):
# set the confidence of highest class with highest probability
confidences[classes[sorted_indexes[_]]] = float(output_numpy[sorted_indexes[_]])
# Show Grad Cam
if show_gradcam:
# Get the target layer
target_layers = get_target_layer(layer_name)
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
cam_generated = cam(input_tensor=input_image, targets=None)
cam_generated = cam_generated[0, :]
display_image = show_cam_on_image(orginal_img / 255, cam_generated, use_rgb=True, image_weight=transparency)
else:
display_image = orginal_img
return confidences, display_image
def app_interface(
input_image,
num_classes,
show_gradcam,
layer_name,
transparency,
show_misclassified,
num_misclassified,
show_gradcam_misclassified,
num_gradcam_misclassified,
):
"""Function which provides the Gradio interface"""
# Get the prediction for the input image along with confidence and display_image
confidences, display_image = generate_prediction(input_image, num_classes, show_gradcam, transparency, layer_name)
if show_misclassified:
misclassified_fig, misclassified_axs = plot_misclassified_images(
data=misclassified_image_data, class_label=classes, num_images=num_misclassified
)
else:
misclassified_fig = None
if show_gradcam_misclassified:
gradcam_fig, gradcam_axs = plot_gradcam_images(
model=model,
data=misclassified_image_data,
class_label=classes,
# Use penultimate block of resnet18 layer 3 as the target layer for gradcam
# Decided using model summary so that dimensions > 7x7
target_layers=get_target_layer(layer_name),
targets=None,
num_images=num_gradcam_misclassified,
image_weight=transparency,
)
else:
gradcam_fig = None
# # delete ununsed axises
# del misclassified_axs
# del gradcam_axs
return confidences, display_image, misclassified_fig, gradcam_fig
TITLE = "CIFAR10 Image classification using a Custom ResNet Model"
DESCRIPTION = "Gradio App to infer using a Custom ResNet model and get GradCAM results"
examples = [
["assets/images/airplane.jpg", 3, True, "layer3_x", 0.6, True, 5, True, 5],
["assets/images/bird.jpeg", 4, True, "layer3_x", 0.7, True, 10, True, 20],
["assets/images/car.jpg", 5, True, "layer3_x", 0.5, True, 15, True, 5],
["assets/images/cat.jpeg", 6, True, "layer3_x", 0.65, True, 20, True, 10],
["assets/images/deer.jpg", 7, False, "layer2", 0.75, True, 5, True, 5],
["assets/images/dog.jpg", 8, True, "layer2", 0.55, True, 10, True, 5],
["assets/images/frog.jpeg", 9, True, "layer2", 0.8, True, 15, True, 15],
["assets/images/horse.jpg", 10, False, "layer1_r1", 0.85, True, 20, True, 5],
["assets/images/ship.jpg", 3, True, "layer1_r1", 0.4, True, 5, True, 15],
["assets/images/truck.jpg", 4, True, "layer1_r1", 0.3, True, 5, True, 10],
]
inference_app = gr.Interface(
app_interface,
inputs=[
# This accepts the image after resizing it to 32x32 which is what our model expects
gr.Image(shape=(32, 32)),
gr.Number(value=3, maximum=10, minimum=1, precision=0, label="#Classes to show"),
gr.Checkbox(True, label="Show GradCAM Image"),
gr.Dropdown(model_layer_names, value="layer3_x", label="Visulalization Layer from Model"),
# How much should the image be overlayed on the original image
gr.Slider(0, 1, 0.6, label="Image Overlay Factor"),
gr.Checkbox(True, label="Show Misclassified Images?"),
gr.Slider(value=10, maximum=25, minimum=5, step=5.0, label="#Misclassified images to show"),
gr.Checkbox(True, label="Visulize GradCAM for Misclassified images?"),
gr.Slider(value=10, maximum=25, minimum=5, step=5.0, label="#GradCAM images to show"),
],
outputs=[
gr.Label(label="Confidences", container=True, show_label=True),
gr.Image(shape=(32, 32), label="Grad CAM/ Input Image", container=True, show_label=True).style(
width=256, height=256
),
gr.Plot(label="Misclassified images", container=True, show_label=True),
gr.Plot(label="Grad CAM of Misclassified images", container=True, show_label=True),
],
title=TITLE,
description=DESCRIPTION,
examples=examples,
)
inference_app.launch()