mawady commited on
Commit
9c45667
1 Parent(s): dbacfc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -7,7 +7,7 @@ from alibi.explainers import IntegratedGradients
7
  from alibi.datasets import load_cats
8
  from alibi.utils.visualization import visualize_image_attr
9
  import numpy as np
10
- from PIL import Image
11
  import io
12
  import time
13
  import os
@@ -46,34 +46,42 @@ def do_process(img, baseline):
46
  lstPreds = decode_predictions(preds, top=3)[0]
47
  dctPreds = {lstPreds[i][1]: round(float(lstPreds[i][2]),2) for i in range(len(lstPreds))}
48
  predictions = preds.argmax(axis=1)
49
- if baseline is 'white':
50
  baselines = bls = np.ones(instance.shape).astype(instance.dtype)
51
- elif baseline is 'black':
 
52
  baselines = bls = np.zeros(instance.shape).astype(instance.dtype)
 
 
 
 
 
 
53
  else:
54
  baselines = np.random.random_sample(instance.shape).astype(instance.dtype)
 
55
  explanation = ig.explain(instance,
56
  baselines=baselines,
57
  target=predictions)
58
  attrs = explanation.attributions[0]
59
- fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 10))
60
  fig, ax = visualize_image_attr(attr=attrs.squeeze(), original_image=img, method='blended_heat_map',
61
- sign='all', show_colorbar=True, title=None,
62
- plt_fig_axis=(fig, ax), use_pyplot=False)
63
  fig.tight_layout()
64
  buf = io.BytesIO()
65
  fig.savefig(buf)
66
  buf.seek(0)
67
  img_res = Image.open(buf)
68
- return img_res, dctPreds
69
-
70
  input_im = gr.inputs.Image(shape=(224, 224), image_mode='RGB',
71
  invert_colors=False, source="upload",
72
  type="pil")
73
  input_drop = gr.inputs.Dropdown(label='Baseline (default: random)',
74
- choices=['black', 'white', 'random'], default='random', type='value')
75
 
76
  output_img = gr.outputs.Image(label='Output of Integrated Gradients', type='pil')
 
77
  output_label = gr.outputs.Label(label='Classification results', num_top_classes=3)
78
 
79
  title = "XAI - Integrated gradients"
@@ -83,7 +91,7 @@ article="<p style='text-align: center'><a href='https://github.com/mawady/colab-
83
  iface = gr.Interface(
84
  fn=do_process,
85
  inputs=[input_im, input_drop],
86
- outputs=[output_img,output_label],
87
  live=False,
88
  interpretation=None,
89
  title=title,
 
7
  from alibi.datasets import load_cats
8
  from alibi.utils.visualization import visualize_image_attr
9
  import numpy as np
10
+ from PIL import Image, ImageFilter
11
  import io
12
  import time
13
  import os
 
46
  lstPreds = decode_predictions(preds, top=3)[0]
47
  dctPreds = {lstPreds[i][1]: round(float(lstPreds[i][2]),2) for i in range(len(lstPreds))}
48
  predictions = preds.argmax(axis=1)
49
+ if baseline == 'white':
50
  baselines = bls = np.ones(instance.shape).astype(instance.dtype)
51
+ img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255))
52
+ elif baseline == 'black':
53
  baselines = bls = np.zeros(instance.shape).astype(instance.dtype)
54
+ img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255))
55
+ elif baseline == 'blur':
56
+ img_flt = img.filter(ImageFilter.GaussianBlur(5))
57
+ baselines = image.img_to_array(img_flt)
58
+ baselines = np.expand_dims(baselines, axis=0)
59
+ baselines = preprocess_input(baselines)
60
  else:
61
  baselines = np.random.random_sample(instance.shape).astype(instance.dtype)
62
+ img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255))
63
  explanation = ig.explain(instance,
64
  baselines=baselines,
65
  target=predictions)
66
  attrs = explanation.attributions[0]
 
67
  fig, ax = visualize_image_attr(attr=attrs.squeeze(), original_image=img, method='blended_heat_map',
68
+ sign='all', show_colorbar=True, title=baseline,
69
+ plt_fig_axis=None, use_pyplot=False)
70
  fig.tight_layout()
71
  buf = io.BytesIO()
72
  fig.savefig(buf)
73
  buf.seek(0)
74
  img_res = Image.open(buf)
75
+ return img_res, img_flt, dctPreds
76
+
77
  input_im = gr.inputs.Image(shape=(224, 224), image_mode='RGB',
78
  invert_colors=False, source="upload",
79
  type="pil")
80
  input_drop = gr.inputs.Dropdown(label='Baseline (default: random)',
81
+ choices=['random', 'black', 'white', 'blur'], default='random', type='value')
82
 
83
  output_img = gr.outputs.Image(label='Output of Integrated Gradients', type='pil')
84
+ output_base = gr.outputs.Image(label='Baseline image', type='pil')
85
  output_label = gr.outputs.Label(label='Classification results', num_top_classes=3)
86
 
87
  title = "XAI - Integrated gradients"
 
91
  iface = gr.Interface(
92
  fn=do_process,
93
  inputs=[input_im, input_drop],
94
+ outputs=[output_img,output_base,output_label]
95
  live=False,
96
  interpretation=None,
97
  title=title,