ERABB / app.py
909ahmed's picture
check
de8d663
import gradio as gr
import torch, torchvision
from torchvision import transforms
from resnet import ResNet18
from resnet import ResBlocks
from PIL import Image
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
model = ResNet18(0.00333)
state_model = torch.load("final_model.pkl", map_location=torch.device('cpu'))
state_dict = state_model.state_dict()
model.load_state_dict(state_dict, strict=False)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
inv_normalize = transforms.Normalize(
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
std=[1/0.23, 1/0.23, 1/0.23]
)
def resize_image_pil(image, new_width, new_height):
img = Image.fromarray(np.array(image))
width, height = img.size
width_scale = new_width / width
height_scale = new_height / height
scale = min(width_scale, height_scale)
resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST)
resized = resized.crop((0, 0, new_width, new_height))
return np.array(resized)
def inference(input_img, transparency = 0.5, target_layer_number = -1):
input_img = resize_image_pil(input_img, 32, 32)
org_img = input_img
input_img = input_img.reshape((32, 32, 3))
transform = transforms.ToTensor()
input_img = transform(input_img)
input_img = input_img.unsqueeze(0)
input_img = cifar10_normalization()(input_img)
outputs = model(input_img)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
confidences = {classes[i]: float(o[i]) for i in range(10)}
_, prediction = torch.max(outputs, 1)
target_layers = [model.res_layers[2][target_layer_number]]
cam = GradCAM(model=model, target_layers=target_layers)
grayscale_cam = cam(input_tensor=input_img, targets=None)
grayscale_cam = grayscale_cam[0, :]
img = input_img.squeeze(0)
img = inv_normalize(img)
print(transparency)
visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
return classes[prediction[0].item()], visualization, confidences
title = "CIFAR10 trained on ResNet18 Model with GradCAM"
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
iface = gr.Interface(
inference,
inputs = [
gr.Image(width=256, height=256, label="Input Image"),
gr.Slider(0, 1, value = 0.5, label="Overall Opacity of Image"),
gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?")
],
outputs = [
"text",
gr.Image(width=256, height=256, label="Output"),
gr.Label(num_top_classes=3)
],
title = title,
description = description,
)
iface.launch()