Navdeeppal Singh
commited on
Commit
•
cbef243
1
Parent(s):
98e242b
feat: add app
Browse files- app.py +184 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Union
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
import bcos.models.pretrained as pretrained
|
9 |
+
from bcos.data.categories import IMAGENET_CATEGORIES
|
10 |
+
|
11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
|
13 |
+
|
14 |
+
def get_model(model_name):
|
15 |
+
model = getattr(pretrained, model_name)(pretrained=True)
|
16 |
+
model = model.to(device)
|
17 |
+
model.eval()
|
18 |
+
return model
|
19 |
+
|
20 |
+
|
21 |
+
MODEL_NAMES = pretrained.list_available()
|
22 |
+
|
23 |
+
|
24 |
+
class NormalizationMode:
|
25 |
+
# this is normalization for the explanations!
|
26 |
+
INDIVIDUAL = "individual"
|
27 |
+
WRT_PREDICTION = "wrt prediction's confidence"
|
28 |
+
INDIVIDUAL_X_CONFIDENCE = "individual×confidence"
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def all(cls):
|
32 |
+
return [cls.WRT_PREDICTION, cls.INDIVIDUAL_X_CONFIDENCE, cls.INDIVIDUAL]
|
33 |
+
|
34 |
+
|
35 |
+
def freeze(model):
|
36 |
+
for param in model.parameters():
|
37 |
+
param.requires_grad = False
|
38 |
+
|
39 |
+
|
40 |
+
def run(
|
41 |
+
model_name: str,
|
42 |
+
input_image: Image,
|
43 |
+
do_resize: bool,
|
44 |
+
do_center_crop: bool,
|
45 |
+
normalization_mode: str,
|
46 |
+
smooth: int,
|
47 |
+
alpha_percentile: Union[int, float],
|
48 |
+
plot_dpi: int,
|
49 |
+
topk: int = 5,
|
50 |
+
) -> Tuple[dict, plt.Figure]:
|
51 |
+
# cleanup previous stuff
|
52 |
+
plt.close("all")
|
53 |
+
torch.cuda.empty_cache()
|
54 |
+
|
55 |
+
# preprocess - get model and transform input image
|
56 |
+
model = get_model(model_name)
|
57 |
+
freeze(model)
|
58 |
+
x = model.transform.transform_with_options(
|
59 |
+
input_image,
|
60 |
+
center_crop=do_center_crop,
|
61 |
+
resize=do_resize,
|
62 |
+
)
|
63 |
+
x = x.unsqueeze(0).to(device).requires_grad_()
|
64 |
+
|
65 |
+
# predict and explain
|
66 |
+
with model.explanation_mode():
|
67 |
+
out = model(x)
|
68 |
+
|
69 |
+
topk_values, topk_preds = torch.topk(out, topk, dim=1)
|
70 |
+
topk_values, topk_preds = topk_values[0], topk_preds[0]
|
71 |
+
|
72 |
+
dynamic_weights = [] # list of grad tensors of shape (C, H, W)
|
73 |
+
for i in range(topk):
|
74 |
+
topk_values[i].backward(inputs=[x], retain_graph=i < topk - 1)
|
75 |
+
dynamic_weights.append(
|
76 |
+
x.grad.detach().cpu()[0],
|
77 |
+
)
|
78 |
+
x.grad = None # reset
|
79 |
+
|
80 |
+
# prepare output labels+confidences
|
81 |
+
topk_probabilities = (
|
82 |
+
model.to_probabilities(out.detach()).topk(topk, dim=1).values[0].cpu()
|
83 |
+
)
|
84 |
+
confidences = {
|
85 |
+
IMAGENET_CATEGORIES[i]: v.item() for i, v in zip(topk_preds, topk_probabilities)
|
86 |
+
}
|
87 |
+
|
88 |
+
# output plot of images
|
89 |
+
output_fig, axs = plt.subplots(
|
90 |
+
1, topk + 1, dpi=plot_dpi, figsize=((topk + 1) * 2.1, 2)
|
91 |
+
)
|
92 |
+
|
93 |
+
# visualize input image
|
94 |
+
x = x.detach().cpu()[0]
|
95 |
+
axs[0].imshow(x[:3].permute(1, 2, 0).numpy())
|
96 |
+
axs[0].set_xlabel("Input Image")
|
97 |
+
|
98 |
+
# visualize explanations
|
99 |
+
pred_confidence = topk_probabilities[0] # first one is pred
|
100 |
+
for i, ax in enumerate(axs[1:]):
|
101 |
+
expl = model.gradient_to_image(
|
102 |
+
x,
|
103 |
+
dynamic_weights[i],
|
104 |
+
smooth=smooth,
|
105 |
+
alpha_percentile=alpha_percentile,
|
106 |
+
)
|
107 |
+
|
108 |
+
if normalization_mode == NormalizationMode.INDIVIDUAL_X_CONFIDENCE:
|
109 |
+
expl[:, :, -1] *= topk_probabilities[i].item()
|
110 |
+
elif normalization_mode == NormalizationMode.WRT_PREDICTION and i > 0:
|
111 |
+
expl[:, :, -1] *= (topk_probabilities[i] / pred_confidence).item()
|
112 |
+
else: # NormalizationMode.INDIVIDUAL
|
113 |
+
pass
|
114 |
+
|
115 |
+
ax.imshow(expl)
|
116 |
+
ax.set_xlabel(IMAGENET_CATEGORIES[topk_preds[i]])
|
117 |
+
|
118 |
+
for ax in axs:
|
119 |
+
ax.set_xticks([])
|
120 |
+
ax.set_yticks([])
|
121 |
+
|
122 |
+
output_fig.tight_layout()
|
123 |
+
|
124 |
+
return confidences, output_fig
|
125 |
+
|
126 |
+
|
127 |
+
with gr.Blocks() as demo:
|
128 |
+
# basic info
|
129 |
+
gr.Markdown(
|
130 |
+
"""# B-cos Explanation Generation Demo
|
131 |
+
[Repository](https://github.com/B-cos/B-cos-v2/)
|
132 |
+
"""
|
133 |
+
)
|
134 |
+
|
135 |
+
with gr.Row():
|
136 |
+
selected_model = gr.Dropdown(
|
137 |
+
MODEL_NAMES, value="densenet121_long", label="Select model"
|
138 |
+
)
|
139 |
+
|
140 |
+
with gr.Accordion("Options", open=False):
|
141 |
+
do_resize = gr.Checkbox(
|
142 |
+
label="Resize input image's shorter side to 256", value=True
|
143 |
+
)
|
144 |
+
do_center_crop = gr.Checkbox(
|
145 |
+
label="Center crop input image to 224x224", value=False
|
146 |
+
)
|
147 |
+
normalization_mode = gr.Radio(
|
148 |
+
NormalizationMode.all(),
|
149 |
+
value=NormalizationMode.WRT_PREDICTION,
|
150 |
+
label="Normalization Mode",
|
151 |
+
)
|
152 |
+
|
153 |
+
smooth = gr.Slider(1, 51, value=15, step=2, label="Smoothing kernel size")
|
154 |
+
alpha_percentile = gr.Number(value=99.99, label="Percentile")
|
155 |
+
plot_dpi = gr.Number(value=100, label="Plot DPI")
|
156 |
+
|
157 |
+
input_image = gr.Image(type="pil", label="Image")
|
158 |
+
run_button = gr.Button("Predict and Explain", variant="primary")
|
159 |
+
|
160 |
+
# will contain all outputs in a plot
|
161 |
+
output = gr.Plot(label="Explanations")
|
162 |
+
# labels
|
163 |
+
output_labels = gr.Label(label="Top-5 Predictions")
|
164 |
+
|
165 |
+
run_button.click(
|
166 |
+
fn=run,
|
167 |
+
inputs=[
|
168 |
+
selected_model,
|
169 |
+
input_image,
|
170 |
+
do_resize,
|
171 |
+
do_center_crop,
|
172 |
+
normalization_mode,
|
173 |
+
smooth,
|
174 |
+
alpha_percentile,
|
175 |
+
plot_dpi,
|
176 |
+
],
|
177 |
+
outputs=[output_labels, output],
|
178 |
+
scroll_to_output=True,
|
179 |
+
)
|
180 |
+
|
181 |
+
|
182 |
+
demo.launch(
|
183 |
+
queue=True,
|
184 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bcos
|
2 |
+
einops
|
3 |
+
torch>=1.13
|
4 |
+
torchvision
|