bala1802 commited on
Commit
2812d44
1 Parent(s): 22fa207

Upload 5 files

Browse files
Files changed (3) hide show
  1. app.py +170 -4
  2. mini_resnet.py +89 -0
  3. model_weights/weights.pt +3 -0
app.py CHANGED
@@ -1,7 +1,173 @@
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+ from pathlib import Path
4
+ from random import shuffle
5
+
6
+ import cv2
7
  import gradio as gr
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import torch
11
+ from mini_resnet import CustomResNet
12
+ from PIL import Image
13
+ from pytorch_grad_cam import GradCAM
14
+ from pytorch_grad_cam.utils.image import show_cam_on_image
15
+ from torchvision import transforms as T
16
+
17
+ mean = (0.49139968, 0.48215841, 0.44653091)
18
+ std = (0.24703223, 0.24348513, 0.26158784)
19
+ transforms = T.Compose([T.ToTensor(), T.Normalize(mean=mean, std=std)])
20
+ classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")
21
+ softmax = torch.nn.Softmax(dim=0)
22
+
23
+ model = CustomResNet()
24
+ model.load_state_dict(torch.load("model_weights/weights.pt", map_location=torch.device("cpu")))
25
+ model.eval()
26
+
27
+ misclf_path = "images/miss_classified"
28
+ mis_classified_imgs = list(Path(misclf_path).glob("*"))
29
+
30
+
31
+ def get_traget_layer(block: str, layer: int):
32
+ layer_num = 0 if layer == 0 else -1
33
+ if block == "block1":
34
+ return model.layer1[layer_num]
35
+ if block == "block2":
36
+ return model.layer2[layer_num]
37
+ if block == "block3":
38
+ return model.layer3[layer_num]
39
+
40
+
41
+ default_cam = GradCAM(model=model, target_layers=[get_traget_layer("block3", -1)])
42
+
43
+
44
+ def make_image(p: Path | str, pred: str, label: str):
45
+ im = cv2.imread(str(p))
46
+ im = cv2.resize(im, (64, 64))
47
+
48
+ plt.imshow(im)
49
+ plt.title(f"{pred} / {label}")
50
+ plt.axis("off")
51
+
52
+ buffer = BytesIO()
53
+ plt.savefig(buffer, format="png")
54
+ buffer.seek(0)
55
+
56
+ img_array = np.frombuffer(buffer.getvalue(), dtype=np.uint8)
57
+ buffer.close()
58
+
59
+ # Decode the image array using OpenCV
60
+ im = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
61
+ return im
62
+
63
+
64
+ @torch.inference_mode()
65
+ def predict_img(img: np.ndarray, top_k: int = 10):
66
+ preds = model(img)
67
+ preds = softmax(preds.flatten())
68
+ preds = {classes[i]: float(preds[i]) for i in range(10)}
69
+ preds = {
70
+ k: v for k, v in sorted(preds.items(), key=lambda item: item[1], reverse=True)[:top_k]
71
+ }
72
+
73
+ return preds
74
+
75
+
76
+ def display_cam(cam: GradCAM, org_img: np.ndarray, img: torch.Tensor, transparency: float):
77
+ grayscale_cam = cam(input_tensor=img, targets=None)
78
+ grayscale_cam = grayscale_cam[0, :]
79
+ visualization = show_cam_on_image(
80
+ org_img / 255, grayscale_cam, use_rgb=True, image_weight=transparency
81
+ )
82
+ return visualization
83
+
84
+
85
+ def inference(
86
+ org_img: np.ndarray,
87
+ top_k: int,
88
+ show_cam: str,
89
+ num_cam_imgs: int,
90
+ cam_block: str,
91
+ target_layer_num: int,
92
+ transparency: float,
93
+ show_misclf: str,
94
+ num_misclf: int,
95
+ ):
96
+ input_img = transforms(org_img)
97
+ input_img = input_img.unsqueeze(0)
98
+
99
+ preds = predict_img(input_img, top_k)
100
+ org_img = display_cam(default_cam, org_img, input_img, transparency)
101
+
102
+ shuffle(mis_classified_imgs)
103
+ cam_outputs = []
104
+ if show_cam:
105
+ img_list = []
106
+
107
+ target_layers = [get_traget_layer(cam_block, target_layer_num)]
108
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
109
+ for p in mis_classified_imgs[:num_cam_imgs]:
110
+ im = cv2.imread(str(p))
111
+ inp_im = transforms(im)
112
+ inp_im = inp_im.unsqueeze(0)
113
+
114
+ grayscale_cam = cam(input_tensor=inp_im, targets=None)
115
+
116
+ grayscale_cam = grayscale_cam[0, :]
117
+ visualization = show_cam_on_image(
118
+ im / 255, grayscale_cam, use_rgb=True, image_weight=transparency
119
+ )
120
+ cam_outputs.append(visualization)
121
+
122
+ del cam, img_list
123
+
124
+ misclf_images_output = []
125
+ if show_misclf:
126
+ img_list = []
127
+ gt = []
128
+ for p in mis_classified_imgs[:num_misclf]:
129
+ img_list.append(transforms(Image.open(p).convert("RGB")))
130
+ gt.append(p.name.split("_")[0])
131
+
132
+ misclf_out = softmax(model(torch.stack(img_list))).argmax(dim=1).tolist()
133
+ del img_list
134
+ for imp, pred, label in zip(mis_classified_imgs[:num_misclf], misclf_out, gt):
135
+ pred = classes[pred]
136
+ misclf_images_output.append(make_image(imp, pred, label))
137
+
138
+ return org_img, preds, cam_outputs, misclf_images_output
139
 
 
 
