andy-wyx commited on
Commit
dd58475
·
1 Parent(s): 0c61c42

show more xai output

Browse files
Files changed (3) hide show
  1. app.py +38 -11
  2. closest_sample.py +0 -1
  3. explanations.py +35 -30
app.py CHANGED
@@ -121,11 +121,14 @@ def explain_image(input_image,model_name):
121
  else:
122
  size = 600
123
  #saliency, integrated, smoothgrad,
124
- rise,avg = explain(model,input_image,size = size, n_classes=n_classes)
125
  #original = saliency + integrated + smoothgrad
126
  print('done')
127
- rise1,rise2,rise3,rise4,rise5,avg = rise[0],rise[1],rise[2],rise[3],rise[4],avg[0]
128
- return rise1,rise2,rise3,rise4,rise5,avg
 
 
 
129
 
130
  #minimalist theme
131
  with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
@@ -173,18 +176,42 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
173
 
174
  with gr.Accordion("Explanations "):
175
  gr.Markdown("Computing Explanations from the model")
176
- with gr.Row():
 
 
177
  #original_input = gr.Image(label="Original Frame")
178
  #saliency = gr.Image(label="saliency")
179
  #gradcam = gr.Image(label='integraged gradients')
180
  #guided_gradcam = gr.Image(label='gradcam')
181
  #guided_backprop = gr.Image(label='guided backprop')
182
- rise1 = gr.Image(label = 'Rise1')
183
- rise2 = gr.Image(label = 'Rise2')
184
- rise3 = gr.Image(label = 'Rise3')
185
- rise4 = gr.Image(label = 'Rise4')
186
- rise5 = gr.Image(label = 'Rise5')
187
- avg = gr.Image(label = 'Avg')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  generate_explanations = gr.Button("Generate Explanations")
189
 
190
  # with gr.Accordion('Closest Images'):
@@ -217,7 +244,7 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
217
 
218
  #segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
219
  classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted)
220
- generate_explanations.click(explain_image, inputs=[input_image,model_name], outputs=[rise1,rise2,rise3,rise4,rise5,avg]) #
221
  #find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
222
  def update_outputs(input_image,model_name):
223
  labels, images = find_closest(input_image,model_name)
 
121
  else:
122
  size = 600
123
  #saliency, integrated, smoothgrad,
124
+ exp_list = explain(model,input_image,size = size, n_classes=n_classes)
125
  #original = saliency + integrated + smoothgrad
126
  print('done')
127
+ sobol1,sobol2,sobol3,sobol4,sobol5 = exp_list[0],exp_list[1],exp_list[2],exp_list[3],exp_list[4]
128
+ rise1,rise2,rise3,rise4,rise5 = exp_list[5],exp_list[6],exp_list[7],exp_list[8],exp_list[9]
129
+ hsic1,hsic2,hsic3,hsic4,hsic5 = exp_list[10],exp_list[11],exp_list[12],exp_list[13],exp_list[14]
130
+ saliency1,saliency2,saliency3,saliency4,saliency5 = exp_list[15],exp_list[16],exp_list[17],exp_list[18],exp_list[19]
131
+ return sobol1,sobol2,sobol3,sobol4,sobol5,rise1,rise2,rise3,rise4,rise5,hsic1,hsic2,hsic3,hsic4,hsic5,saliency1,saliency2,saliency3,saliency4,saliency5
132
 
133
  #minimalist theme
134
  with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
 
176
 
177
  with gr.Accordion("Explanations "):
178
  gr.Markdown("Computing Explanations from the model")
179
+ with gr.Column():
180
+ with gr.Row():
181
+
182
  #original_input = gr.Image(label="Original Frame")
183
  #saliency = gr.Image(label="saliency")
184
  #gradcam = gr.Image(label='integraged gradients')
185
  #guided_gradcam = gr.Image(label='gradcam')
186
  #guided_backprop = gr.Image(label='guided backprop')
187
+ sobol1 = gr.Image(label = 'Sobol1')
188
+ sobol2= gr.Image(label = 'Sobol2')
189
+ sobol3= gr.Image(label = 'Sobol3')
190
+ sobol4= gr.Image(label = 'Sobol4')
191
+ sobol5= gr.Image(label = 'Sobol5')
192
+
193
+ with gr.Row():
194
+ rise1 = gr.Image(label = 'Rise1')
195
+ rise2 = gr.Image(label = 'Rise2')
196
+ rise3 = gr.Image(label = 'Rise3')
197
+ rise4 = gr.Image(label = 'Rise4')
198
+ rise5 = gr.Image(label = 'Rise5')
199
+
200
+ with gr.Row():
201
+ hsic1 = gr.Image(label = 'HSIC1')
202
+ hsic2 = gr.Image(label = 'HSIC2')
203
+ hsic3 = gr.Image(label = 'HSIC3')
204
+ hsic4 = gr.Image(label = 'HSIC4')
205
+ hsic5 = gr.Image(label = 'HSIC5')
206
+
207
+ with gr.Row():
208
+ saliency1 = gr.Image(label = 'Saliency1')
209
+ saliency2 = gr.Image(label = 'Saliency2')
210
+ saliency3 = gr.Image(label = 'Saliency3')
211
+ saliency4 = gr.Image(label = 'Saliency4')
212
+ saliency5 = gr.Image(label = 'Saliency5')
213
+
214
+
215
  generate_explanations = gr.Button("Generate Explanations")
216
 
217
  # with gr.Accordion('Closest Images'):
 
244
 
245
  #segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
246
  classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted)
247
+ generate_explanations.click(explain_image, inputs=[input_image,model_name], outputs=[sobol1,sobol2,sobol3,sobol4,sobol5,rise1,rise2,rise3,rise4,rise5,hsic1,hsic2,hsic3,hsic4,hsic5,saliency1,saliency2,saliency3,saliency4,saliency5]) #
248
  #find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
