Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torchvision import transforms | |
import numpy as np | |
from PIL import Image | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
import io | |
from models import custom_resnet_lightning_s10 | |
from utils import load_model_from_checkpoint, denormalize, get_data_label_name, get_dataset_labels | |
device = torch.device('cpu') | |
dataset_mean, dataset_std = (0.4914, 0.4822, 0.4465), \ | |
(0.2470, 0.2435, 0.2616) | |
model = custom_resnet_lightning_s10.S10LightningModel(64) | |
checkpoint = load_model_from_checkpoint(device) | |
model.load_state_dict(checkpoint['model'], strict=False) | |
test_incorrect_pred = checkpoint['test_incorrect_pred'] | |
sample_images = [ | |
['images/aeroplane.jpeg', 0], | |
['images/bird.jpeg', 2], | |
['images/car.jpeg', 1], | |
['images/cat.jpeg', 3], | |
['images/deer.jpeg', 4], | |
['images/dog.jpeg', 5], | |
['images/frog.jpeg', 6], | |
['images/horse.jpeg', 7], | |
['images/ship.jpeg', 8], | |
['images/truck.jpeg', 9] | |
] | |
with gr.Blocks() as app: | |
''' | |
Select feature interface | |
''' | |
with gr.Row() as input_radio_group: | |
radio_btn = gr.Radio( | |
choices=['Top Prediction Classes', 'Missclassified Images', 'GradCAM Images'], | |
type="index", | |
label='Feature options', | |
info="Choose which feature you want to explore", | |
value='Top Prediction Classes' | |
) | |
''' | |
Options for GradCAM feature | |
''' | |
with gr.Row(): | |
with gr.Column(visible=False) as grad_cam_col: | |
grad_cam_count = gr.Slider(1, 20, value=5, step=1, label="Choose image count", | |
info="How many images you want to view?") | |
grad_cam_layer = gr.Slider(-4, -1, value=-3, step=1, label="Choose model layer", | |
info="Which layer you want to view GradCAM on? [-4 => last layer]") | |
grad_cam_opacity = gr.Slider(0, 1, value=0.4, step=0.1, label="Choose opacity of the gradient") | |
with gr.Column(): | |
grad_cam_btn = gr.Button("Yes, Go Ahead", variant='primary') | |
with gr.Column(visible=False) as grad_cam_output: | |
grad_cam_output_gallery = gr.Gallery(value=[], columns=3, label='Output') | |
# prediction_title = gr.Label(value='') | |
''' | |
Options for Missclassfied images feature | |
''' | |
with gr.Row(visible=False) as missclassified_col: | |
with gr.Row(): | |
missclassified_img_count = gr.Slider(1, 20, value=5, step=1, label="Choose image count", | |
info="How many missclassified images you want to view?") | |
missclassified_btn = gr.Button("Click to Continue", variant='primary') | |
with gr.Row(visible=False) as missclassified_img_output: | |
missclassified_img_output_gallery = gr.Gallery(value=[], columns=5, label='Output') | |
''' | |
Option for Top prediction classes | |
''' | |
with gr.Row(visible=True) as top_pred_cls_col: | |
with gr.Column(): | |
example_images = gr.Gallery(allow_preview=False, label='Select image ', info='', value=[img[0] for img in sample_images], columns=3, rows=2, object_fit='scale_down') | |
with gr.Column(): | |
with gr.Row(): | |
top_pred_image = gr.Image(shape=(32, 32), label='Upload Image or Select from the gallery') | |
top_class_count = gr.Slider(1, 10, value=5, step=1, label="Number of classes to predict") | |
top_class_btn = gr.Button("Submit", variant='primary') | |
tc_clear_btn = gr.ClearButton() | |
with gr.Row(visible=True) as top_class_output: | |
#top_class_output_img = gr.Image().style(width=256, height=256) | |
top_class_output_labels = gr.Label(num_top_classes=top_class_count.value, label='Output') | |
def clear_data(): | |
return { | |
top_pred_image: None, | |
top_class_output_labels: None | |
} | |
tc_clear_btn.click(clear_data, None, [top_pred_image, top_class_output_labels]) | |
def on_select(evt: gr.SelectData): | |
return { | |
top_pred_image: sample_images[evt.index][0] | |
} | |
example_images.select(on_select, None, top_pred_image) | |
def top_class_img_upload(input_img, top_class_count): | |
if input_img is not None: | |
transform = transforms.ToTensor() | |
org_img = input_img | |
input_img = transform(input_img) | |
input_img = input_img.to(device) | |
input_img = input_img.unsqueeze(0) | |
outputs = model(input_img, no_softmax=True) | |
softmax = torch.nn.Softmax(dim=0) | |
o = softmax(outputs.flatten()) | |
confidences = {get_dataset_labels()[i]: float(o[i]) for i in range(10)} | |
top_class_output_labels.num_top_classes = top_class_count | |
#tc_clear_btn.add([top_pred_image, top_class_output_labels]) | |
return { | |
top_class_output: gr.update(visible=True), | |
#top_class_output_img: org_img, | |
top_class_output_labels: confidences | |
} | |
top_class_btn.click( | |
top_class_img_upload, | |
[top_pred_image, top_class_count], | |
[top_class_output, top_class_output_labels] | |
) | |
''' | |
Missclassified Images feature | |
''' | |
def show_missclassified_images(img_count): | |
imgs = [] | |
for i in range(img_count): | |
img = test_incorrect_pred['images'][i].cpu() | |
img = denormalize(img, dataset_mean, dataset_std) | |
img = np.array(255 * img, np.int16).transpose(1, 2, 0) | |
label = 'β ' + get_data_label_name( | |
test_incorrect_pred['ground_truths'][i].item()) + ' β ' + get_data_label_name( | |
test_incorrect_pred['predicted_vals'][i].item()) | |
imgs.append((img, label)) | |
return { | |
missclassified_img_output: gr.update(visible=True), | |
missclassified_img_output_gallery: imgs | |
} | |
missclassified_btn.click( | |
show_missclassified_images, | |
[missclassified_img_count], | |
[missclassified_img_output_gallery, missclassified_img_output] | |
) | |
''' | |
GradCAM Feature | |
''' | |
def grad_cam_submit(img_count, layer_idx, grad_opacity): | |
target_layers = [model.get_layer(-1 * (layer_idx + 1))] | |
cam = GradCAM(model=model, target_layers=target_layers) | |
visual_arr = [] | |
pred_arr = [] | |
for i in range(img_count): | |
pred_dict = test_incorrect_pred | |
targets = [ClassifierOutputTarget(pred_dict['ground_truths'][i].cpu().item())] | |
grayscale_cam = cam(input_tensor=pred_dict['images'][i][None, :].cpu(), targets=targets) | |
x = denormalize(pred_dict['images'][i].cpu(), dataset_mean, dataset_std) | |
image = np.array(255 * x, np.int16).transpose(1, 2, 0) | |
img_tensor = np.array(x, np.float16).transpose(1, 2, 0) | |
visualization = show_cam_on_image(img_tensor, grayscale_cam.transpose(1, 2, 0), use_rgb=True, | |
image_weight=(1.0 - grad_opacity)) | |
visual_arr.append( | |
(visualization, get_data_label_name(pred_dict['ground_truths'][i].item())) | |
) | |
return { | |
grad_cam_output: gr.update(visible=True), | |
grad_cam_output_gallery: visual_arr | |
} | |
grad_cam_btn.click( | |
grad_cam_submit, | |
[grad_cam_count, grad_cam_layer, grad_cam_opacity], | |
[grad_cam_output_gallery, grad_cam_output] | |
) | |
''' | |
Select Feature to showcase | |
''' | |
def select_feature(feature): | |
if feature == 0: | |
return { | |
grad_cam_col: gr.update(visible=False), | |
grad_cam_output: gr.update(visible=False), | |
missclassified_col: gr.update(visible=False), | |
missclassified_img_output: gr.update(visible=False), | |
top_pred_cls_col: gr.update(visible=True), | |
top_class_output: gr.update(visible=True) | |
} | |
elif feature == 1: | |
return { | |
grad_cam_col: gr.update(visible=False), | |
grad_cam_output: gr.update(visible=False), | |
missclassified_col: gr.update(visible=True), | |
missclassified_img_output: gr.update(visible=True), | |
top_pred_cls_col: gr.update(visible=False), | |
top_class_output: gr.update(visible=False) | |
} | |
else: | |
return { | |
grad_cam_col: gr.update(visible=True), | |
grad_cam_output: gr.update(visible=True), | |
missclassified_col: gr.update(visible=False), | |
missclassified_img_output: gr.update(visible=False), | |
top_pred_cls_col: gr.update(visible=False), | |
top_class_output: gr.update(visible=False) | |
} | |
radio_btn.change(select_feature, | |
[radio_btn], | |
[grad_cam_col, grad_cam_output, missclassified_col, missclassified_img_output, top_pred_cls_col, top_class_output]) | |
''' | |
Launch the app | |
''' | |
app.launch() |