Navdeeppal Singh commited on
Commit
cbef243
1 Parent(s): 98e242b

feat: add app

Browse files
Files changed (2) hide show
  1. app.py +184 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import gradio as gr
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ from PIL import Image
7
+
8
+ import bcos.models.pretrained as pretrained
9
+ from bcos.data.categories import IMAGENET_CATEGORIES
10
+
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+
14
+ def get_model(model_name):
15
+ model = getattr(pretrained, model_name)(pretrained=True)
16
+ model = model.to(device)
17
+ model.eval()
18
+ return model
19
+
20
+
21
+ MODEL_NAMES = pretrained.list_available()
22
+
23
+
24
+ class NormalizationMode:
25
+ # this is normalization for the explanations!
26
+ INDIVIDUAL = "individual"
27
+ WRT_PREDICTION = "wrt prediction's confidence"
28
+ INDIVIDUAL_X_CONFIDENCE = "individual×confidence"
29
+
30
+ @classmethod
31
+ def all(cls):
32
+ return [cls.WRT_PREDICTION, cls.INDIVIDUAL_X_CONFIDENCE, cls.INDIVIDUAL]
33
+
34
+
35
+ def freeze(model):
36
+ for param in model.parameters():
37
+ param.requires_grad = False
38
+
39
+
40
+ def run(
41
+ model_name: str,
42
+ input_image: Image,
43
+ do_resize: bool,
44
+ do_center_crop: bool,
45
+ normalization_mode: str,
46
+ smooth: int,
47
+ alpha_percentile: Union[int, float],
48
+ plot_dpi: int,
49
+ topk: int = 5,
50
+ ) -> Tuple[dict, plt.Figure]:
51
+ # cleanup previous stuff
52
+ plt.close("all")
53
+ torch.cuda.empty_cache()
54
+
55
+ # preprocess - get model and transform input image
56
+ model = get_model(model_name)
57
+ freeze(model)
58
+ x = model.transform.transform_with_options(
59
+ input_image,
60
+ center_crop=do_center_crop,
61
+ resize=do_resize,
62
+ )
63
+ x = x.unsqueeze(0).to(device).requires_grad_()
64
+
65
+ # predict and explain
66
+ with model.explanation_mode():
67
+ out = model(x)
68
+
69
+ topk_values, topk_preds = torch.topk(out, topk, dim=1)
70
+ topk_values, topk_preds = topk_values[0], topk_preds[0]
71
+
72
+ dynamic_weights = [] # list of grad tensors of shape (C, H, W)
73
+ for i in range(topk):
74
+ topk_values[i].backward(inputs=[x], retain_graph=i < topk - 1)
75
+ dynamic_weights.append(
76
+ x.grad.detach().cpu()[0],
77
+ )
78
+ x.grad = None # reset
79
+
80
+ # prepare output labels+confidences
81
+ topk_probabilities = (
82
+ model.to_probabilities(out.detach()).topk(topk, dim=1).values[0].cpu()
83
+ )
84
+ confidences = {
85
+ IMAGENET_CATEGORIES[i]: v.item() for i, v in zip(topk_preds, topk_probabilities)
86
+ }
87
+
88
+ # output plot of images
89
+ output_fig, axs = plt.subplots(
90
+ 1, topk + 1, dpi=plot_dpi, figsize=((topk + 1) * 2.1, 2)
91
+ )
92
+
93
+ # visualize input image
94
+ x = x.detach().cpu()[0]
95
+ axs[0].imshow(x[:3].permute(1, 2, 0).numpy())
96
+ axs[0].set_xlabel("Input Image")
97
+
98
+ # visualize explanations
99
+ pred_confidence = topk_probabilities[0] # first one is pred
100
+ for i, ax in enumerate(axs[1:]):
101
+ expl = model.gradient_to_image(
102
+ x,
103
+ dynamic_weights[i],
104
+ smooth=smooth,
105
+ alpha_percentile=alpha_percentile,
106
+ )
107
+
108
+ if normalization_mode == NormalizationMode.INDIVIDUAL_X_CONFIDENCE:
109
+ expl[:, :, -1] *= topk_probabilities[i].item()
110
+ elif normalization_mode == NormalizationMode.WRT_PREDICTION and i > 0:
111
+ expl[:, :, -1] *= (topk_probabilities[i] / pred_confidence).item()
112
+ else: # NormalizationMode.INDIVIDUAL
113
+ pass
114
+
115
+ ax.imshow(expl)
116
+ ax.set_xlabel(IMAGENET_CATEGORIES[topk_preds[i]])
117
+
118
+ for ax in axs:
119
+ ax.set_xticks([])
120
+ ax.set_yticks([])
121
+
122
+ output_fig.tight_layout()
123
+
124
+ return confidences, output_fig
125
+
126
+
127
+ with gr.Blocks() as demo:
128
+ # basic info
129
+ gr.Markdown(
130
+ """# B-cos Explanation Generation Demo
131
+ [Repository](https://github.com/B-cos/B-cos-v2/)
132
+ """
133
+ )
134
+
135
+ with gr.Row():
136
+ selected_model = gr.Dropdown(
137
+ MODEL_NAMES, value="densenet121_long", label="Select model"
138
+ )
139
+
140
+ with gr.Accordion("Options", open=False):
141
+ do_resize = gr.Checkbox(
142
+ label="Resize input image's shorter side to 256", value=True
143
+ )
144
+ do_center_crop = gr.Checkbox(
145
+ label="Center crop input image to 224x224", value=False
146
+ )
147
+ normalization_mode = gr.Radio(
148
+ NormalizationMode.all(),
149
+ value=NormalizationMode.WRT_PREDICTION,
150
+ label="Normalization Mode",
151
+ )
152
+
153
+ smooth = gr.Slider(1, 51, value=15, step=2, label="Smoothing kernel size")
154
+ alpha_percentile = gr.Number(value=99.99, label="Percentile")
155
+ plot_dpi = gr.Number(value=100, label="Plot DPI")
156
+
157
+ input_image = gr.Image(type="pil", label="Image")
158
+ run_button = gr.Button("Predict and Explain", variant="primary")
159
+
160
+ # will contain all outputs in a plot
161
+ output = gr.Plot(label="Explanations")
162
+ # labels
163
+ output_labels = gr.Label(label="Top-5 Predictions")
164
+
165
+ run_button.click(
166
+ fn=run,
167
+ inputs=[
168
+ selected_model,
169
+ input_image,
170
+ do_resize,
171
+ do_center_crop,
172
+ normalization_mode,
173
+ smooth,
174
+ alpha_percentile,
175
+ plot_dpi,
176
+ ],
177
+ outputs=[output_labels, output],
178
+ scroll_to_output=True,
179
+ )
180
+
181
+
182
+ demo.launch(
183
+ queue=True,
184
+ )
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ bcos
2
+ einops
3
+ torch>=1.13
4
+ torchvision