amitkayal's picture
Update app.py
341395d
raw
history blame contribute delete
No virus
5.97 kB
import torch, torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnet import ResNet18
import gradio as gr
model = ResNet18()
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
inv_normalize = transforms.Normalize(
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
std=[1/0.23, 1/0.23, 1/0.23]
)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
def inference(input_img, transparency = 0.5, target_layer_number = -1):
transform = transforms.ToTensor()
org_img = input_img
input_img = transform(input_img)
input_img = input_img
input_img = input_img.unsqueeze(0)
outputs = model(input_img)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
confidences = {classes[i]: float(o[i]) for i in range(10)}
_, prediction = torch.max(outputs, 1)
target_layers = [model.layer2[target_layer_number]]
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
img = input_img.squeeze(0)
img = inv_normalize(img)
rgb_img = np.transpose(img, (1, 2, 0))
rgb_img = rgb_img.numpy()
visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
return confidences, visualization
def inference_confidences(input_img, transparency = 0.5, target_layer_number = -1):
transform = transforms.ToTensor()
org_img = input_img
input_img = transform(input_img)
input_img = input_img
input_img = input_img.unsqueeze(0)
outputs = model(input_img)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
confidences = {classes[i]: float(o[i]) for i in range(10)}
return confidences
def inference_visualization(input_img, transparency = 0.5, target_layer_number = -1):
transform = transforms.ToTensor()
org_img = input_img
input_img = transform(input_img)
input_img = input_img
input_img = input_img.unsqueeze(0)
outputs = model(input_img)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
confidences = {classes[i]: float(o[i]) for i in range(10)}
_, prediction = torch.max(outputs, 1)
target_layers = [model.layer2[target_layer_number]]
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
img = input_img.squeeze(0)
img = inv_normalize(img)
rgb_img = np.transpose(img, (1, 2, 0))
rgb_img = rgb_img.numpy()
visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
return visualization
# Callback function for the Gradio interface
# def gradio_callback(view_gradcam, num_gradcam_images, layer_name, opacity,
# view_misclassified, num_misclassified_images,
# input_img,submit):
def gradio_callback(view_grad_cam, num_gradcam_images, view_misclassified, num_misclassified_images,
input_img, transparency = 0.5, target_layer_number = -1):
confidence = inference_confidences(input_img, transparency = 0.5, target_layer_number = -1)
if view_grad_cam == "Yes":
visualization = inference_visualization(input_img, transparency = 0.5, target_layer_number = -1)
return confidence, visualization
else:
return confidence
title = "CIFAR10 trained on ResNet18 Model with GradCAM"
description = "Gradio interface to infer on ResNet18 model, and get GradCAM results"
examples = [["Yes",5,"Yes",5,"cat.jpg", 0.5, -1], ["Yes",5,"Yes",5,"dog.jpg", 0.5, -1]]
demo = gr.Interface(
# inference,
# inputs = [gr.Image(shape=(32, 32), label="Input Image"), gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?")],
# outputs = [gr.Label(num_top_classes=3), gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)],
# title = title,
# description = description,
# examples = examples,
title = title,
escription = description,
# examples = examples,
fn=gradio_callback, # We'll add the function later after defining all functions, # We'll add the function later after defining all functions
inputs=[
# gr.Radio(["Yes", "No"], label="View GradCAM images?"),
# gr.Number(label="Number of GradCAM images to view", default=5, max=10),
# gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?"),
# gr.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.5, label="Opacity"),
# gr.Radio(["Yes", "No"], label="View misclassified images?"),
# gr.Number(label="Number of misclassified images to view", default=5, min=1, max=10),
# gr.Image(shape=(32, 32), label="Input Image")
# gr.Radio(["Yes", "No"], label="View GradCAM images?"),
gr.Radio(["Yes", "No"], label="GradCAM images", info="View GradCAM images?"),
gr.Number(label="Number of GradCAM images to view", default=5, max=10),
gr.Radio(["Yes", "No"], label="View misclassified images?"),
gr.Number(label="Number of misclassified images to view", default=5, min=1, max=10),
gr.Image(shape=(32, 32), label="Input Image"),
gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"),
gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?")
],
outputs = [gr.Label(num_top_classes=3), gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)],
examples = examples,
# live=True
)
# Set the callback function to the Gradio interface
# demo.fn = gradio_callback
demo.launch()