AnnonSubmission commited on
Commit
2a82b5d
1 Parent(s): 2f9b167

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -46
app.py CHANGED
@@ -13,7 +13,7 @@ from data_transforms import normal_transforms, no_shift_transforms, ig_transform
13
  from utils import overlay_heatmap, viz_map, show_image, deprocess, get_ssl_model, fig2img
14
  from methods import occlusion, occlusion_context_agnositc, pairwise_occlusion
15
  from methods import create_mixed_images, averaged_transforms, sailency, smooth_grad
16
- from methods import get_sample_dataset, pixel_invariance, get_gradcam, get_interactioncam
17
 
18
  matplotlib.use('Agg')
19
 
@@ -148,63 +148,24 @@ def get_cams():
148
 
149
  gradcam1, gradcam2 = get_gradcam(ssl_model, img1, img2)
150
  intcam1_mean, intcam2_mean = get_interactioncam(ssl_model, img1, img2, reduction = 'mean')
151
- intcam1_maxmax, intcam2_maxmax = get_interactioncam(ssl_model, img1, img2, reduction = 'max', grad_interact = True)
152
- intcam1_attnmax, intcam2_attnmax = get_interactioncam(ssl_model, img1, img2, reduction = 'attn', grad_interact = True)
153
 
154
- fig, axs = plt.subplots(2, 5, figsize=(20,8))
155
  np.vectorize(lambda ax:ax.axis('off'))(axs)
156
 
157
  axs[0,0].imshow(show_image(img1[0], squeeze = False, denormalize = denorm))
158
  axs[0,1].imshow(overlay_heatmap(img1, gradcam1, denormalize = denorm))
159
  axs[0,1].set_title("Grad-CAM")
160
  axs[0,2].imshow(overlay_heatmap(img1, intcam1_mean, denormalize = denorm))
161
- axs[0,2].set_title("IntCAM Mean")
162
- axs[0,3].imshow(overlay_heatmap(img1, intcam1_maxmax, denormalize = denorm))
163
- axs[0,3].set_title("IntCAM Max + IntGradMax")
164
- axs[0,4].imshow(overlay_heatmap(img1, intcam1_attnmax, denormalize = denorm))
165
- axs[0,4].set_title("IntCAM Attn + IntGradMax")
166
 
167
  axs[1,0].imshow(show_image(img2[0], squeeze = False, denormalize = denorm))
168
  axs[1,1].imshow(overlay_heatmap(img2, gradcam2, denormalize = denorm))
169
  axs[1,2].imshow(overlay_heatmap(img2, intcam2_mean, denormalize = denorm))
170
- axs[1,3].imshow(overlay_heatmap(img2, intcam2_maxmax, denormalize = denorm))
171
- axs[1,4].imshow(overlay_heatmap(img2, intcam2_attnmax, denormalize = denorm))
172
 
173
  plt.subplots_adjust(wspace=0.01, hspace = 0.01)
174
  pil_output = fig2img(fig)
175
  return pil_output
176
 
177
- def get_pixel_invariance():
178
-
179
- data_samples1, data_samples2, data_labels, labels_invariance = get_sample_dataset(img_path = img_main,
180
- num_augments = 1000,
181
- batch_size = 32,
182
- no_shift_transforms = no_shift_transforms,
183
- ssl_model = ssl_model,
184
- n_components = 10)
185
-
186
- inv_heatmap = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels,
187
- labels_invariance = labels_invariance, resize_transform = transforms.Resize, size = 64,
188
- epochs = 1000, learning_rate = 0.1, l1_weight = 0.2, zero_small_values = True,
189
- blur_output = True, nmf_weight = 0)
190
-
191
- inv_heatmap_nmf = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels,
192
- labels_invariance = labels_invariance, resize_transform = transforms.Resize, size = 64,
193
- epochs = 100, learning_rate = 0.1, l1_weight = 0.2, zero_small_values = True,
194
- blur_output = True, nmf_weight = 1)
195
-
196
- fig, axs = plt.subplots(1, 2, figsize=(10,5))
197
- np.vectorize(lambda ax:ax.axis('off'))(axs)
198
-
199
- axs[0].imshow(viz_map(img_main, inv_heatmap))
200
- axs[0].set_title("Heatmap w/o NMF")
201
- axs[1].imshow(viz_map(img_main, inv_heatmap_nmf))
202
- axs[1].set_title("Heatmap w/ NMF")
203
- plt.subplots_adjust(wspace=0.01, hspace = 0.01)
204
-
205
- pil_output = fig2img(fig)
206
- return pil_output
207
-
208
  xai = gr.Blocks()
