Spaces:
Sleeping
Sleeping
Inference Code for trained model
Browse files- gradio_app.py +54 -0
- utils.py +151 -0
gradio_app.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from torchvision import transforms
|
3 |
+
import torch
|
4 |
+
from utils import CustomResnet, main_inference, get_misclassified_images, get_gradcam
|
5 |
+
|
6 |
+
inv_normalize = transforms.Normalize(
|
7 |
+
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
|
8 |
+
std=[1/0.23, 1/0.23, 1/0.23]
|
9 |
+
)
|
10 |
+
|
11 |
+
model = CustomResnet()
|
12 |
+
|
13 |
+
classes = ('plane', 'car', 'bird', 'cat', 'deer',
|
14 |
+
'dog', 'frog', 'horse', 'ship', 'truck')
|
15 |
+
targets = None
|
16 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
17 |
+
model.to(device)
|
18 |
+
|
19 |
+
# Define the input and output components of the Gradio app
|
20 |
+
input_component = gr.inputs.Image(shape=(32, 32))
|
21 |
+
num_of_output_classes = gr.inputs.Slider(minimum=0, maximum=10, default=5, step=1,label="Top class count")
|
22 |
+
# Adding a checkbox to the interface to show/hide misclassified images
|
23 |
+
show_misclassified_checkbox = gr.inputs.Checkbox(default=False, label="Show Misclassified Images")
|
24 |
+
|
25 |
+
# Input field to specify the number of misclassified images to display
|
26 |
+
num_images_input = gr.inputs.Slider(minimum=0, maximum=20, default=15, step=5,label="Missclassified Images Count")
|
27 |
+
|
28 |
+
# Adding a checkbox to the interface to show/hide GradCAM output
|
29 |
+
show_gradcam_checkbox = gr.inputs.Checkbox(default=False, label="Show GradCAM Output")
|
30 |
+
|
31 |
+
# Slider for adjusting the opacity of the GradCAM overlay
|
32 |
+
opacity_slider = gr.inputs.Slider(minimum=0, maximum=1, default=0.7,step=0.1, label="GradCAM Opacity")
|
33 |
+
|
34 |
+
gr.Interface(
|
35 |
+
fn=lambda image, num_of_output_classes,show_misclassified, num_images, show_gradcam, opacity: [main_inference(num_of_output_classes,classes,model,image),
|
36 |
+
get_misclassified_images(show_misclassified, num_images) if show_misclassified else None,
|
37 |
+
get_gradcam(model,image, opacity) if show_gradcam else None],
|
38 |
+
inputs=[input_component, num_of_output_classes,show_misclassified_checkbox, num_images_input, show_gradcam_checkbox, opacity_slider],
|
39 |
+
outputs=[gr.outputs.Label(), gr.Image(shape=(500, 500)), gr.Image(shape=(500, 500))],
|
40 |
+
examples=[
|
41 |
+
["example_images/example_1.png",5,True,5,True,0.2], # You can provide your own example input values here
|
42 |
+
["example_images/example_2.png",5,False,5,True,0.3],
|
43 |
+
["example_images/example_3.png",5,True,15,False,0.2] ,
|
44 |
+
["example_images/example_4.png",5,True,20,True,0.5] ,
|
45 |
+
["example_images/example_5.png",5,False,5,False,0.2] ,
|
46 |
+
["example_images/example_6.png",5,True,10,True,0.3] ,
|
47 |
+
["example_images/example_7.png",5,True,5,True,0.4] ,
|
48 |
+
["example_images/example_8.png",5,False,5,False,0.6] ,
|
49 |
+
["example_images/example_9.png",5,True,20,False,0.2] ,
|
50 |
+
["example_images/example_10.png",5,False,5,True,0.7]
|
51 |
+
|
52 |
+
],
|
53 |
+
layout="horizontal"
|
54 |
+
).launch()
|
utils.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from pytorch_grad_cam import GradCAM
|
8 |
+
from pytorch_grad_cam import GradCAM
|
9 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
10 |
+
|
11 |
+
def apply_normalization(chennels):
|
12 |
+
return nn.BatchNorm2d(chennels)
|
13 |
+
|
14 |
+
class CustomResnet(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super(CustomResnet, self).__init__()
|
17 |
+
# Input Block
|
18 |
+
drop = 0.0
|
19 |
+
# PrepLayer - Conv 3x3 s1, p1) >> BN >> RELU [64k]
|
20 |
+
self.preplayer = nn.Sequential(
|
21 |
+
nn.Conv2d(3, 64, (3, 3), padding=1, stride=1, bias=False), # 3
|
22 |
+
apply_normalization(64),
|
23 |
+
nn.ReLU(),
|
24 |
+
)
|
25 |
+
# Layer1 -
|
26 |
+
# X = Conv 3x3 (s1, p1) >> MaxPool2D >> BN >> RELU [128k]
|
27 |
+
self.convlayer1 = nn.Sequential(
|
28 |
+
nn.Conv2d(64, 128, (3, 3), padding=1, stride=1, bias=False), # 3
|
29 |
+
nn.MaxPool2d(2, 2),
|
30 |
+
apply_normalization(128),
|
31 |
+
nn.ReLU(),
|
32 |
+
)
|
33 |
+
# R1 = ResBlock( (Conv-BN-ReLU-Conv-BN-ReLU))(X) [128k]
|
34 |
+
self.reslayer1 = nn.Sequential(
|
35 |
+
nn.Conv2d(128, 128, (3, 3), padding=1, stride=1, bias=False), # 3
|
36 |
+
apply_normalization(128),
|
37 |
+
nn.ReLU(),
|
38 |
+
nn.Conv2d(128, 128, (3, 3), padding=1, stride=1, bias=False), # 3
|
39 |
+
apply_normalization(128),
|
40 |
+
nn.ReLU(),
|
41 |
+
)
|
42 |
+
# Conv 3x3 [256k]
|
43 |
+
self.convlayer2 = nn.Sequential(
|
44 |
+
nn.Conv2d(128, 256, (3, 3), padding=1, stride=1, bias=False), # 3
|
45 |
+
nn.MaxPool2d(2, 2),
|
46 |
+
apply_normalization(256),
|
47 |
+
nn.ReLU(),
|
48 |
+
)
|
49 |
+
# X = Conv 3x3 (s1, p1) >> MaxPool2D >> BN >> RELU [512k]
|
50 |
+
self.convlayer3 = nn.Sequential(
|
51 |
+
nn.Conv2d(256, 512, (3, 3), padding=1, stride=1, bias=False), # 3
|
52 |
+
nn.MaxPool2d(2, 2),
|
53 |
+
apply_normalization(512),
|
54 |
+
nn.ReLU(),
|
55 |
+
)
|
56 |
+
# R1 = ResBlock( (Conv-BN-ReLU-Conv-BN-ReLU))(X) [128k]
|
57 |
+
self.reslayer2 = nn.Sequential(
|
58 |
+
nn.Conv2d(512, 512, (3, 3), padding=1, stride=1, bias=False), # 3
|
59 |
+
apply_normalization(512),
|
60 |
+
nn.ReLU(),
|
61 |
+
nn.Conv2d(512, 512, (3, 3), padding=1, stride=1, bias=False), # 3
|
62 |
+
apply_normalization(512),
|
63 |
+
nn.ReLU(),
|
64 |
+
)
|
65 |
+
self.maxpool3 = nn.MaxPool2d(4, 2)
|
66 |
+
self.linear1 = nn.Linear(512,10)
|
67 |
+
|
68 |
+
def forward(self,x):
|
69 |
+
x = self.preplayer(x)
|
70 |
+
x1 = self.convlayer1(x)
|
71 |
+
x2 = self.reslayer1(x1)
|
72 |
+
x = x1+x2
|
73 |
+
x = self.convlayer2(x)
|
74 |
+
x = self.convlayer3(x)
|
75 |
+
x1 = self.reslayer2(x)
|
76 |
+
x = x+x1
|
77 |
+
x = self.maxpool3(x)
|
78 |
+
x = x.view(-1, 512)
|
79 |
+
x = self.linear1(x)
|
80 |
+
return F.log_softmax(x, dim=-1)
|
81 |
+
|
82 |
+
# Function to run inference and return top classes
|
83 |
+
def get_gradcam(model,input_img, opacity):
|
84 |
+
targets = None
|
85 |
+
inv_normalize = transforms.Normalize(
|
86 |
+
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
|
87 |
+
std=[1/0.23, 1/0.23, 1/0.23]
|
88 |
+
)
|
89 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
90 |
+
transform = transforms.ToTensor()
|
91 |
+
input_img = transform(input_img)
|
92 |
+
input_img = input_img.to(device)
|
93 |
+
input_img = input_img.unsqueeze(0)
|
94 |
+
outputs = model(input_img)
|
95 |
+
_, prediction = torch.max(outputs, 1)
|
96 |
+
target_layers = [model.convlayer3[-2]]
|
97 |
+
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
|
98 |
+
grayscale_cam = cam(input_tensor=input_img, targets=targets)
|
99 |
+
grayscale_cam = grayscale_cam[0, :]
|
100 |
+
img = input_img.squeeze(0).to('cpu')
|
101 |
+
img = inv_normalize(img)
|
102 |
+
rgb_img = np.transpose(img, (1, 2, 0))
|
103 |
+
rgb_img = rgb_img.numpy()
|
104 |
+
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=opacity)
|
105 |
+
return visualization
|
106 |
+
|
107 |
+
|
108 |
+
def get_misclassified_images(show_misclassified,num):
|
109 |
+
if show_misclassified:
|
110 |
+
return cv2.imread(f"missclassified_images_examples/{int(num)}_missclassified.png")
|
111 |
+
else:
|
112 |
+
return None
|
113 |
+
|
114 |
+
|
115 |
+
def main_inference(num_of_output_classes,classes,model,input_img):
|
116 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
117 |
+
transform = transforms.ToTensor()
|
118 |
+
input_img = transform(input_img)
|
119 |
+
input_img = input_img.to(device)
|
120 |
+
input_img = input_img.unsqueeze(0)
|
121 |
+
softmax = torch.nn.Softmax(dim=0)
|
122 |
+
outputs = model(input_img)
|
123 |
+
out = softmax(outputs.flatten())
|
124 |
+
_, prediction = torch.max(outputs, 1)
|
125 |
+
confidences = {classes[i]:float(out[i]) for i in range(num_of_output_classes)}
|
126 |
+
outputs = model(input_img)
|
127 |
+
_, prediction = torch.max(outputs, 1)
|
128 |
+
return confidences
|
129 |
+
# def run_inference(input_img, num_of_output_classes,transparency):
|
130 |
+
# transform = transforms.ToTensor()
|
131 |
+
# input_img = transform(input_img)
|
132 |
+
# input_img = input_img.to(device)
|
133 |
+
# input_img = input_img.unsqueeze(0)
|
134 |
+
# softmax = torch.nn.Softmax(dim=0)
|
135 |
+
# outputs = model(input_img)
|
136 |
+
# out = softmax(outputs.flatten())
|
137 |
+
# _, prediction = torch.max(outputs, 1)
|
138 |
+
# confidences = {classes[i]:float(out[i]) for i in range(num_of_output_classes)}
|
139 |
+
# target_layers = [model.convlayer3[-2]]
|
140 |
+
|
141 |
+
# cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
|
142 |
+
# grayscale_cam = cam(input_tensor=input_img, targets=targets)
|
143 |
+
# grayscale_cam = grayscale_cam[0, :]
|
144 |
+
# img = input_img.squeeze(0).to('cpu')
|
145 |
+
# img = inv_normalize(img)
|
146 |
+
# rgb_img = np.transpose(img, (1, 2, 0))
|
147 |
+
# rgb_img = rgb_img.numpy()
|
148 |
+
# visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
|
149 |
+
# return confidences, rgb_img, transparency,grayscale_cam
|
150 |
+
|
151 |
+
|