sagar007 commited on
Commit
3f461cc
1 Parent(s): 2ab51e7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision import transforms
4
+ import numpy as np
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from pytorch_grad_cam import GradCAM
8
+ from pytorch_grad_cam.utils.image import show_cam_on_image
9
+ from resnet import ResNet18
10
+
11
+ model = ResNet18()
12
+ model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
13
+
14
+ inv_normalize = transforms.Normalize(
15
+ mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
16
+ std=[1/0.23, 1/0.23, 1/0.23]
17
+ )
18
+
19
+ classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
20
+
21
+ def resize_image_pil(image, new_width, new_height):
22
+ img = Image.fromarray(np.array(image))
23
+ width, height = img.size
24
+ width_scale = new_width / width
25
+ height_scale = new_height / height
26
+ scale = min(width_scale, height_scale)
27
+ resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST)
28
+ resized = resized.crop((0, 0, new_width, new_height))
29
+ return resized
30
+
31
+ def inference(input_img, show_gradcam, num_gradcam, target_layer_number, opacity, show_misclassified, num_misclassified, num_top_classes):
32
+ input_img = resize_image_pil(input_img, 32, 32)
33
+ input_img = np.array(input_img)
34
+ org_img = input_img
35
+ input_img = input_img.reshape((32, 32, 3))
36
+ transform = transforms.ToTensor()
37
+ input_img = transform(input_img)
38
+ input_img = input_img.unsqueeze(0)
39
+
40
+ outputs = model(input_img)
41
+ softmax = torch.nn.Softmax(dim=1)
42
+ probs = softmax(outputs)
43
+
44
+ top_probs, top_labels = torch.topk(probs, k=min(num_top_classes, 10))
45
+ top_classes = [classes[idx] for idx in top_labels[0]]
46
+ confidences = {cls: float(prob) for cls, prob in zip(top_classes, top_probs[0])}
47
+
48
+ _, prediction = torch.max(outputs, 1)
49
+ predicted_class = classes[prediction[0].item()]
50
+
51
+ results = [predicted_class, confidences]
52
+
53
+ if show_gradcam:
54
+ target_layers = [model.layer2[target_layer_number]]
55
+ cam = GradCAM(model=model, target_layers=target_layers)
56
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
57
+ grayscale_cam = grayscale_cam[0, :]
58
+ visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=opacity)
59
+ results.append(visualization)
60
+
61
+ if show_misclassified:
62
+ # This part would require a dataset of images to check for misclassifications
63
+ # For demonstration, we'll just return a placeholder message
64
+ results.append("Misclassified images feature would be implemented here")
65
+
66
+ return results
67
+
68
+ def launch():
69
+ with gr.Blocks() as demo:
70
+ gr.Markdown("# CIFAR10 ResNet18 Model with GradCAM")
71
+
72
+ with gr.Row():
73
+ input_image = gr.Image(width=256, height=256, label="Input Image")
74
+ output_image = gr.Image(width=256, height=256, label="GradCAM Output")
75
+
76
+ with gr.Row():
77
+ prediction = gr.Textbox(label="Predicted Class")
78
+ confidences = gr.Label(label="Top Class Confidences")
79
+
80
+ with gr.Row():
81
+ show_gradcam = gr.Checkbox(label="Show GradCAM")
82
+ num_gradcam = gr.Slider(1, 5, value=1, step=1, label="Number of GradCAM images")
83
+ target_layer = gr.Slider(-2, -1, value=-2, step=1, label="Target Layer")
84
+ opacity = gr.Slider(0, 1, value=0.5, label="GradCAM Opacity")
85
+
86
+ with gr.Row():
87
+ show_misclassified = gr.Checkbox(label="Show Misclassified Images")
88
+ num_misclassified = gr.Slider(1, 10, value=5, step=1, label="Number of Misclassified Images")
89
+
90
+ num_top_classes = gr.Slider(1, 10, value=3, step=1, label="Number of Top Classes to Show")
91
+
92
+ submit_btn = gr.Button("Submit")
93
+
94
+ example_images = gr.Dataset(
95
+ components=[input_image],
96
+ samples=[["cat.jpg"], ["dog.jpg"], ["bird.jpg"], ["plane.jpg"], ["car.jpg"],
97
+ ["deer.jpg"], ["frog.jpg"], ["horse.jpg"], ["ship.jpg"], ["truck.jpg"]]
98
+ )
99
+
100
+ submit_btn.click(
101
+ inference,
102
+ inputs=[input_image, show_gradcam, num_gradcam, target_layer, opacity,
103
+ show_misclassified, num_misclassified, num_top_classes],
104
+ outputs=[prediction, confidences, output_image]
105
+ )
106
+
107
+ demo.launch()
108
+
109
+ if __name__ == "__main__":
110
+ launch()