Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from models.david_page import DavidPageNet
|
9 |
+
from PIL import Image
|
10 |
+
from pytorch_grad_cam import GradCAM
|
11 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
12 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
13 |
+
from torchvision import transforms
|
14 |
+
|
15 |
+
|
16 |
+
# imagenet mean and std
|
17 |
+
mean = [0.485, 0.456, 0.406]
|
18 |
+
std = [0.229, 0.224, 0.225]
|
19 |
+
|
20 |
+
inv_mean = [-mean / std for mean, std in zip(mean, std)]
|
21 |
+
inv_std = [1 / s for s in std]
|
22 |
+
|
23 |
+
# transforms
|
24 |
+
transform = transforms.Compose([
|
25 |
+
transforms.ToTensor(),
|
26 |
+
transforms.Normalize(mean=mean, std=std),
|
27 |
+
|
28 |
+
]
|
29 |
+
)
|
30 |
+
|
31 |
+
inv_normalize = transforms.Normalize(mean=inv_mean, std=inv_std)
|
32 |
+
|
33 |
+
classes = [
|
34 |
+
"plane",
|
35 |
+
"car",
|
36 |
+
"bird",
|
37 |
+
"cat",
|
38 |
+
"deer",
|
39 |
+
"dog",
|
40 |
+
"frog",
|
41 |
+
"horse",
|
42 |
+
"ship",
|
43 |
+
"truck",
|
44 |
+
]
|
45 |
+
|
46 |
+
|
47 |
+
class Gradio:
|
48 |
+
def __init__(self, model_path: str):
|
49 |
+
use_cuda = torch.cuda.is_available()
|
50 |
+
self.device = torch.device("cuda" if use_cuda else "cpu")
|
51 |
+
self.model = self.load_model(model_path)
|
52 |
+
self.temperature = 2
|
53 |
+
|
54 |
+
def load_model(self, model_path: str):
|
55 |
+
model = DavidPageNet().to(self.device)
|
56 |
+
|
57 |
+
if os.path.isfile(model_path):
|
58 |
+
model.load_state_dict(
|
59 |
+
torch.load(model_path)["model_state_dict"], strict=False
|
60 |
+
)
|
61 |
+
|
62 |
+
return model
|
63 |
+
|
64 |
+
def cam(
|
65 |
+
self,
|
66 |
+
input_tensor: torch.Tensor,
|
67 |
+
target_class_id: int,
|
68 |
+
layer_nums: List,
|
69 |
+
transparency: float = 0.7,
|
70 |
+
):
|
71 |
+
targets = [ClassifierOutputTarget(target_class_id)]
|
72 |
+
target_layers = [getattr(self.model, f"block{layer-1}") for layer in layer_nums]
|
73 |
+
|
74 |
+
with GradCAM(
|
75 |
+
model=self.model,
|
76 |
+
target_layers=target_layers,
|
77 |
+
use_cuda=self.device == torch.device("cuda"),
|
78 |
+
) as cam:
|
79 |
+
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
|
80 |
+
grayscale_cam = grayscale_cam[0, :]
|
81 |
+
|
82 |
+
img = inv_normalize(input_tensor)
|
83 |
+
rgb_img = img[0].permute(1, 2, 0).cpu().numpy()
|
84 |
+
|
85 |
+
visualization = show_cam_on_image(
|
86 |
+
rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency
|
87 |
+
)
|
88 |
+
return visualization
|
89 |
+
|
90 |
+
def inference(
|
91 |
+
self,
|
92 |
+
input_img: np.array,
|
93 |
+
transparency: float,
|
94 |
+
ntop_classes: int,
|
95 |
+
layer_nums: List,
|
96 |
+
cam_for_class: str,
|
97 |
+
):
|
98 |
+
self.model.eval()
|
99 |
+
input_img = transform(input_img)
|
100 |
+
|
101 |
+
input_img = input_img.to(self.device)
|
102 |
+
input_img = input_img.unsqueeze(0)
|
103 |
+
|
104 |
+
with torch.no_grad():
|
105 |
+
outputs = self.model(input_img).squeeze(0)
|
106 |
+
outputs = F.softmax(outputs / self.temperature, dim=-1)
|
107 |
+
|
108 |
+
probability, prediction = torch.sort(outputs, descending=True)
|
109 |
+
prediction = list(zip(prediction.tolist(), probability.tolist()))
|
110 |
+
|
111 |
+
class_id = (
|
112 |
+
prediction[0][0]
|
113 |
+
if cam_for_class in ["default", ""]
|
114 |
+
else classes.index(cam_for_class)
|
115 |
+
)
|
116 |
+
visualization = self.cam(
|
117 |
+
input_tensor=input_img,
|
118 |
+
target_class_id=class_id,
|
119 |
+
layer_nums=layer_nums,
|
120 |
+
transparency=transparency,
|
121 |
+
)
|
122 |
+
top_nclass_result = [
|
123 |
+
(classes[class_id], round(score, 2))
|
124 |
+
for class_id, score in prediction[:ntop_classes]
|
125 |
+
]
|
126 |
+
return visualization, dict(top_nclass_result)
|
127 |
+
|
128 |
+
|
129 |
+
method = Gradio(model_path="./model.pt")
|
130 |
+
demo = gr.Interface(
|
131 |
+
method.inference,
|
132 |
+
[
|
133 |
+
gr.Image(shape=(32, 32), label="Input Image", value="./samples/dog_cat.jpeg"),
|
134 |
+
gr.Slider(
|
135 |
+
minimum=0,
|
136 |
+
maximum=1,
|
137 |
+
value=0.5,
|
138 |
+
label="Transparency",
|
139 |
+
info="Transparency of the CAM-Attention Output",
|
140 |
+
),
|
141 |
+
gr.Slider(
|
142 |
+
minimum=1,
|
143 |
+
maximum=10,
|
144 |
+
step=1,
|
145 |
+
value=2,
|
146 |
+
label="Top Classes",
|
147 |
+
info="Number of Top Predicted Classes",
|
148 |
+
),
|
149 |
+
gr.CheckboxGroup(
|
150 |
+
choices=[1, 2, 3, 4],
|
151 |
+
value=[3, 4],
|
152 |
+
label="Network Layers",
|
153 |
+
info="Network Layers for CAM-Attention Extraction",
|
154 |
+
),
|
155 |
+
gr.Dropdown(
|
156 |
+
choices=["default"] + classes,
|
157 |
+
multiselect=False,
|
158 |
+
value="default",
|
159 |
+
label="Class Activation Map (CAM) Focus Visualization",
|
160 |
+
info="This section showcases the specific region of interest within the input image that the Class Activation Map (CAM) algorithm emphasizes to make predictions based on the selected class from the dropdown menu. The 'default' value serves as the default choice, representing the top class predicted by the model.",
|
161 |
+
),
|
162 |
+
],
|
163 |
+
[
|
164 |
+
gr.Image(shape=(32, 32)).style(width=128, height=128),
|
165 |
+
gr.Label(label="Top Classes"),
|
166 |
+
],
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
demo.launch()
|