249
  def update_outputs(input_image,model_name):
250
  labels, images = find_closest(input_image,model_name)
closest_sample.py CHANGED
@@ -35,7 +35,6 @@ def pca_distance(pca,sample,embedding):
35
  s = pca.transform(sample.reshape(1,-1))
36
  all = pca.transform(embedding[:,-1])
37
  distances = np.linalg.norm(all - s, axis=1)
38
- #print(distances)
39
  return np.argsort(distances)[:5]
40
 
41
  def return_paths(argsorted,files):
 
35
  s = pca.transform(sample.reshape(1,-1))
36
  all = pca.transform(embedding[:,-1])
37
  distances = np.linalg.norm(all - s, axis=1)
 
38
  return np.argsort(distances)[:5]
39
 
40
  def return_paths(argsorted,files):
explanations.py CHANGED
@@ -2,8 +2,8 @@ import xplique
2
  import tensorflow as tf
3
  from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
4
  SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,
5
- GradCAMPP, Lime, KernelShap)
6
-
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
  from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
@@ -50,16 +50,19 @@ def explain(model, input_image,size=600, n_classes=171) :
50
  class_model = tf.keras.Model(model.input, model.output[1])
51
 
52
  explainers = [
53
- #Saliency(class_model),
54
  #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
55
  #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
56
  #GradCAM(class_model),
57
- Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15,
58
- preservation_probability=0.5)
 
 
 
 
 
59
  #
60
  ]
61
- explainer = Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15,
62
- preservation_probability=0.5)
63
 
64
  cropped,repetitions = _clever_crop(input_image,(size,size))
65
  size_repetitions = int(size//(repetitions.numpy()+1))
@@ -70,30 +73,32 @@ def explain(model, input_image,size=600, n_classes=171) :
70
  #print(top_5_indices)
71
  X = np.expand_dims(X, 0)
72
  explanations = []
73
- for i,Y in enumerate(top_5_indices):
74
- Y = tf.one_hot([Y], n_classes)
75
- print(f'{i}/{len(top_5_indices)}')
76
- phi = np.abs(explainer(X, Y))[0]
77
- if len(phi.shape) == 3:
78
- phi = np.mean(phi, -1)
79
- show(X[0][:,size_repetitions:2*size_repetitions,:])
80
- show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
81
- plt.savefig(f'phi_{i}.png')
82
- explanations.append(f'phi_{i}.png')
83
- avg=[]
84
- for i,Y in enumerate(top_5_indices):
85
- Y = tf.one_hot([Y], n_classes)
86
- print(f'{i}/{len(top_5_indices)}')
87
- phi = np.abs(explainer(X, Y))[0]
88
- if len(phi.shape) == 3:
89
- phi = np.mean(phi, -1)
90
- show(X[0][:,size_repetitions:2*size_repetitions,:])
91
- show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
92
- plt.savefig(f'phi_6.png')
93
- avg.append(f'phi_6.png')
 
 
94
 
95
  print('Done')
96
  if len(explanations)==1:
97
  explanations = explanations[0]
98
-
99
- return explanations,avg
 
2
  import tensorflow as tf
3
  from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
4
  SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,
5
+ GradCAMPP, Lime, KernelShap,SobolAttributionMethod,HsicAttributionMethod)
6
+ from xplique.attributions.global_sensitivity_analysis import LatinHypercube
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
  from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
 
50
  class_model = tf.keras.Model(model.input, model.output[1])
51
 
52
  explainers = [
53
+ #Sobol, RISE, HSIC, Saliency
54
  #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
55
  #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
56
  #GradCAM(class_model),
57
+ SobolAttributionMethod(class_model, grid_size=8, nb_design=32),
58
+ Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15,
59
+ preservation_probability=0.5),
60
+ HsicAttributionMethod(class_model,
61
+ grid_size=7, nb_design=1500,
62
+ sampler = LatinHypercube(binary=True)),
63
+ Saliency(class_model),
64
  #
65
  ]
 
 
66
 
67
  cropped,repetitions = _clever_crop(input_image,(size,size))
68
  size_repetitions = int(size//(repetitions.numpy()+1))
 
73
  #print(top_5_indices)
74
  X = np.expand_dims(X, 0)
75
  explanations = []
76
+ for e,explainer in enumerate(explainers):
77
+ print(f'{e}/{len(explainers)}')
78
+ for i,Y in enumerate(top_5_indices):
79
+ Y = tf.one_hot([Y], n_classes)
80
+ print(f'{i}/{len(top_5_indices)}')
81
+ phi = np.abs(explainer(X, Y))[0]
82
+ if len(phi.shape) == 3:
83
+ phi = np.mean(phi, -1)
84
+ show(X[0][:,size_repetitions:2*size_repetitions,:])
85
+ show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
86
+ plt.savefig(f'phi_{e}{i}.png')
87
+ explanations.append(f'phi_{e}{i}.png')
88
+ # avg=[]
89
+ # for i,Y in enumerate(top_5_indices):
90
+ # Y = tf.one_hot([Y], n_classes)
91
+ # print(f'{i}/{len(top_5_indices)}')
92
+ # phi = np.abs(explainer(X, Y))[0]
93
+ # if len(phi.shape) == 3:
94
+ # phi = np.mean(phi, -1)
95
+ # show(X[0][:,size_repetitions:2*size_repetitions,:])
96
+ # show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
97
+ # plt.savefig(f'phi_6.png')
98
+ # avg.append(f'phi_6.png')
99
 
100
  print('Done')
101
  if len(explanations)==1:
102
  explanations = explanations[0]
103
+ # return explanations,avg
104
+ return explanations