import gradio as gr import torch from torchvision import transforms import numpy as np from PIL import Image from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image import io from models import custom_resnet_lightning_s10 from utils import load_model_from_checkpoint, denormalize, get_data_label_name, get_dataset_labels device = torch.device('cpu') dataset_mean, dataset_std = (0.4914, 0.4822, 0.4465), \ (0.2470, 0.2435, 0.2616) model = custom_resnet_lightning_s10.S10LightningModel(64) checkpoint = load_model_from_checkpoint(device) model.load_state_dict(checkpoint['model'], strict=False) test_incorrect_pred = checkpoint['test_incorrect_pred'] sample_images = [ ['images/aeroplane.jpeg', 0], ['images/bird.jpeg', 2], ['images/car.jpeg', 1], ['images/cat.jpeg', 3], ['images/deer.jpeg', 4], ['images/dog.jpeg', 5], ['images/frog.jpeg', 6], ['images/horse.jpeg', 7], ['images/ship.jpeg', 8], ['images/truck.jpeg', 9] ] with gr.Blocks() as app: ''' Select feature interface ''' with gr.Row() as input_radio_group: radio_btn = gr.Radio( choices=['Top Prediction Classes', 'Missclassified Images', 'GradCAM Images'], type="index", label='Feature options', info="Choose which feature you want to explore", value='Top Prediction Classes' ) ''' Options for GradCAM feature ''' with gr.Row(): with gr.Column(visible=False) as grad_cam_col: grad_cam_count = gr.Slider(1, 20, value=5, step=1, label="Choose image count", info="How many images you want to view?") grad_cam_layer = gr.Slider(-4, -1, value=-3, step=1, label="Choose model layer", info="Which layer you want to view GradCAM on? [-4 => last layer]") grad_cam_opacity = gr.Slider(0, 1, value=0.4, step=0.1, label="Choose opacity of the gradient") with gr.Column(): grad_cam_btn = gr.Button("Yes, Go Ahead", variant='primary') with gr.Column(visible=False) as grad_cam_output: grad_cam_output_gallery = gr.Gallery(value=[], columns=3, label='Output') # prediction_title = gr.Label(value='') ''' Options for Missclassfied images feature ''' with gr.Row(visible=False) as missclassified_col: with gr.Row(): missclassified_img_count = gr.Slider(1, 20, value=5, step=1, label="Choose image count", info="How many missclassified images you want to view?") missclassified_btn = gr.Button("Click to Continue", variant='primary') with gr.Row(visible=False) as missclassified_img_output: missclassified_img_output_gallery = gr.Gallery(value=[], columns=5, label='Output') ''' Option for Top prediction classes ''' with gr.Row(visible=True) as top_pred_cls_col: with gr.Column(): example_images = gr.Gallery(allow_preview=False, label='Select image ', info='', value=[img[0] for img in sample_images], columns=3, rows=2, object_fit='scale_down') with gr.Column(): with gr.Row(): top_pred_image = gr.Image(shape=(32, 32), label='Upload Image or Select from the gallery') top_class_count = gr.Slider(1, 10, value=5, step=1, label="Number of classes to predict") top_class_btn = gr.Button("Submit", variant='primary') tc_clear_btn = gr.ClearButton() with gr.Row(visible=True) as top_class_output: #top_class_output_img = gr.Image().style(width=256, height=256) top_class_output_labels = gr.Label(num_top_classes=top_class_count.value, label='Output') def clear_data(): return { top_pred_image: None, top_class_output_labels: None } tc_clear_btn.click(clear_data, None, [top_pred_image, top_class_output_labels]) def on_select(evt: gr.SelectData): return { top_pred_image: sample_images[evt.index][0] } example_images.select(on_select, None, top_pred_image) def top_class_img_upload(input_img, top_class_count): if input_img is not None: transform = transforms.ToTensor() org_img = input_img input_img = transform(input_img) input_img = input_img.to(device) input_img = input_img.unsqueeze(0) outputs = model(input_img, no_softmax=True) softmax = torch.nn.Softmax(dim=0) o = softmax(outputs.flatten()) confidences = {get_dataset_labels()[i]: float(o[i]) for i in range(10)} top_class_output_labels.num_top_classes = top_class_count #tc_clear_btn.add([top_pred_image, top_class_output_labels]) return { top_class_output: gr.update(visible=True), #top_class_output_img: org_img, top_class_output_labels: confidences } top_class_btn.click( top_class_img_upload, [top_pred_image, top_class_count], [top_class_output, top_class_output_labels] ) ''' Missclassified Images feature ''' def show_missclassified_images(img_count): imgs = [] for i in range(img_count): img = test_incorrect_pred['images'][i].cpu() img = denormalize(img, dataset_mean, dataset_std) img = np.array(255 * img, np.int16).transpose(1, 2, 0) label = '✅ ' + get_data_label_name( test_incorrect_pred['ground_truths'][i].item()) + ' ❌ ' + get_data_label_name( test_incorrect_pred['predicted_vals'][i].item()) imgs.append((img, label)) return { missclassified_img_output: gr.update(visible=True), missclassified_img_output_gallery: imgs } missclassified_btn.click( show_missclassified_images, [missclassified_img_count], [missclassified_img_output_gallery, missclassified_img_output] ) ''' GradCAM Feature ''' def grad_cam_submit(img_count, layer_idx, grad_opacity): target_layers = [model.get_layer(-1 * (layer_idx + 1))] cam = GradCAM(model=model, target_layers=target_layers) visual_arr = [] pred_arr = [] for i in range(img_count): pred_dict = test_incorrect_pred targets = [ClassifierOutputTarget(pred_dict['ground_truths'][i].cpu().item())] grayscale_cam = cam(input_tensor=pred_dict['images'][i][None, :].cpu(), targets=targets) x = denormalize(pred_dict['images'][i].cpu(), dataset_mean, dataset_std) image = np.array(255 * x, np.int16).transpose(1, 2, 0) img_tensor = np.array(x, np.float16).transpose(1, 2, 0) visualization = show_cam_on_image(img_tensor, grayscale_cam.transpose(1, 2, 0), use_rgb=True, image_weight=(1.0 - grad_opacity)) visual_arr.append( (visualization, get_data_label_name(pred_dict['ground_truths'][i].item())) ) return { grad_cam_output: gr.update(visible=True), grad_cam_output_gallery: visual_arr } grad_cam_btn.click( grad_cam_submit, [grad_cam_count, grad_cam_layer, grad_cam_opacity], [grad_cam_output_gallery, grad_cam_output] ) ''' Select Feature to showcase ''' def select_feature(feature): if feature == 0: return { grad_cam_col: gr.update(visible=False), grad_cam_output: gr.update(visible=False), missclassified_col: gr.update(visible=False), missclassified_img_output: gr.update(visible=False), top_pred_cls_col: gr.update(visible=True), top_class_output: gr.update(visible=True) } elif feature == 1: return { grad_cam_col: gr.update(visible=False), grad_cam_output: gr.update(visible=False), missclassified_col: gr.update(visible=True), missclassified_img_output: gr.update(visible=True), top_pred_cls_col: gr.update(visible=False), top_class_output: gr.update(visible=False) } else: return { grad_cam_col: gr.update(visible=True), grad_cam_output: gr.update(visible=True), missclassified_col: gr.update(visible=False), missclassified_img_output: gr.update(visible=False), top_pred_cls_col: gr.update(visible=False), top_class_output: gr.update(visible=False) } radio_btn.change(select_feature, [radio_btn], [grad_cam_col, grad_cam_output, missclassified_col, missclassified_img_output, top_pred_cls_col, top_class_output]) ''' Launch the app ''' app.launch()