GradeCam / app.py
sagar007's picture
Create app.py
3f461cc verified
raw
history blame contribute delete
No virus
4.37 kB
import torch
import torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnet import ResNet18
model = ResNet18()
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
inv_normalize = transforms.Normalize(
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
std=[1/0.23, 1/0.23, 1/0.23]
)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def resize_image_pil(image, new_width, new_height):
img = Image.fromarray(np.array(image))
width, height = img.size
width_scale = new_width / width
height_scale = new_height / height
scale = min(width_scale, height_scale)
resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST)
resized = resized.crop((0, 0, new_width, new_height))
return resized
def inference(input_img, show_gradcam, num_gradcam, target_layer_number, opacity, show_misclassified, num_misclassified, num_top_classes):
input_img = resize_image_pil(input_img, 32, 32)
input_img = np.array(input_img)
org_img = input_img
input_img = input_img.reshape((32, 32, 3))
transform = transforms.ToTensor()
input_img = transform(input_img)
input_img = input_img.unsqueeze(0)
outputs = model(input_img)
softmax = torch.nn.Softmax(dim=1)
probs = softmax(outputs)
top_probs, top_labels = torch.topk(probs, k=min(num_top_classes, 10))
top_classes = [classes[idx] for idx in top_labels[0]]
confidences = {cls: float(prob) for cls, prob in zip(top_classes, top_probs[0])}
_, prediction = torch.max(outputs, 1)
predicted_class = classes[prediction[0].item()]
results = [predicted_class, confidences]
if show_gradcam:
target_layers = [model.layer2[target_layer_number]]
cam = GradCAM(model=model, target_layers=target_layers)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=opacity)
results.append(visualization)
if show_misclassified:
# This part would require a dataset of images to check for misclassifications
# For demonstration, we'll just return a placeholder message
results.append("Misclassified images feature would be implemented here")
return results
def launch():
with gr.Blocks() as demo:
gr.Markdown("# CIFAR10 ResNet18 Model with GradCAM")
with gr.Row():
input_image = gr.Image(width=256, height=256, label="Input Image")
output_image = gr.Image(width=256, height=256, label="GradCAM Output")
with gr.Row():
prediction = gr.Textbox(label="Predicted Class")
confidences = gr.Label(label="Top Class Confidences")
with gr.Row():
show_gradcam = gr.Checkbox(label="Show GradCAM")
num_gradcam = gr.Slider(1, 5, value=1, step=1, label="Number of GradCAM images")
target_layer = gr.Slider(-2, -1, value=-2, step=1, label="Target Layer")
opacity = gr.Slider(0, 1, value=0.5, label="GradCAM Opacity")
with gr.Row():
show_misclassified = gr.Checkbox(label="Show Misclassified Images")
num_misclassified = gr.Slider(1, 10, value=5, step=1, label="Number of Misclassified Images")
num_top_classes = gr.Slider(1, 10, value=3, step=1, label="Number of Top Classes to Show")
submit_btn = gr.Button("Submit")
example_images = gr.Dataset(
components=[input_image],
samples=[["cat.jpg"], ["dog.jpg"], ["bird.jpg"], ["plane.jpg"], ["car.jpg"],
["deer.jpg"], ["frog.jpg"], ["horse.jpg"], ["ship.jpg"], ["truck.jpg"]]
)
submit_btn.click(
inference,
inputs=[input_image, show_gradcam, num_gradcam, target_layer, opacity,
show_misclassified, num_misclassified, num_top_classes],
outputs=[prediction, confidences, output_image]
)
demo.launch()
if __name__ == "__main__":
launch()