ERA-S12 / app.py
gupta1912's picture
Update app.py
3e1ae80
import torch, torchvision
from torchvision import transforms
from torchvision import datasets
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
import itertools
import matplotlib.pyplot as plt
from utils import LitCIFAR10
model = LitCIFAR10.load_from_checkpoint("model/model.ckpt")
model.eval()
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
means = [0.4914, 0.4822, 0.4465]
stds = [0.2470, 0.2435, 0.2616]
cifar_testset = datasets.CIFAR10(root='.', train=False, download=True)
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(means, stds)
])
class ClassifierOutputTarget:
def __init__(self, category):
self.category = category
def __call__(self, model_output):
if len(model_output.shape) == 1:
return model_output[self.category]
return model_output[:, self.category]
def inference(wants_gradcam, n_gradcam, target_layer_number, transparency, wants_misclassified, n_misclassified, input_img = None, n_top_classes=10):
if wants_gradcam:
outputs_inference_gc = []
count_gradcam = 1
for data, target in cifar_testset:
input_tensor = preprocess_image(data,
mean=means,
std=stds)
target_layers = [model.model.layer3[target_layer_number]]
targets = [ClassifierOutputTarget(target)]
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
rgb_img = np.float32(data) / 255
visualization = np.array(show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency))
outputs_inference_gc.append(visualization)
count_gradcam += 1
if count_gradcam > n_gradcam:
break
else:
outputs_inference_gc = None
if wants_misclassified:
outputs_inference_mis = []
count_mis = 1
for data_, target in cifar_testset:
data = transform(data_)
data = data.unsqueeze(0)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
if pred.item()!=target:
count_mis += 1
fig = plt.figure()
fig.add_subplot(111)
plt.imshow(data_)
plt.title(f'Target: {classes[target]}\nPred: {classes[pred.item()]}')
plt.axis('off')
fig.canvas.draw()
fig_img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
fig_img = fig_img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close(fig)
outputs_inference_mis.append(fig_img)
if count_mis > n_misclassified:
break
else:
outputs_inference_mis = None
if input_img is not None:
data = transform(input_img)
data = data.unsqueeze(0)
output = model(data)
softmax = torch.nn.Softmax(dim=0)
o = softmax(output.flatten())
confidences = {classes[i]: float(o[i]) for i in range(10)}
_, prediction = torch.max(output, 1)
confidences = {k: v for k, v in sorted(confidences.items(), key=lambda item: item[1], reverse=True)}
confidences = dict(itertools.islice(confidences.items(), n_top_classes))
else:
confidences = None
return outputs_inference_gc, outputs_inference_mis, confidences
title = "CIFAR10 trained on Custom ResNet Model with GradCAM"
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
examples = [[None, None, None, None, None, None, 'Images/test_'+str(i)+'.jpg', None] for i in range(10)]
demo = gr.Interface(inference,
inputs = [gr.Checkbox(False, label='Do you want to see GradCAM outputs?'),
gr.Slider(0, 10, value = 0, step=1, label="How many?"),
gr.Slider(-2, -1, value = -2, step=1, label="Which target layer?"),
gr.Slider(0, 1, value = 0, label="Opacity of GradCAM"),
gr.Checkbox(False, label='Do you want to see misclassified images?'),
gr.Slider(0, 10, value = 0, step=1, label="How many?"),
gr.Image(shape=(32, 32), label="Input image"),
gr.Slider(0, 10, value = 0, step=1, label="How many top classes you want to see?")
],
outputs = [
gr.Gallery(label="GradCAM Outputs", show_label=True, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto"),
gr.Gallery(label="Misclassified Images", show_label=True, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto"),
gr.Label(num_top_classes=None)
],
title = title,
description = description,
examples = examples
)
demo.launch()