209
 
210
  with xai:
@@ -240,9 +201,6 @@ with xai:
240
  blur_output = gr.Checkbox(value = True, label = "Blur Output")
241
  guided = gr.Checkbox(value = True, label = "Guided Backprop")
242
  avgtransform_button = gr.Button("Get Saliency")
243
- with gr.TabItem("Pixel Invariance"):
244
- gr.Markdown("Note: Invariance map will be obtained for the first image")
245
- pixel_invariance_button = gr.Button("Get Invariance Map")
246
 
247
  with gr.Column():
248
  output_image = gr.Image(type='pil', show_label = False)
@@ -252,7 +210,6 @@ with xai:
252
  occlusion_button.click(run_occlusion, inputs=[w_size,stride], outputs=output_image)
253
  avgtransform_button.click(get_avg_trasforms, inputs = [transform_type, add_noise, blur_output, guided], outputs = output_image)
254
  cams_button.click(get_cams, inputs = [], outputs = output_image)
255
- pixel_invariance_button.click(get_pixel_invariance, inputs = [], outputs = output_image)
256
 
257
  xai.launch()
258
 
 
13
  from utils import overlay_heatmap, viz_map, show_image, deprocess, get_ssl_model, fig2img
14
  from methods import occlusion, occlusion_context_agnositc, pairwise_occlusion
15
  from methods import create_mixed_images, averaged_transforms, sailency, smooth_grad
16
+ from methods import get_gradcam, get_interactioncam
17
 
18
  matplotlib.use('Agg')
19
 
 
148
 
149
  gradcam1, gradcam2 = get_gradcam(ssl_model, img1, img2)
150
  intcam1_mean, intcam2_mean = get_interactioncam(ssl_model, img1, img2, reduction = 'mean')
 
 
151
 
152
+ fig, axs = plt.subplots(2, 3, figsize=(20,8))
153
  np.vectorize(lambda ax:ax.axis('off'))(axs)
154
 
155
  axs[0,0].imshow(show_image(img1[0], squeeze = False, denormalize = denorm))
156
  axs[0,1].imshow(overlay_heatmap(img1, gradcam1, denormalize = denorm))
157
  axs[0,1].set_title("Grad-CAM")
158
  axs[0,2].imshow(overlay_heatmap(img1, intcam1_mean, denormalize = denorm))
159
+ axs[0,2].set_title("IntCAM")
 
 
 
 
160
 
161
  axs[1,0].imshow(show_image(img2[0], squeeze = False, denormalize = denorm))
162
  axs[1,1].imshow(overlay_heatmap(img2, gradcam2, denormalize = denorm))
163
  axs[1,2].imshow(overlay_heatmap(img2, intcam2_mean, denormalize = denorm))
 
 
164
 
165
  plt.subplots_adjust(wspace=0.01, hspace = 0.01)
166
  pil_output = fig2img(fig)
167
  return pil_output
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  xai = gr.Blocks()
170
 
171
  with xai:
 
201
  blur_output = gr.Checkbox(value = True, label = "Blur Output")
202
  guided = gr.Checkbox(value = True, label = "Guided Backprop")
203
  avgtransform_button = gr.Button("Get Saliency")
 
 
 
204
 
205
  with gr.Column():
206
  output_image = gr.Image(type='pil', show_label = False)
 
210
  occlusion_button.click(run_occlusion, inputs=[w_size,stride], outputs=output_image)
211
  avgtransform_button.click(get_avg_trasforms, inputs = [transform_type, add_noise, blur_output, guided], outputs = output_image)
212
  cams_button.click(get_cams, inputs = [], outputs = output_image)
 
213
 
214
  xai.launch()
215