File size: 3,345 Bytes
9e78939
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5608c10
9e78939
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c7f4df
 
 
 
9e78939
2932f27
9e78939
2932f27
 
b2ce316
bbaaadc
9e78939
3c7f4df
 
 
 
 
 
 
 
 
 
9e78939
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
from torchvision import transforms
import torch
from utils import CustomResnet, main_inference, get_misclassified_images, get_gradcam

inv_normalize = transforms.Normalize(
    mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
    std=[1/0.23, 1/0.23, 1/0.23]
)

model = CustomResnet()

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')
targets = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load("best_model.pth",map_location=torch.device("cpu")),strict=False)
model.to(device)

# Define the input and output components of the Gradio app
input_component = gr.inputs.Image(shape=(32, 32))
num_of_output_classes = gr.inputs.Slider(minimum=0, maximum=10, default=5, step=1,label="Top class count")
# Adding a checkbox to the interface to show/hide misclassified images
show_misclassified_checkbox = gr.inputs.Checkbox(default=False, label="Show Misclassified Images")

# Input field to specify the number of misclassified images to display
num_images_input = gr.inputs.Slider(minimum=0, maximum=20, default=15, step=5,label="Missclassified Images Count")

# Adding a checkbox to the interface to show/hide GradCAM output
show_gradcam_checkbox = gr.inputs.Checkbox(default=False, label="Show GradCAM Output")

# Slider for adjusting the opacity of the GradCAM overlay
opacity_slider = gr.inputs.Slider(minimum=0, maximum=1, default=0.7,step=0.1,    label="GradCAM Opacity")
#layer_options = ['layer1', 'layer2', 'layer3']

layer_options = ['convblock1', 'resblock1','convblock2',"convblock3","resblock2"]
layer_input = gr.inputs.Dropdown(layer_options,label="Select a Layer",default="convblock3")
gr.Interface(
    fn=lambda image, num_of_output_classes,show_misclassified, num_images, show_gradcam, opacity,layer: [main_inference(num_of_output_classes,classes,model,image),
                                                                              get_misclassified_images(show_misclassified, num_images) if show_misclassified else None,
                                                                              get_gradcam(model,image, opacity,layer) if show_gradcam else None],
    inputs=[input_component, num_of_output_classes,show_misclassified_checkbox, num_images_input, show_gradcam_checkbox, opacity_slider,layer_input],
    outputs=[gr.outputs.Label(), gr.Image(shape=(500, 500)), gr.Image(shape=(250, 250))],
    title="CIFAR10 Trained on Custom Residual CNN Architecture",
    examples=[
        ["example_images/example_1.png",5,True,5,True,0.2,'convblock3'],  # You can provide your own example input values here
    ["example_images/example_2.png",5,False,5,True,0.3,'convblock3'],
    ["example_images/example_3.png",5,True,15,False,0.2,'convblock3'] ,
    ["example_images/example_4.png",5,True,20,True,0.5,'convblock3'] ,
    ["example_images/example_5.png",5,False,5,False,0.2,'convblock3'] ,
    ["example_images/example_6.png",5,True,10,True,0.3,'convblock3'] ,
    ["example_images/example_7.png",5,True,5,True,0.4,'convblock3'] ,
    ["example_images/example_8.png",5,False,5,False,0.6,'convblock3'] ,
    ["example_images/example_9.png",5,True,20,False,0.2,'convblock3'] ,
    ["example_images/example_10.png",5,False,5,True,0.7,'convblock3']  
    
    ],
    layout="horizontal"
).launch()