Annonymous commited on
Commit
835894d
1 Parent(s): 3d52de0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +283 -0
app.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ import matplotlib
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image
8
+ import cv2
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ from data_transforms import normal_transforms, no_shift_transforms, ig_transforms, modify_transforms
12
+ from utils import overlay_heatmap, viz_map, show_image, deprocess, get_ssl_model, fig2img
13
+ from methods import occlusion, occlusion_context_agnositc, pairwise_occlusion
14
+ from methods import get_difference
15
+ from methods import create_mixed_images, averaged_transforms, sailency, smooth_grad
16
+ from methods import get_sample_dataset, pixel_invariance, get_gradcam, get_interactioncam
17
+
18
+ matplotlib.use('Agg')
19
+
20
+ def load_model(model_name):
21
+
22
+ global network, ssl_model, denorm
23
+ if model_name == "simclrv2 (1X)":
24
+ variant = '1x'
25
+ network = 'simclrv2'
26
+ denorm = False
27
+
28
+ elif model_name == "simclrv2 (2X)":
29
+ variant = '2x'
30
+ network = 'simclrv2'
31
+ denorm = False
32
+
33
+ elif model_name == "Barlow Twins":
34
+ network = 'barlow_twins'
35
+ variant = None
36
+ denorm = True
37
+
38
+ ssl_model = get_ssl_model(network, variant)
39
+
40
+ if network != 'simclrv2':
41
+ global normal_transforms, no_shift_transforms, ig_transforms
42
+ normal_transforms, no_shift_transforms, ig_transforms = modify_transforms(normal_transforms, no_shift_transforms, ig_transforms)
43
+
44
+ return "Loaded Model Successfully"
45
+
46
+ def load_or_augment_images(img1_input, img2_input, use_aug):
47
+
48
+ global img_main, img1, img2
49
+
50
+ img_main = img1_input.convert('RGB')
51
+
52
+ if use_aug:
53
+ img1 = normal_transforms['pure'](img_main).unsqueeze(0).to(device)
54
+ img2 = normal_transforms['aug'](img_main).unsqueeze(0).to(device)
55
+ else:
56
+ img1 = normal_transforms['pure'](img_main).unsqueeze(0).to(device)
57
+ img2 = img2_input.convert('RGB')
58
+ img2 = normal_transforms['pure'](img2).unsqueeze(0).to(device)
59
+
60
+ similarity = "Similarity: {:.3f}".format(nn.CosineSimilarity(dim=-1)(ssl_model(img1), ssl_model(img2)).item())
61
+
62
+ fig, axs = plt.subplots(1, 2, figsize=(10,10))
63
+ np.vectorize(lambda ax:ax.axis('off'))(axs)
64
+
65
+ axs[0].imshow(show_image(img1, denormalize = denorm))
66
+ axs[1].imshow(show_image(img2, denormalize = denorm))
67
+ plt.subplots_adjust(wspace=0.1, hspace = 0)
68
+ pil_output = fig2img(fig)
69
+ return pil_output, similarity
70
+
71
+ def run_occlusion(w_size, stride):
72
+ heatmap1, heatmap2 = occlusion(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32)
73
+ heatmap1_ca, heatmap2_ca = occlusion_context_agnositc(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32)
74
+ heatmap1_po, heatmap2_po = pairwise_occlusion(img1, img2, ssl_model, batch_size = 32, erase_scale = (0.1, 0.3), erase_ratio = (1, 1.5), num_erases = 100)
75
+
76
+ added_image1 = overlay_heatmap(img1, heatmap1, denormalize = denorm)
77
+ added_image2 = overlay_heatmap(img2, heatmap2, denormalize = denorm)
78
+ added_image1_ca = overlay_heatmap(img1, heatmap1_ca, denormalize = denorm)
79
+ added_image2_ca = overlay_heatmap(img2, heatmap2_ca, denormalize = denorm)
80
+
81
+ fig, axs = plt.subplots(2, 4, figsize=(20,10))
82
+ np.vectorize(lambda ax:ax.axis('off'))(axs)
83
+
84
+ axs[0, 0].imshow(show_image(img1, denormalize = denorm))
85
+ axs[0, 1].imshow(added_image1)
86
+ axs[0, 1].set_title("Conditional Occlusion")
87
+ axs[0, 2].imshow(added_image1_ca)
88
+ axs[0, 2].set_title("CA Cond. Occlusion")
89
+ axs[0, 3].imshow((deprocess(img1, denormalize = denorm) * heatmap1_po[:,:,None]).astype('uint8'))
90
+ axs[0, 3].set_title("Pairwise Occlusion")
91
+ axs[1, 0].imshow(show_image(img2, denormalize = denorm))
92
+ axs[1, 1].imshow(added_image2)
93
+ axs[1, 2].imshow(added_image2_ca)
94
+ axs[1, 3].imshow((deprocess(img2, denormalize = denorm) * heatmap2_po[:,:,None]).astype('uint8'))
95
+ plt.subplots_adjust(wspace=0, hspace = 0.01)
96
+ pil_output = fig2img(fig)
97
+ return pil_output
98
+
99
+ def get_model_difference(later):
100
+
101
+ imagenet_images, ssl_images = get_difference(ssl_model = ssl_model, baseline = 'imagenet', image = img2, lr = 1e4,
102
+ l2_weight = 0.1, alpha_weight = 1e-7, alpha_power = 6, tv_weight = 1e-8,
103
+ init_scale = 0.1, network = network)
104
+
105
+ fig, axs = plt.subplots(3, 3, figsize=(10,10))
106
+ np.vectorize(lambda ax:ax.axis('off'))(axs)
107
+
108
+ for aa, (in_img, ssl_img) in enumerate(zip(imagenet_images, ssl_images)):
109
+ axs[aa,0].imshow(deprocess(img2, denormalize = denorm))
110
+ axs[aa,1].imshow(deprocess(in_img))
111
+ axs[aa,2].imshow(deprocess(ssl_img))
112
+
113
+ axs[0,0].set_title("Original Image")
114
+ axs[0,1].set_title("Synthesized (cls)")
115
+ axs[0,2].set_title("Synthesized (contastive)")
116
+
117
+ plt.subplots_adjust(wspace=0.01, hspace = 0.01)
118
+ pil_output = fig2img(fig)
119
+ return pil_output
120
+
121
+ def get_avg_trasforms(transform_type, add_noise, blur_output, guided):
122
+
123
+ mixed_images = create_mixed_images(transform_type = transform_type,
124
+ ig_transforms = ig_transforms,
125
+ step = 0.1,
126
+ img_path = img_main,
127
+ add_noise = add_noise)
128
+
129
+ # vanilla gradients (for comparison purposes)
130
+ sailency1_van, sailency2_van = sailency(guided = guided, ssl_model = ssl_model,
131
+ img1 = mixed_images[0], img2 = mixed_images[-1],
132
+ blur_output = blur_output)
133
+
134
+ # smooth gradients (for comparison purposes)
135
+ sailency1_s, sailency2_s = smooth_grad(guided = guided, ssl_model = ssl_model,
136
+ img1 = mixed_images[0], img2 = mixed_images[-1],
137
+ blur_output = blur_output, steps = 50)
138
+
139
+ # integrated transform
140
+ sailency1, sailency2 = averaged_transforms(guided = guided, ssl_model = ssl_model,
141
+ mixed_images = mixed_images,
142
+ blur_output = blur_output)
143
+
144
+ fig, axs = plt.subplots(2, 4, figsize=(20,10))
145
+ np.vectorize(lambda ax:ax.axis('off'))(axs)
146
+
147
+ axs[0,0].imshow(show_image(mixed_images[0], denormalize = denorm))
148
+ axs[0,1].imshow(show_image(sailency1_van.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet)
149
+ axs[0,1].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5)
150
+ axs[0,1].set_title("Vanilla Gradients")
151
+ axs[0,2].imshow(show_image(sailency1_s.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet)
152
+ axs[0,2].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5)
153
+ axs[0,2].set_title("Smooth Gradients")
154
+ axs[0,3].imshow(show_image(sailency1.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet)
155
+ axs[0,3].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5)
156
+ axs[0,3].set_title("Integrated Transform")
157
+ axs[1,0].imshow(show_image(mixed_images[-1], denormalize = denorm))
158
+ axs[1,1].imshow(show_image(sailency2_van.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet)
159
+ axs[1,1].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5)
160
+ axs[1,2].imshow(show_image(sailency2_s.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet)
161
+ axs[1,2].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5)
162
+ axs[1,3].imshow(show_image(sailency2.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet)
163
+ axs[1,3].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5)
164
+
165
+ plt.subplots_adjust(wspace=0.02, hspace = 0.02)
166
+ pil_output = fig2img(fig)
167
+ return pil_output
168
+
169
+ def get_cams():
170
+
171
+ gradcam1, gradcam2 = get_gradcam(ssl_model, img1, img2)
172
+ intcam1_mean, intcam2_mean = get_interactioncam(ssl_model, img1, img2, reduction = 'mean')
173
+ intcam1_maxmax, intcam2_maxmax = get_interactioncam(ssl_model, img1, img2, reduction = 'max', grad_interact = True)
174
+ intcam1_attnmax, intcam2_attnmax = get_interactioncam(ssl_model, img1, img2, reduction = 'attn', grad_interact = True)
175
+
176
+ fig, axs = plt.subplots(2, 5, figsize=(20,8))
177
+ np.vectorize(lambda ax:ax.axis('off'))(axs)
178
+
179
+ axs[0,0].imshow(show_image(img1[0], squeeze = False, denormalize = denorm))
180
+ axs[0,1].imshow(overlay_heatmap(img1, gradcam1, denormalize = denorm))
181
+ axs[0,1].set_title("Grad-CAM")
182
+ axs[0,2].imshow(overlay_heatmap(img1, intcam1_mean, denormalize = denorm))
183
+ axs[0,2].set_title("IntCAM Mean")
184
+ axs[0,3].imshow(overlay_heatmap(img1, intcam1_maxmax, denormalize = denorm))
185
+ axs[0,3].set_title("IntCAM Max + IntGradMax")
186
+ axs[0,4].imshow(overlay_heatmap(img1, intcam1_attnmax, denormalize = denorm))
187
+ axs[0,4].set_title("IntCAM Attn + IntGradMax")
188
+
189
+ axs[1,0].imshow(show_image(img2[0], squeeze = False, denormalize = denorm))
190
+ axs[1,1].imshow(overlay_heatmap(img2, gradcam2, denormalize = denorm))
191
+ axs[1,2].imshow(overlay_heatmap(img2, intcam2_mean, denormalize = denorm))
192
+ axs[1,3].imshow(overlay_heatmap(img2, intcam2_maxmax, denormalize = denorm))
193
+ axs[1,4].imshow(overlay_heatmap(img2, intcam2_attnmax, denormalize = denorm))
194
+
195
+ plt.subplots_adjust(wspace=0.01, hspace = 0.01)
196
+ pil_output = fig2img(fig)
197
+ return pil_output
198
+
199
+ def get_pixel_invariance():
200
+
201
+ data_samples1, data_samples2, data_labels, labels_invariance = get_sample_dataset(img_path = img_main,
202
+ num_augments = 1000,
203
+ batch_size = 32,
204
+ no_shift_transforms = no_shift_transforms,
205
+ ssl_model = ssl_model,
206
+ n_components = 10)
207
+
208
+ inv_heatmap = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels,
209
+ labels_invariance = labels_invariance, resize_transform = transforms.Resize, size = 64,
210
+ epochs = 1000, learning_rate = 0.1, l1_weight = 0.2, zero_small_values = True,
211
+ blur_output = True, nmf_weight = 0)
212
+
213
+ inv_heatmap_nmf = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels,
214
+ labels_invariance = labels_invariance, resize_transform = transforms.Resize, size = 64,
215
+ epochs = 100, learning_rate = 0.1, l1_weight = 0.2, zero_small_values = True,
216
+ blur_output = True, nmf_weight = 1)
217
+
218
+ fig, axs = plt.subplots(1, 2, figsize=(10,5))
219
+ np.vectorize(lambda ax:ax.axis('off'))(axs)
220
+
221
+ axs[0].imshow(viz_map(img_main, inv_heatmap))
222
+ axs[0].set_title("Heatmap w/o NMF")
223
+ axs[1].imshow(viz_map(img_main, inv_heatmap_nmf))
224
+ axs[1].set_title("Heatmap w/ NMF")
225
+ plt.subplots_adjust(wspace=0.01, hspace = 0.01)
226
+
227
+ pil_output = fig2img(fig)
228
+ return pil_output
229
+
230
+ xai = gr.Blocks()
231
+
232
+ with xai:
233
+ gr.Markdown("<h1>Methods for Explaining Contrastive Learning, CVPR 2023 Submission</h1>")
234
+ gr.Markdown("The interface is simplified as much as possible with only necessary options to select for each method. Please use our Google Colab demo for more flexibility.")
235
+
236
+ with gr.Row():
237
+ model_name = gr.Dropdown(["simclrv2 (1X)", "simclrv2 (2X)", "Barlow Twins"], label="Choose Model and press \"Load Model\"")
238
+ load_model_button = gr.Button("Load Model")
239
+ status_or_similarity = gr.inputs.Textbox(label = "Status")
240
+ with gr.Row():
241
+ gr.Markdown("You can either load two images or load a single image and augment it to get the second image (in that case please check the \"Use Augmentations\" button). After that, please press on \"Show Images\"")
242
+ img1 = gr.Image(type='pil', label = "First Image")
243
+ img2 = gr.Image(type='pil', label = "Second Image")
244
+ with gr.Row():
245
+ use_aug = gr.Checkbox(value = False, label = "Use Augmentations")
246
+ load_images_button = gr.Button("Show Images")
247
+
248
+ gr.Markdown("Choose a method from the different tabs. You may leave the default options as they are and press on \"Run\" ")
249
+ with gr.Row():
250
+ with gr.Column():
251
+ with gr.Tabs():
252
+ with gr.TabItem("Interaction-CAM"):
253
+ cams_button = gr.Button("Get Heatmaps")
254
+ with gr.TabItem("Perturbation Methods"):
255
+ w_size = gr.Number(value = 64, label = "Occlusion Window Size", precision = 0)
256
+ stride = gr.Number(value = 8, label = "Occlusion Stride", precision = 0)
257
+ occlusion_button = gr.Button("Get Heatmap")
258
+ with gr.TabItem("Averaged Transforms"):
259
+ transform_type = gr.inputs.Radio(label="Data Augment", choices=['color_jitter', 'blur', 'grayscale', 'solarize', 'combine'], default="combine")
260
+ add_noise = gr.Checkbox(value = True, label = "Add Noise")
261
+ blur_output = gr.Checkbox(value = True, label = "Blur Output")
262
+ guided = gr.Checkbox(value = True, label = "Guided Backprop")
263
+ avgtransform_button = gr.Button("Get Saliency")
264
+ with gr.TabItem("Pixel Invariance"):
265
+ gr.Markdown("Note: Invariance map will be obtained for the first image")
266
+ pixel_invariance_button = gr.Button("Get Invariance Map")
267
+ with gr.TabItem("Image Synthesization"):
268
+ baseline = gr.inputs.Radio(label="Compare With", choices=["Supervised Classification"], default="Supervised Classification")
269
+ modeldiff_button = gr.Button("Compare")
270
+
271
+ with gr.Column():
272
+ output_image = gr.Image(type='pil', show_label = False)
273
+
274
+ load_model_button.click(load_model, inputs = model_name, outputs = status_or_similarity)
275
+ load_images_button.click(load_or_augment_images, inputs = [img1, img2, use_aug], outputs = [output_image, status_or_similarity])
276
+ occlusion_button.click(run_occlusion, inputs=[w_size,stride], outputs=output_image)
277
+ modeldiff_button.click(get_model_difference, inputs = baseline, outputs = output_image)
278
+ avgtransform_button.click(get_avg_trasforms, inputs = [transform_type, add_noise, blur_output, guided], outputs = output_image)
279
+ cams_button.click(get_cams, inputs = [], outputs = output_image)
280
+ pixel_invariance_button.click(get_pixel_invariance, inputs = [], outputs = output_image)
281
+
282
+ xai.launch()
283
+