Spaces:
Sleeping
Sleeping
Commit
•
2a82b5d
1
Parent(s):
2f9b167
Update app.py
Browse files
app.py
CHANGED
@@ -13,7 +13,7 @@ from data_transforms import normal_transforms, no_shift_transforms, ig_transform
|
|
13 |
from utils import overlay_heatmap, viz_map, show_image, deprocess, get_ssl_model, fig2img
|
14 |
from methods import occlusion, occlusion_context_agnositc, pairwise_occlusion
|
15 |
from methods import create_mixed_images, averaged_transforms, sailency, smooth_grad
|
16 |
-
from methods import
|
17 |
|
18 |
matplotlib.use('Agg')
|
19 |
|
@@ -148,63 +148,24 @@ def get_cams():
|
|
148 |
|
149 |
gradcam1, gradcam2 = get_gradcam(ssl_model, img1, img2)
|
150 |
intcam1_mean, intcam2_mean = get_interactioncam(ssl_model, img1, img2, reduction = 'mean')
|
151 |
-
intcam1_maxmax, intcam2_maxmax = get_interactioncam(ssl_model, img1, img2, reduction = 'max', grad_interact = True)
|
152 |
-
intcam1_attnmax, intcam2_attnmax = get_interactioncam(ssl_model, img1, img2, reduction = 'attn', grad_interact = True)
|
153 |
|
154 |
-
fig, axs = plt.subplots(2,
|
155 |
np.vectorize(lambda ax:ax.axis('off'))(axs)
|
156 |
|
157 |
axs[0,0].imshow(show_image(img1[0], squeeze = False, denormalize = denorm))
|
158 |
axs[0,1].imshow(overlay_heatmap(img1, gradcam1, denormalize = denorm))
|
159 |
axs[0,1].set_title("Grad-CAM")
|
160 |
axs[0,2].imshow(overlay_heatmap(img1, intcam1_mean, denormalize = denorm))
|
161 |
-
axs[0,2].set_title("IntCAM
|
162 |
-
axs[0,3].imshow(overlay_heatmap(img1, intcam1_maxmax, denormalize = denorm))
|
163 |
-
axs[0,3].set_title("IntCAM Max + IntGradMax")
|
164 |
-
axs[0,4].imshow(overlay_heatmap(img1, intcam1_attnmax, denormalize = denorm))
|
165 |
-
axs[0,4].set_title("IntCAM Attn + IntGradMax")
|
166 |
|
167 |
axs[1,0].imshow(show_image(img2[0], squeeze = False, denormalize = denorm))
|
168 |
axs[1,1].imshow(overlay_heatmap(img2, gradcam2, denormalize = denorm))
|
169 |
axs[1,2].imshow(overlay_heatmap(img2, intcam2_mean, denormalize = denorm))
|
170 |
-
axs[1,3].imshow(overlay_heatmap(img2, intcam2_maxmax, denormalize = denorm))
|
171 |
-
axs[1,4].imshow(overlay_heatmap(img2, intcam2_attnmax, denormalize = denorm))
|
172 |
|
173 |
plt.subplots_adjust(wspace=0.01, hspace = 0.01)
|
174 |
pil_output = fig2img(fig)
|
175 |
return pil_output
|
176 |
|
177 |
-
def get_pixel_invariance():
|
178 |
-
|
179 |
-
data_samples1, data_samples2, data_labels, labels_invariance = get_sample_dataset(img_path = img_main,
|
180 |
-
num_augments = 1000,
|
181 |
-
batch_size = 32,
|
182 |
-
no_shift_transforms = no_shift_transforms,
|
183 |
-
ssl_model = ssl_model,
|
184 |
-
n_components = 10)
|
185 |
-
|
186 |
-
inv_heatmap = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels,
|
187 |
-
labels_invariance = labels_invariance, resize_transform = transforms.Resize, size = 64,
|
188 |
-
epochs = 1000, learning_rate = 0.1, l1_weight = 0.2, zero_small_values = True,
|
189 |
-
blur_output = True, nmf_weight = 0)
|
190 |
-
|
191 |
-
inv_heatmap_nmf = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels,
|
192 |
-
labels_invariance = labels_invariance, resize_transform = transforms.Resize, size = 64,
|
193 |
-
epochs = 100, learning_rate = 0.1, l1_weight = 0.2, zero_small_values = True,
|
194 |
-
blur_output = True, nmf_weight = 1)
|
195 |
-
|
196 |
-
fig, axs = plt.subplots(1, 2, figsize=(10,5))
|
197 |
-
np.vectorize(lambda ax:ax.axis('off'))(axs)
|
198 |
-
|
199 |
-
axs[0].imshow(viz_map(img_main, inv_heatmap))
|
200 |
-
axs[0].set_title("Heatmap w/o NMF")
|
201 |
-
axs[1].imshow(viz_map(img_main, inv_heatmap_nmf))
|
202 |
-
axs[1].set_title("Heatmap w/ NMF")
|
203 |
-
plt.subplots_adjust(wspace=0.01, hspace = 0.01)
|
204 |
-
|
205 |
-
pil_output = fig2img(fig)
|
206 |
-
return pil_output
|
207 |
-
|
208 |
xai = gr.Blocks()
|
209 |
|
210 |
with xai:
|
@@ -240,9 +201,6 @@ with xai:
|
|
240 |
blur_output = gr.Checkbox(value = True, label = "Blur Output")
|
241 |
guided = gr.Checkbox(value = True, label = "Guided Backprop")
|
242 |
avgtransform_button = gr.Button("Get Saliency")
|
243 |
-
with gr.TabItem("Pixel Invariance"):
|
244 |
-
gr.Markdown("Note: Invariance map will be obtained for the first image")
|
245 |
-
pixel_invariance_button = gr.Button("Get Invariance Map")
|
246 |
|
247 |
with gr.Column():
|
248 |
output_image = gr.Image(type='pil', show_label = False)
|
@@ -252,7 +210,6 @@ with xai:
|
|
252 |
occlusion_button.click(run_occlusion, inputs=[w_size,stride], outputs=output_image)
|
253 |
avgtransform_button.click(get_avg_trasforms, inputs = [transform_type, add_noise, blur_output, guided], outputs = output_image)
|
254 |
cams_button.click(get_cams, inputs = [], outputs = output_image)
|
255 |
-
pixel_invariance_button.click(get_pixel_invariance, inputs = [], outputs = output_image)
|
256 |
|
257 |
xai.launch()
|
258 |
|
|
|
13 |
from utils import overlay_heatmap, viz_map, show_image, deprocess, get_ssl_model, fig2img
|
14 |
from methods import occlusion, occlusion_context_agnositc, pairwise_occlusion
|
15 |
from methods import create_mixed_images, averaged_transforms, sailency, smooth_grad
|
16 |
+
from methods import get_gradcam, get_interactioncam
|
17 |
|
18 |
matplotlib.use('Agg')
|
19 |
|
|
|
148 |
|
149 |
gradcam1, gradcam2 = get_gradcam(ssl_model, img1, img2)
|
150 |
intcam1_mean, intcam2_mean = get_interactioncam(ssl_model, img1, img2, reduction = 'mean')
|
|
|
|
|
151 |
|
152 |
+
fig, axs = plt.subplots(2, 3, figsize=(20,8))
|
153 |
np.vectorize(lambda ax:ax.axis('off'))(axs)
|
154 |
|
155 |
axs[0,0].imshow(show_image(img1[0], squeeze = False, denormalize = denorm))
|
156 |
axs[0,1].imshow(overlay_heatmap(img1, gradcam1, denormalize = denorm))
|
157 |
axs[0,1].set_title("Grad-CAM")
|
158 |
axs[0,2].imshow(overlay_heatmap(img1, intcam1_mean, denormalize = denorm))
|
159 |
+
axs[0,2].set_title("IntCAM")
|
|
|
|
|
|
|
|
|
160 |
|
161 |
axs[1,0].imshow(show_image(img2[0], squeeze = False, denormalize = denorm))
|
162 |
axs[1,1].imshow(overlay_heatmap(img2, gradcam2, denormalize = denorm))
|
163 |
axs[1,2].imshow(overlay_heatmap(img2, intcam2_mean, denormalize = denorm))
|
|
|
|
|
164 |
|
165 |
plt.subplots_adjust(wspace=0.01, hspace = 0.01)
|
166 |
pil_output = fig2img(fig)
|
167 |
return pil_output
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
xai = gr.Blocks()
|
170 |
|
171 |
with xai:
|
|
|
201 |
blur_output = gr.Checkbox(value = True, label = "Blur Output")
|
202 |
guided = gr.Checkbox(value = True, label = "Guided Backprop")
|
203 |
avgtransform_button = gr.Button("Get Saliency")
|
|
|
|
|
|
|
204 |
|
205 |
with gr.Column():
|
206 |
output_image = gr.Image(type='pil', show_label = False)
|
|
|
210 |
occlusion_button.click(run_occlusion, inputs=[w_size,stride], outputs=output_image)
|
211 |
avgtransform_button.click(get_avg_trasforms, inputs = [transform_type, add_noise, blur_output, guided], outputs = output_image)
|
212 |
cams_button.click(get_cams, inputs = [], outputs = output_image)
|
|
|
213 |
|
214 |
xai.launch()
|
215 |
|