Navdeeppal Singh commited on
Commit
5c65491
1 Parent(s): 852655c

feat(gradio): improvements + examples

Browse files
Files changed (1) hide show
  1. app.py +49 -9
app.py CHANGED
@@ -40,13 +40,13 @@ def freeze(model):
40
  def run(
41
  model_name: str,
42
  input_image: Image,
43
- do_resize: bool,
44
- do_center_crop: bool,
45
- normalization_mode: str,
46
- smooth: int,
47
- alpha_percentile: Union[int, float],
48
- plot_dpi: int,
49
- topk: int = 5,
50
  ) -> Tuple[dict, plt.Figure]:
51
  # cleanup previous stuff
52
  plt.close("all")
@@ -128,7 +128,9 @@ with gr.Blocks() as demo:
128
  # basic info
129
  gr.Markdown(
130
  """# B-cos Explanation Generation Demo
131
- [Repository](https://github.com/B-cos/B-cos-v2/)
 
 
132
  """
133
  )
134
 
@@ -147,7 +149,7 @@ with gr.Blocks() as demo:
147
  normalization_mode = gr.Radio(
148
  NormalizationMode.all(),
149
  value=NormalizationMode.WRT_PREDICTION,
150
- label="Normalization Mode",
151
  )
152
 
153
  smooth = gr.Slider(1, 51, value=15, step=2, label="Smoothing kernel size")
@@ -178,6 +180,44 @@ with gr.Blocks() as demo:
178
  scroll_to_output=True,
179
  )
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  demo.launch()
183
 
 
40
  def run(
41
  model_name: str,
42
  input_image: Image,
43
+ do_resize: bool = True,
44
+ do_center_crop: bool = False,
45
+ normalization_mode: str = NormalizationMode.WRT_PREDICTION,
46
+ smooth: int = 15,
47
+ alpha_percentile: Union[int, float] = 99.99,
48
+ plot_dpi: int = 120,
49
+ topk: int = 3,
50
  ) -> Tuple[dict, plt.Figure]:
51
  # cleanup previous stuff
52
  plt.close("all")
 
128
  # basic info
129
  gr.Markdown(
130
  """# B-cos Explanation Generation Demo
131
+ This demo generates explanations for images using the B-cos models.
132
+
133
+ GitHub: [link](https://github.com/B-cos/B-cos-v2/)
134
  """
135
  )
136
 
 
149
  normalization_mode = gr.Radio(
150
  NormalizationMode.all(),
151
  value=NormalizationMode.WRT_PREDICTION,
152
+ label="Explanation Normalization Mode",
153
  )
154
 
155
  smooth = gr.Slider(1, 51, value=15, step=2, label="Smoothing kernel size")
 
180
  scroll_to_output=True,
181
  )
182
 
183
+ gr.Examples(
184
+ fn=run,
185
+ examples=[
186
+ [
187
+ "resnet50",
188
+ "./examples/polizeifahrzeug-zebra.png",
189
+ True,
190
+ False,
191
+ NormalizationMode.WRT_PREDICTION,
192
+ 15,
193
+ 99.99,
194
+ 120,
195
+ ],
196
+ [
197
+ "resnet50",
198
+ "./examples/cat-dog.png",
199
+ True,
200
+ False,
201
+ NormalizationMode.WRT_PREDICTION,
202
+ 15,
203
+ 99.99,
204
+ 120,
205
+ ]
206
+ ],
207
+ inputs=[
208
+ selected_model,
209
+ input_image,
210
+ do_resize,
211
+ do_center_crop,
212
+ normalization_mode,
213
+ smooth,
214
+ alpha_percentile,
215
+ plot_dpi,
216
+ ],
217
+ outputs=[output_labels, output],
218
+ cache_examples=True,
219
+ )
220
+
221
 
222
  demo.launch()
223