140
 
141
+ title = "CIFAR10 trained on Custom Model inspired by ResNet with GradCAM"
142
+ description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
143
+ # examples = [["cat.jpg", 0.5, -1], ["dog.jpg", 0.5, -1]]
144
+ demo = gr.Interface(
145
+ inference,
146
+ inputs=[
147
+ gr.Image(shape=(32, 32), label="Input Image"),
148
+ gr.Slider(1, 10, value=3, step=1, label="Top K predictions"),
149
+ gr.Checkbox(label="Show Grad Cam"),
150
+ gr.Slider(1, 20, value=5, step=1, label="Number of images"),
151
+ gr.Radio(label="Which Block?", choices=["block1", "block2", "block3"]),
152
+ gr.Slider(0, 1, value=1, step=1, label="Which Layer?"),
153
+ gr.Slider(0, 1, value=0.5, label="Opacity of GradCAM"),
154
+ gr.Checkbox(label="Show Misclassified Images"),
155
+ gr.Slider(1, 20, value=5, step=5, label="Number of Misclassification Images"),
156
+ ],
157
+ outputs=[
158
+ gr.Image(shape=(32, 32), label="Output", width=128, height=128),
159
+ "label",
160
+ gr.Gallery(label="GradCAM Output"),
161
+ gr.Gallery(
162
+ label="Misclassified Images Pred/G.T.",
163
+ columns=[2],
164
+ rows=[2],
165
+ object_fit="contain",
166
+ height="auto",
167
+ ),
168
+ ],
169
+ title=title,
170
+ description=description,
171
+ # examples=examples,
172
+ )
173
+ demo.launch(share=True)
mini_resnet.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # from common import BaseNet
6
+
7
+
8
+ class ResBlock(nn.Module):
9
+ def __init__(self, in_planes: int, out_planes: int, stride: int = 1, drop: float = 0) -> None:
10
+ super().__init__()
11
+ self.dropout = nn.Dropout2d(drop)
12
+
13
+ self.conv1 = nn.Conv2d(
14
+ in_planes,
15
+ out_planes,
16
+ kernel_size=3,
17
+ stride=stride,
18
+ padding=1,
19
+ bias=False,
20
+ )
21
+ self.bn1 = nn.BatchNorm2d(out_planes)
22
+
23
+ self.conv2 = nn.Conv2d(
24
+ out_planes,
25
+ out_planes,
26
+ kernel_size=3,
27
+ stride=stride,
28
+ padding=1,
29
+ bias=False,
30
+ )
31
+ self.bn2 = nn.BatchNorm2d(out_planes)
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ out = F.relu(self.bn1(self.conv1(x)))
35
+ out = self.dropout(out)
36
+ out = self.bn2(self.conv2(out))
37
+ out += x
38
+ out = F.relu(out)
39
+ out = self.dropout(out)
40
+
41
+ return out
42
+
43
+
44
+ class CustomResNet(nn.Module):
45
+ def __init__(self, drop: float = 0, num_classes: int = 10) -> None:
46
+ super().__init__()
47
+
48
+ # perp layer
49
+ self.perlayer = nn.Sequential(
50
+ nn.Conv2d(3, 64, 3, padding=1, bias=False),
51
+ nn.BatchNorm2d(64),
52
+ nn.ReLU(),
53
+ nn.Dropout2d(drop),
54
+ )
55
+ self.layer1 = nn.Sequential(
56
+ nn.Conv2d(64, 128, 3, padding=1, bias=False),
57
+ nn.MaxPool2d(2, 2),
58
+ nn.BatchNorm2d(128),
59
+ nn.ReLU(),
60
+ nn.Dropout2d(drop),
61
+ ResBlock(128, 128, drop=drop),
62
+ )
63
+ self.layer2 = nn.Sequential(
64
+ nn.Conv2d(128, 256, 3, padding=1, bias=False),
65
+ nn.MaxPool2d(2, 2),
66
+ nn.BatchNorm2d(256),
67
+ nn.ReLU(),
68
+ nn.Dropout2d(drop),
69
+ )
70
+ self.layer3 = nn.Sequential(
71
+ nn.Conv2d(256, 512, 3, padding=1, bias=False),
72
+ nn.MaxPool2d(2, 2),
73
+ nn.BatchNorm2d(512),
74
+ nn.ReLU(),
75
+ nn.Dropout2d(drop),
76
+ ResBlock(512, 512, drop=drop),
77
+ )
78
+ self.pool = nn.MaxPool2d(4)
79
+ self.out = nn.Conv2d(512, num_classes, 1, bias=False)
80
+
81
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
82
+ x = self.perlayer(x)
83
+ x = self.layer1(x)
84
+ x = self.layer2(x)
85
+ x = self.layer3(x)
86
+ x = self.pool(x)
87
+ x = self.out(x)
88
+
89
+ return x.view(-1, 10)
model_weights/weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bfb94e78ae17040a9dba004bdf9e3ba9633cf4bb730184cf6e487458747e3a2
3
+ size 26325330