Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| import numpy as np | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| import io | |
| from models import custom_resnet_lightning_s10 | |
| from utils import load_model_from_checkpoint, denormalize, get_data_label_name, get_dataset_labels | |
| device = torch.device('cpu') | |
| dataset_mean, dataset_std = (0.4914, 0.4822, 0.4465), \ | |
| (0.2470, 0.2435, 0.2616) | |
| model = custom_resnet_lightning_s10.S10LightningModel(64) | |
| checkpoint = load_model_from_checkpoint(device) | |
| model.load_state_dict(checkpoint['model'], strict=False) | |
| test_incorrect_pred = checkpoint['test_incorrect_pred'] | |
| sample_images = [ | |
| ['images/aeroplane.jpeg', 0], | |
| ['images/bird.jpeg', 2], | |
| ['images/car.jpeg', 1], | |
| ['images/cat.jpeg', 3], | |
| ['images/deer.jpeg', 4], | |
| ['images/dog.jpeg', 5], | |
| ['images/frog.jpeg', 6], | |
| ['images/horse.jpeg', 7], | |
| ['images/ship.jpeg', 8], | |
| ['images/truck.jpeg', 9] | |
| ] | |
| with gr.Blocks() as app: | |
| ''' | |
| Select feature interface | |
| ''' | |
| with gr.Row() as input_radio_group: | |
| radio_btn = gr.Radio( | |
| choices=['Top Prediction Classes', 'Missclassified Images', 'GradCAM Images'], | |
| type="index", | |
| label='Feature options', | |
| info="Choose which feature you want to explore", | |
| value='Top Prediction Classes' | |
| ) | |
| ''' | |
| Options for GradCAM feature | |
| ''' | |
| with gr.Row(): | |
| with gr.Column(visible=False) as grad_cam_col: | |
| grad_cam_count = gr.Slider(1, 20, value=5, step=1, label="Choose image count", | |
| info="How many images you want to view?") | |
| grad_cam_layer = gr.Slider(-4, -1, value=-3, step=1, label="Choose model layer", | |
| info="Which layer you want to view GradCAM on? [-4 => last layer]") | |
| grad_cam_opacity = gr.Slider(0, 1, value=0.4, step=0.1, label="Choose opacity of the gradient") | |
| with gr.Column(): | |
| grad_cam_btn = gr.Button("Yes, Go Ahead", variant='primary') | |
| with gr.Column(visible=False) as grad_cam_output: | |
| grad_cam_output_gallery = gr.Gallery(value=[], columns=3, label='Output') | |
| # prediction_title = gr.Label(value='') | |
| ''' | |
| Options for Missclassfied images feature | |
| ''' | |
| with gr.Row(visible=False) as missclassified_col: | |
| with gr.Row(): | |
| missclassified_img_count = gr.Slider(1, 20, value=5, step=1, label="Choose image count", | |
| info="How many missclassified images you want to view?") | |
| missclassified_btn = gr.Button("Click to Continue", variant='primary') | |
| with gr.Row(visible=False) as missclassified_img_output: | |
| missclassified_img_output_gallery = gr.Gallery(value=[], columns=5, label='Output') | |
| ''' | |
| Option for Top prediction classes | |
| ''' | |
| with gr.Row(visible=True) as top_pred_cls_col: | |
| with gr.Column(): | |
| example_images = gr.Gallery(allow_preview=False, label='Select image ', info='', value=[img[0] for img in sample_images], columns=3, rows=2, object_fit='scale_down') | |
| with gr.Column(): | |
| with gr.Row(): | |
| top_pred_image = gr.Image(shape=(32, 32), label='Upload Image or Select from the gallery') | |
| top_class_count = gr.Slider(1, 10, value=5, step=1, label="Number of classes to predict") | |
| top_class_btn = gr.Button("Submit", variant='primary') | |
| tc_clear_btn = gr.ClearButton() | |
| with gr.Row(visible=True) as top_class_output: | |
| #top_class_output_img = gr.Image().style(width=256, height=256) | |
| top_class_output_labels = gr.Label(num_top_classes=top_class_count.value, label='Output') | |
| def clear_data(): | |
| return { | |
| top_pred_image: None, | |
| top_class_output_labels: None | |
| } | |
| tc_clear_btn.click(clear_data, None, [top_pred_image, top_class_output_labels]) | |
| def on_select(evt: gr.SelectData): | |
| return { | |
| top_pred_image: sample_images[evt.index][0] | |
| } | |
| example_images.select(on_select, None, top_pred_image) | |
| def top_class_img_upload(input_img, top_class_count): | |
| if input_img is not None: | |
| transform = transforms.ToTensor() | |
| org_img = input_img | |
| input_img = transform(input_img) | |
| input_img = input_img.to(device) | |
| input_img = input_img.unsqueeze(0) | |
| outputs = model(input_img, no_softmax=True) | |
| softmax = torch.nn.Softmax(dim=0) | |
| o = softmax(outputs.flatten()) | |
| confidences = {get_dataset_labels()[i]: float(o[i]) for i in range(10)} | |
| top_class_output_labels.num_top_classes = top_class_count | |
| #tc_clear_btn.add([top_pred_image, top_class_output_labels]) | |
| return { | |
| top_class_output: gr.update(visible=True), | |
| #top_class_output_img: org_img, | |
| top_class_output_labels: confidences | |
| } | |
| top_class_btn.click( | |
| top_class_img_upload, | |
| [top_pred_image, top_class_count], | |
| [top_class_output, top_class_output_labels] | |
| ) | |
| ''' | |
| Missclassified Images feature | |
| ''' | |
| def show_missclassified_images(img_count): | |
| imgs = [] | |
| for i in range(img_count): | |
| img = test_incorrect_pred['images'][i].cpu() | |
| img = denormalize(img, dataset_mean, dataset_std) | |
| img = np.array(255 * img, np.int16).transpose(1, 2, 0) | |
| label = '✅ ' + get_data_label_name( | |
| test_incorrect_pred['ground_truths'][i].item()) + ' ❌ ' + get_data_label_name( | |
| test_incorrect_pred['predicted_vals'][i].item()) | |
| imgs.append((img, label)) | |
| return { | |
| missclassified_img_output: gr.update(visible=True), | |
| missclassified_img_output_gallery: imgs | |
| } | |
| missclassified_btn.click( | |
| show_missclassified_images, | |
| [missclassified_img_count], | |
| [missclassified_img_output_gallery, missclassified_img_output] | |
| ) | |
| ''' | |
| GradCAM Feature | |
| ''' | |
| def grad_cam_submit(img_count, layer_idx, grad_opacity): | |
| target_layers = [model.get_layer(-1 * (layer_idx + 1))] | |
| cam = GradCAM(model=model, target_layers=target_layers) | |
| visual_arr = [] | |
| pred_arr = [] | |
| for i in range(img_count): | |
| pred_dict = test_incorrect_pred | |
| targets = [ClassifierOutputTarget(pred_dict['ground_truths'][i].cpu().item())] | |
| grayscale_cam = cam(input_tensor=pred_dict['images'][i][None, :].cpu(), targets=targets) | |
| x = denormalize(pred_dict['images'][i].cpu(), dataset_mean, dataset_std) | |
| image = np.array(255 * x, np.int16).transpose(1, 2, 0) | |
| img_tensor = np.array(x, np.float16).transpose(1, 2, 0) | |
| visualization = show_cam_on_image(img_tensor, grayscale_cam.transpose(1, 2, 0), use_rgb=True, | |
| image_weight=(1.0 - grad_opacity)) | |
| visual_arr.append( | |
| (visualization, get_data_label_name(pred_dict['ground_truths'][i].item())) | |
| ) | |
| return { | |
| grad_cam_output: gr.update(visible=True), | |
| grad_cam_output_gallery: visual_arr | |
| } | |
| grad_cam_btn.click( | |
| grad_cam_submit, | |
| [grad_cam_count, grad_cam_layer, grad_cam_opacity], | |
| [grad_cam_output_gallery, grad_cam_output] | |
| ) | |
| ''' | |
| Select Feature to showcase | |
| ''' | |
| def select_feature(feature): | |
| if feature == 0: | |
| return { | |
| grad_cam_col: gr.update(visible=False), | |
| grad_cam_output: gr.update(visible=False), | |
| missclassified_col: gr.update(visible=False), | |
| missclassified_img_output: gr.update(visible=False), | |
| top_pred_cls_col: gr.update(visible=True), | |
| top_class_output: gr.update(visible=True) | |
| } | |
| elif feature == 1: | |
| return { | |
| grad_cam_col: gr.update(visible=False), | |
| grad_cam_output: gr.update(visible=False), | |
| missclassified_col: gr.update(visible=True), | |
| missclassified_img_output: gr.update(visible=True), | |
| top_pred_cls_col: gr.update(visible=False), | |
| top_class_output: gr.update(visible=False) | |
| } | |
| else: | |
| return { | |
| grad_cam_col: gr.update(visible=True), | |
| grad_cam_output: gr.update(visible=True), | |
| missclassified_col: gr.update(visible=False), | |
| missclassified_img_output: gr.update(visible=False), | |
| top_pred_cls_col: gr.update(visible=False), | |
| top_class_output: gr.update(visible=False) | |
| } | |
| radio_btn.change(select_feature, | |
| [radio_btn], | |
| [grad_cam_col, grad_cam_output, missclassified_col, missclassified_img_output, top_pred_cls_col, top_class_output]) | |
| ''' | |
| Launch the app | |
| ''' | |
| app.launch() |