Spaces:
Sleeping
Sleeping
import gradio as gr | |
from torchvision import transforms | |
import torch | |
from utils import CustomResnet, main_inference, get_misclassified_images, get_gradcam | |
inv_normalize = transforms.Normalize( | |
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23], | |
std=[1/0.23, 1/0.23, 1/0.23] | |
) | |
model = CustomResnet() | |
classes = ('plane', 'car', 'bird', 'cat', 'deer', | |
'dog', 'frog', 'horse', 'ship', 'truck') | |
targets = None | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
# Define the input and output components of the Gradio app | |
input_component = gr.inputs.Image(shape=(32, 32)) | |
num_of_output_classes = gr.inputs.Slider(minimum=0, maximum=10, default=5, step=1,label="Top class count") | |
# Adding a checkbox to the interface to show/hide misclassified images | |
show_misclassified_checkbox = gr.inputs.Checkbox(default=False, label="Show Misclassified Images") | |
# Input field to specify the number of misclassified images to display | |
num_images_input = gr.inputs.Slider(minimum=0, maximum=20, default=15, step=5,label="Missclassified Images Count") | |
# Adding a checkbox to the interface to show/hide GradCAM output | |
show_gradcam_checkbox = gr.inputs.Checkbox(default=False, label="Show GradCAM Output") | |
# Slider for adjusting the opacity of the GradCAM overlay | |
opacity_slider = gr.inputs.Slider(minimum=0, maximum=1, default=0.7,step=0.1, label="GradCAM Opacity") | |
gr.Interface( | |
fn=lambda image, num_of_output_classes,show_misclassified, num_images, show_gradcam, opacity: [main_inference(num_of_output_classes,classes,model,image), | |
get_misclassified_images(show_misclassified, num_images) if show_misclassified else None, | |
get_gradcam(model,image, opacity) if show_gradcam else None], | |
inputs=[input_component, num_of_output_classes,show_misclassified_checkbox, num_images_input, show_gradcam_checkbox, opacity_slider], | |
outputs=[gr.outputs.Label(), gr.Image(shape=(500, 500)), gr.Image(shape=(500, 500))], | |
examples=[ | |
["example_images/example_1.png",5,True,5,True,0.2], # You can provide your own example input values here | |
["example_images/example_2.png",5,False,5,True,0.3], | |
["example_images/example_3.png",5,True,15,False,0.2] , | |
["example_images/example_4.png",5,True,20,True,0.5] , | |
["example_images/example_5.png",5,False,5,False,0.2] , | |
["example_images/example_6.png",5,True,10,True,0.3] , | |
["example_images/example_7.png",5,True,5,True,0.4] , | |
["example_images/example_8.png",5,False,5,False,0.6] , | |
["example_images/example_9.png",5,True,20,False,0.2] , | |
["example_images/example_10.png",5,False,5,True,0.7] | |
], | |
layout="horizontal" | |
).launch() | |