Navdeeppal Singh
commited on
Commit
•
5c65491
1
Parent(s):
852655c
feat(gradio): improvements + examples
Browse files
app.py
CHANGED
@@ -40,13 +40,13 @@ def freeze(model):
|
|
40 |
def run(
|
41 |
model_name: str,
|
42 |
input_image: Image,
|
43 |
-
do_resize: bool,
|
44 |
-
do_center_crop: bool,
|
45 |
-
normalization_mode: str,
|
46 |
-
smooth: int,
|
47 |
-
alpha_percentile: Union[int, float],
|
48 |
-
plot_dpi: int,
|
49 |
-
topk: int =
|
50 |
) -> Tuple[dict, plt.Figure]:
|
51 |
# cleanup previous stuff
|
52 |
plt.close("all")
|
@@ -128,7 +128,9 @@ with gr.Blocks() as demo:
|
|
128 |
# basic info
|
129 |
gr.Markdown(
|
130 |
"""# B-cos Explanation Generation Demo
|
131 |
-
|
|
|
|
|
132 |
"""
|
133 |
)
|
134 |
|
@@ -147,7 +149,7 @@ with gr.Blocks() as demo:
|
|
147 |
normalization_mode = gr.Radio(
|
148 |
NormalizationMode.all(),
|
149 |
value=NormalizationMode.WRT_PREDICTION,
|
150 |
-
label="Normalization Mode",
|
151 |
)
|
152 |
|
153 |
smooth = gr.Slider(1, 51, value=15, step=2, label="Smoothing kernel size")
|
@@ -178,6 +180,44 @@ with gr.Blocks() as demo:
|
|
178 |
scroll_to_output=True,
|
179 |
)
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
demo.launch()
|
183 |
|
|
|
40 |
def run(
|
41 |
model_name: str,
|
42 |
input_image: Image,
|
43 |
+
do_resize: bool = True,
|
44 |
+
do_center_crop: bool = False,
|
45 |
+
normalization_mode: str = NormalizationMode.WRT_PREDICTION,
|
46 |
+
smooth: int = 15,
|
47 |
+
alpha_percentile: Union[int, float] = 99.99,
|
48 |
+
plot_dpi: int = 120,
|
49 |
+
topk: int = 3,
|
50 |
) -> Tuple[dict, plt.Figure]:
|
51 |
# cleanup previous stuff
|
52 |
plt.close("all")
|
|
|
128 |
# basic info
|
129 |
gr.Markdown(
|
130 |
"""# B-cos Explanation Generation Demo
|
131 |
+
This demo generates explanations for images using the B-cos models.
|
132 |
+
|
133 |
+
GitHub: [link](https://github.com/B-cos/B-cos-v2/)
|
134 |
"""
|
135 |
)
|
136 |
|
|
|
149 |
normalization_mode = gr.Radio(
|
150 |
NormalizationMode.all(),
|
151 |
value=NormalizationMode.WRT_PREDICTION,
|
152 |
+
label="Explanation Normalization Mode",
|
153 |
)
|
154 |
|
155 |
smooth = gr.Slider(1, 51, value=15, step=2, label="Smoothing kernel size")
|
|
|
180 |
scroll_to_output=True,
|
181 |
)
|
182 |
|
183 |
+
gr.Examples(
|
184 |
+
fn=run,
|
185 |
+
examples=[
|
186 |
+
[
|
187 |
+
"resnet50",
|
188 |
+
"./examples/polizeifahrzeug-zebra.png",
|
189 |
+
True,
|
190 |
+
False,
|
191 |
+
NormalizationMode.WRT_PREDICTION,
|
192 |
+
15,
|
193 |
+
99.99,
|
194 |
+
120,
|
195 |
+
],
|
196 |
+
[
|
197 |
+
"resnet50",
|
198 |
+
"./examples/cat-dog.png",
|
199 |
+
True,
|
200 |
+
False,
|
201 |
+
NormalizationMode.WRT_PREDICTION,
|
202 |
+
15,
|
203 |
+
99.99,
|
204 |
+
120,
|
205 |
+
]
|
206 |
+
],
|
207 |
+
inputs=[
|
208 |
+
selected_model,
|
209 |
+
input_image,
|
210 |
+
do_resize,
|
211 |
+
do_center_crop,
|
212 |
+
normalization_mode,
|
213 |
+
smooth,
|
214 |
+
alpha_percentile,
|
215 |
+
plot_dpi,
|
216 |
+
],
|
217 |
+
outputs=[output_labels, output],
|
218 |
+
cache_examples=True,
|
219 |
+
)
|
220 |
+
|
221 |
|
222 |
demo.launch()
|
223 |
|