s12erav1 / app.py
piyushgrover's picture
Update app.py
51e2303
import gradio as gr
import torch
from torchvision import transforms
import numpy as np
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import io
from models import custom_resnet_lightning_s10
from utils import load_model_from_checkpoint, denormalize, get_data_label_name, get_dataset_labels
device = torch.device('cpu')
dataset_mean, dataset_std = (0.4914, 0.4822, 0.4465), \
(0.2470, 0.2435, 0.2616)
model = custom_resnet_lightning_s10.S10LightningModel(64)
checkpoint = load_model_from_checkpoint(device)
model.load_state_dict(checkpoint['model'], strict=False)
test_incorrect_pred = checkpoint['test_incorrect_pred']
sample_images = [
['images/aeroplane.jpeg', 0],
['images/bird.jpeg', 2],
['images/car.jpeg', 1],
['images/cat.jpeg', 3],
['images/deer.jpeg', 4],
['images/dog.jpeg', 5],
['images/frog.jpeg', 6],
['images/horse.jpeg', 7],
['images/ship.jpeg', 8],
['images/truck.jpeg', 9]
]
with gr.Blocks() as app:
'''
Select feature interface
'''
with gr.Row() as input_radio_group:
radio_btn = gr.Radio(
choices=['Top Prediction Classes', 'Missclassified Images', 'GradCAM Images'],
type="index",
label='Feature options',
info="Choose which feature you want to explore",
value='Top Prediction Classes'
)
'''
Options for GradCAM feature
'''
with gr.Row():
with gr.Column(visible=False) as grad_cam_col:
grad_cam_count = gr.Slider(1, 20, value=5, step=1, label="Choose image count",
info="How many images you want to view?")
grad_cam_layer = gr.Slider(-4, -1, value=-3, step=1, label="Choose model layer",
info="Which layer you want to view GradCAM on? [-4 => last layer]")
grad_cam_opacity = gr.Slider(0, 1, value=0.4, step=0.1, label="Choose opacity of the gradient")
with gr.Column():
grad_cam_btn = gr.Button("Yes, Go Ahead", variant='primary')
with gr.Column(visible=False) as grad_cam_output:
grad_cam_output_gallery = gr.Gallery(value=[], columns=3, label='Output')
# prediction_title = gr.Label(value='')
'''
Options for Missclassfied images feature
'''
with gr.Row(visible=False) as missclassified_col:
with gr.Row():
missclassified_img_count = gr.Slider(1, 20, value=5, step=1, label="Choose image count",
info="How many missclassified images you want to view?")
missclassified_btn = gr.Button("Click to Continue", variant='primary')
with gr.Row(visible=False) as missclassified_img_output:
missclassified_img_output_gallery = gr.Gallery(value=[], columns=5, label='Output')
'''
Option for Top prediction classes
'''
with gr.Row(visible=True) as top_pred_cls_col:
with gr.Column():
example_images = gr.Gallery(allow_preview=False, label='Select image ', info='', value=[img[0] for img in sample_images], columns=3, rows=2, object_fit='scale_down')
with gr.Column():
with gr.Row():
top_pred_image = gr.Image(shape=(32, 32), label='Upload Image or Select from the gallery')
top_class_count = gr.Slider(1, 10, value=5, step=1, label="Number of classes to predict")
top_class_btn = gr.Button("Submit", variant='primary')
tc_clear_btn = gr.ClearButton()
with gr.Row(visible=True) as top_class_output:
#top_class_output_img = gr.Image().style(width=256, height=256)
top_class_output_labels = gr.Label(num_top_classes=top_class_count.value, label='Output')
def clear_data():
return {
top_pred_image: None,
top_class_output_labels: None
}
tc_clear_btn.click(clear_data, None, [top_pred_image, top_class_output_labels])
def on_select(evt: gr.SelectData):
return {
top_pred_image: sample_images[evt.index][0]
}
example_images.select(on_select, None, top_pred_image)
def top_class_img_upload(input_img, top_class_count):
if input_img is not None:
transform = transforms.ToTensor()
org_img = input_img
input_img = transform(input_img)
input_img = input_img.to(device)
input_img = input_img.unsqueeze(0)
outputs = model(input_img, no_softmax=True)
softmax = torch.nn.Softmax(dim=0)
o = softmax(outputs.flatten())
confidences = {get_dataset_labels()[i]: float(o[i]) for i in range(10)}
top_class_output_labels.num_top_classes = top_class_count
#tc_clear_btn.add([top_pred_image, top_class_output_labels])
return {
top_class_output: gr.update(visible=True),
#top_class_output_img: org_img,
top_class_output_labels: confidences
}
top_class_btn.click(
top_class_img_upload,
[top_pred_image, top_class_count],
[top_class_output, top_class_output_labels]
)
'''
Missclassified Images feature
'''
def show_missclassified_images(img_count):
imgs = []
for i in range(img_count):
img = test_incorrect_pred['images'][i].cpu()
img = denormalize(img, dataset_mean, dataset_std)
img = np.array(255 * img, np.int16).transpose(1, 2, 0)
label = 'βœ… ' + get_data_label_name(
test_incorrect_pred['ground_truths'][i].item()) + ' ❌ ' + get_data_label_name(
test_incorrect_pred['predicted_vals'][i].item())
imgs.append((img, label))
return {
missclassified_img_output: gr.update(visible=True),
missclassified_img_output_gallery: imgs
}
missclassified_btn.click(
show_missclassified_images,
[missclassified_img_count],
[missclassified_img_output_gallery, missclassified_img_output]
)
'''
GradCAM Feature
'''
def grad_cam_submit(img_count, layer_idx, grad_opacity):
target_layers = [model.get_layer(-1 * (layer_idx + 1))]
cam = GradCAM(model=model, target_layers=target_layers)
visual_arr = []
pred_arr = []
for i in range(img_count):
pred_dict = test_incorrect_pred
targets = [ClassifierOutputTarget(pred_dict['ground_truths'][i].cpu().item())]
grayscale_cam = cam(input_tensor=pred_dict['images'][i][None, :].cpu(), targets=targets)
x = denormalize(pred_dict['images'][i].cpu(), dataset_mean, dataset_std)
image = np.array(255 * x, np.int16).transpose(1, 2, 0)
img_tensor = np.array(x, np.float16).transpose(1, 2, 0)
visualization = show_cam_on_image(img_tensor, grayscale_cam.transpose(1, 2, 0), use_rgb=True,
image_weight=(1.0 - grad_opacity))
visual_arr.append(
(visualization, get_data_label_name(pred_dict['ground_truths'][i].item()))
)
return {
grad_cam_output: gr.update(visible=True),
grad_cam_output_gallery: visual_arr
}
grad_cam_btn.click(
grad_cam_submit,
[grad_cam_count, grad_cam_layer, grad_cam_opacity],
[grad_cam_output_gallery, grad_cam_output]
)
'''
Select Feature to showcase
'''
def select_feature(feature):
if feature == 0:
return {
grad_cam_col: gr.update(visible=False),
grad_cam_output: gr.update(visible=False),
missclassified_col: gr.update(visible=False),
missclassified_img_output: gr.update(visible=False),
top_pred_cls_col: gr.update(visible=True),
top_class_output: gr.update(visible=True)
}
elif feature == 1:
return {
grad_cam_col: gr.update(visible=False),
grad_cam_output: gr.update(visible=False),
missclassified_col: gr.update(visible=True),
missclassified_img_output: gr.update(visible=True),
top_pred_cls_col: gr.update(visible=False),
top_class_output: gr.update(visible=False)
}
else:
return {
grad_cam_col: gr.update(visible=True),
grad_cam_output: gr.update(visible=True),
missclassified_col: gr.update(visible=False),
missclassified_img_output: gr.update(visible=False),
top_pred_cls_col: gr.update(visible=False),
top_class_output: gr.update(visible=False)
}
radio_btn.change(select_feature,
[radio_btn],
[grad_cam_col, grad_cam_output, missclassified_col, missclassified_img_output, top_pred_cls_col, top_class_output])
'''
Launch the app
'''
app.launch()