Annonymous commited on
Commit
cddd431
1 Parent(s): feafd17

Upload 3 files

Browse files
Files changed (3) hide show
  1. data_transforms.py +96 -0
  2. methods.py +578 -0
  3. utils.py +101 -0
data_transforms.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torchvision.transforms as transforms
4
+ import torch.nn as nn
5
+ from PIL import Image, ImageOps, ImageFilter
6
+ import random
7
+
8
+ def add_normalization_to_transform(unnormalized_transforms):
9
+ """Adds ImageNet normalization to all transforms"""
10
+ normalized_transform = {}
11
+ for key, value in unnormalized_transforms.items():
12
+ normalized_transform[key] = transforms.Compose([value,
13
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
14
+ std=[0.229, 0.224, 0.225])])
15
+ return normalized_transform
16
+
17
+ def modify_transforms(normal_transforms, no_shift_transforms, ig_transforms):
18
+ normal_transforms = add_normalization_to_transform(normal_transforms)
19
+ no_shift_transforms = add_normalization_to_transform(no_shift_transforms)
20
+ ig_transforms = add_normalization_to_transform(ig_transforms)
21
+ return normal_transforms, no_shift_transforms, ig_transforms
22
+
23
+ class Solarization(object):
24
+ def __init__(self, p):
25
+ self.p = p
26
+
27
+ def __call__(self, img):
28
+ if random.random() < self.p:
29
+ return ImageOps.solarize(img)
30
+ else:
31
+ return img
32
+
33
+ # no imagent normalization for simclrv2
34
+ pure_transform = transforms.Compose([transforms.Resize(256),
35
+ transforms.CenterCrop(224),
36
+ transforms.ToTensor()])
37
+
38
+ aug_transform = transforms.Compose([transforms.RandomResizedCrop(224),
39
+ transforms.RandomHorizontalFlip(p=0.5),
40
+ transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
41
+ transforms.RandomGrayscale(p=0.2),
42
+ transforms.RandomApply([transforms.GaussianBlur(kernel_size=(21,21), sigma=(0.1,2.0))], p=0.5),
43
+ transforms.ToTensor()])
44
+
45
+ ig_pure_transform = transforms.Compose([transforms.Resize(256),
46
+ transforms.CenterCrop(224),
47
+ transforms.ToTensor()])
48
+
49
+ ig_transform_colorjitter = transforms.Compose([transforms.Resize(256),
50
+ transforms.CenterCrop(224),
51
+ transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.4)], p=1),
52
+ transforms.ToTensor()])
53
+
54
+ ig_transform_blur = transforms.Compose([transforms.Resize(256),
55
+ transforms.CenterCrop(224),
56
+ transforms.RandomApply([transforms.GaussianBlur(kernel_size=(11,11), sigma=(5,5))], p=1),
57
+ transforms.ToTensor()])
58
+
59
+ ig_transform_solarize = transforms.Compose([transforms.Resize(256),
60
+ transforms.CenterCrop(224),
61
+ Solarization(p=1.0),
62
+ transforms.ToTensor()])
63
+
64
+ ig_transform_grayscale = transforms.Compose([transforms.Resize(256),
65
+ transforms.CenterCrop(224),
66
+ transforms.RandomGrayscale(p=1),
67
+ transforms.ToTensor()])
68
+
69
+
70
+ ig_transform_combine = transforms.Compose([transforms.Resize(256),
71
+ transforms.CenterCrop(224),
72
+ transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
73
+ transforms.RandomGrayscale(p=0.2),
74
+ transforms.RandomApply([transforms.GaussianBlur(kernel_size=(21,21), sigma=(0.1, 2.0))], p=0.5),
75
+ transforms.ToTensor()])
76
+
77
+ pure_transform_no_shift = transforms.Compose([transforms.Resize((224, 224)),
78
+ transforms.ToTensor()])
79
+
80
+ aug_transform_no_shift = transforms.Compose([transforms.Resize((224, 224)),
81
+ transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
82
+ transforms.RandomGrayscale(p=0.2),
83
+ transforms.ToTensor()])
84
+
85
+ normal_transforms = {'pure': pure_transform,
86
+ 'aug': aug_transform}
87
+
88
+ no_shift_transforms = {'pure': pure_transform_no_shift,
89
+ 'aug': aug_transform_no_shift}
90
+
91
+ ig_transforms = {'pure': ig_pure_transform,
92
+ 'color_jitter': ig_transform_colorjitter,
93
+ 'blur': ig_transform_blur,
94
+ 'grayscale': ig_transform_grayscale,
95
+ 'solarize': ig_transform_solarize,
96
+ 'combine': ig_transform_combine}
methods.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as transforms
6
+ import torchvision
7
+ from PIL import Image
8
+ from sklearn.decomposition import NMF
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ def relu_hook_function(module, grad_in, grad_out):
12
+ if isinstance(module, nn.ReLU):
13
+ return (F.relu(grad_in[0]),)
14
+
15
+ def blur_sailency(input_image):
16
+ return torchvision.transforms.functional.gaussian_blur(input_image, kernel_size=[11, 11], sigma=[5,5])
17
+
18
+ def occlusion(img1, img2, model, w_size = 64, stride = 8, batch_size = 32):
19
+
20
+ measure = nn.CosineSimilarity(dim=-1)
21
+ output_size = int(((img2.size(-1) - w_size) / stride) + 1)
22
+ out1_condition, out2_condition = model(img1), model(img2)
23
+ images1 = []
24
+ images2 = []
25
+
26
+ for i in range(output_size):
27
+ for j in range(output_size):
28
+ start_i, start_j = i * stride, j * stride
29
+ image1 = img1.clone().detach()
30
+ image2 = img2.clone().detach()
31
+ image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0
32
+ image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0
33
+ images1.append(image1)
34
+ images2.append(image2)
35
+
36
+ images1 = torch.cat(images1, dim=0).to(device)
37
+ images2 = torch.cat(images2, dim=0).to(device)
38
+
39
+ score_map1 = []
40
+ score_map2 = []
41
+
42
+ assert images1.shape[0] == images2.shape[0]
43
+
44
+ for b in range(0, images2.shape[0], batch_size):
45
+
46
+ with torch.no_grad():
47
+ out1 = model(images1[b : b + batch_size, :])
48
+ out2 = model(images2[b : b + batch_size, :])
49
+
50
+ score_map1.append(measure(out1, out2_condition)) # try torch.mm(out2_condition, out1.t())[0]
51
+ score_map2.append(measure(out1_condition, out2)) # try torch.mm(out1_condition, out2.t())[0]
52
+
53
+ score_map1 = torch.cat(score_map1, dim = 0)
54
+ score_map2 = torch.cat(score_map2, dim = 0)
55
+ assert images2.shape[0] == score_map2.shape[0] == score_map1.shape[0]
56
+
57
+ heatmap1 = score_map1.view(output_size, output_size).cpu().detach().numpy()
58
+ heatmap2 = score_map2.view(output_size, output_size).cpu().detach().numpy()
59
+ base_score = measure(out1_condition, out2_condition)
60
+
61
+ heatmap1 = (heatmap1 - base_score.item()) * -1 # or base_score.item() - heatmap1. The higher the drop, the better
62
+ heatmap2 = (heatmap2 - base_score.item()) * -1 # or base_score.item() - heatmap2. The higher the drop, the better
63
+
64
+ return heatmap1, heatmap2
65
+
66
+ def occlusion_context_agnositc(img1, img2, model, w_size = 64, stride = 8, batch_size = 32):
67
+
68
+ measure = nn.CosineSimilarity(dim=-1)
69
+ output_size = int(((img2.size(-1) - w_size) / stride) + 1)
70
+ out1_condition, out2_condition = model(img1), model(img2)
71
+
72
+ images1_occlude_mask = []
73
+ images2_occlude_mask = []
74
+
75
+ for i in range(output_size):
76
+ for j in range(output_size):
77
+ start_i, start_j = i * stride, j * stride
78
+ image1 = img1.clone().detach()
79
+ image2 = img2.clone().detach()
80
+ image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0
81
+ image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0
82
+ images1_occlude_mask.append(image1)
83
+ images2_occlude_mask.append(image2)
84
+
85
+ images1_occlude_mask = torch.cat(images1_occlude_mask, dim=0).to(device)
86
+ images2_occlude_mask = torch.cat(images2_occlude_mask, dim=0).to(device)
87
+
88
+ images1_occlude_backround = []
89
+ images2_occlude_backround = []
90
+
91
+ copy_img1 = img1.clone().detach()
92
+ copy_img2 = img2.clone().detach()
93
+
94
+ for i in range(output_size):
95
+ for j in range(output_size):
96
+ start_i, start_j = i * stride, j * stride
97
+
98
+ image1 = torch.zeros_like(img1)
99
+ image2 = torch.zeros_like(img2)
100
+
101
+ image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = copy_img1[:, :, start_i : start_i + w_size, start_j : start_j + w_size]
102
+ image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = copy_img2[:, :, start_i : start_i + w_size, start_j : start_j + w_size]
103
+
104
+ images1_occlude_backround.append(image1)
105
+ images2_occlude_backround.append(image2)
106
+
107
+ images1_occlude_backround = torch.cat(images1_occlude_backround, dim=0).to(device)
108
+ images2_occlude_backround = torch.cat(images2_occlude_backround, dim=0).to(device)
109
+
110
+ score_map1 = []
111
+ score_map2 = []
112
+
113
+ assert images1_occlude_mask.shape[0] == images2_occlude_mask.shape[0]
114
+
115
+ for b in range(0, images1_occlude_mask.shape[0], batch_size):
116
+
117
+ with torch.no_grad():
118
+ out1_mask = model(images1_occlude_mask[b : b + batch_size, :])
119
+ out2_mask = model(images2_occlude_mask[b : b + batch_size, :])
120
+ out1_backround = model(images1_occlude_backround[b : b + batch_size, :])
121
+ out2_backround = model(images2_occlude_backround[b : b + batch_size, :])
122
+
123
+ out1 = out1_backround - out1_mask
124
+ out2 = out2_backround - out2_mask
125
+ score_map1.append(measure(out1, out2_condition)) # or torch.mm(out2_condition, out1.t())[0]
126
+ score_map2.append(measure(out1_condition, out2)) # or torch.mm(out1_condition, out2.t())[0]
127
+
128
+ score_map1 = torch.cat(score_map1, dim = 0)
129
+ score_map2 = torch.cat(score_map2, dim = 0)
130
+ assert images1_occlude_mask.shape[0] == images2_occlude_mask.shape[0] == score_map2.shape[0] == score_map1.shape[0]
131
+
132
+ heatmap1 = score_map1.view(output_size, output_size).cpu().detach().numpy()
133
+ heatmap2 = score_map2.view(output_size, output_size).cpu().detach().numpy()
134
+
135
+ heatmap1 = (heatmap1 - heatmap1.min()) / (heatmap1.max() - heatmap1.min())
136
+ heatmap2 = (heatmap2 - heatmap2.min()) / (heatmap2.max() - heatmap2.min())
137
+
138
+ return heatmap1, heatmap2
139
+
140
+ def pairwise_occlusion(img1, img2, model, batch_size, erase_scale, erase_ratio, num_erases):
141
+
142
+ measure = nn.CosineSimilarity(dim=-1)
143
+ out1_condition, out2_condition = model(img1), model(img2)
144
+ baseline = measure(out1_condition, out2_condition).detach()
145
+ # a bit sensitive to scale and ratio. erase_scale is from (scale[0] * 100) % to (scale[1] * 100) %
146
+ random_erase = transforms.RandomErasing(p=1.0, scale=erase_scale, ratio=erase_ratio)
147
+
148
+ image1 = img1.clone().detach()
149
+ image2 = img2.clone().detach()
150
+ images1 = []
151
+ images2 = []
152
+
153
+ for _ in range(num_erases):
154
+ images1.append(random_erase(image1))
155
+ images2.append(random_erase(image2))
156
+
157
+ images1 = torch.cat(images1, dim=0).to(device)
158
+ images2 = torch.cat(images2, dim=0).to(device)
159
+
160
+ sims = []
161
+ weights1 = []
162
+ weights2 = []
163
+
164
+ for b in range(0, images2.shape[0], batch_size):
165
+
166
+ with torch.no_grad():
167
+ out1 = model(images1[b : b + batch_size, :])
168
+ out2 = model(images2[b : b + batch_size, :])
169
+ sims.append(measure(out1, out2))
170
+ weights1.append(out1.norm(dim=-1))
171
+ weights2.append(out2.norm(dim=-1))
172
+
173
+ sims = torch.cat(sims, dim = 0)
174
+ weights1, weights2 = torch.cat(weights1, dim = 0).cpu().numpy(), torch.cat(weights2, dim = 0).cpu().numpy()
175
+ weights = list(zip(weights1, weights2))
176
+ sims = baseline - sims # the higher the drop, the better
177
+ sims = F.softmax(sims, dim = -1)
178
+ sims = sims.cpu().numpy()
179
+
180
+ assert sims.shape[0] == images1.shape[0] == images2.shape[0]
181
+ A1 = np.zeros((224, 224))
182
+ A2 = np.zeros((224, 224))
183
+
184
+ for n in range(images1.shape[0]):
185
+
186
+ im1_2d = images1[n].cpu().numpy().transpose((1, 2, 0)).sum(axis=-1)
187
+ im2_2d = images2[n].cpu().numpy().transpose((1, 2, 0)).sum(axis=-1)
188
+
189
+ joint_similarity = sims[n]
190
+ weight = weights[n]
191
+
192
+ if weight[0] < weight[1]:
193
+ A1[im1_2d == 0] += joint_similarity
194
+ else:
195
+ A2[im2_2d == 0] += joint_similarity
196
+
197
+ A1 = A1 / (np.max(A1) + 1e-9)
198
+ A2 = A2 / (np.max(A2) + 1e-9)
199
+
200
+ return A1, A2
201
+
202
+ def tv_reg(img, l1 = True):
203
+
204
+ diff_i = (img[:, :, :, 1:] - img[:, :, :, :-1])
205
+ diff_j = (img[:, :, 1:, :] - img[:, :, :-1, :])
206
+
207
+ if l1:
208
+ return diff_i.abs().sum() + diff_j.abs().sum()
209
+ else:
210
+ return diff_i.pow(2).sum() + diff_j.pow(2).sum()
211
+
212
+
213
+ def synthesize(ssl_model, model_type, img1, img_cls_layer, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network):
214
+
215
+ if model_type == 'imagenet':
216
+ reduce_lr = False
217
+ model = torchvision.models.resnet50(pretrained=True)
218
+ model = list(model.children())[:img_cls_layer]
219
+ model = nn.Sequential(*model).to(device)
220
+ model.eval()
221
+ else:
222
+ reduce_lr = True
223
+ shift_layer = 3 if network == 'simclrv2' else 0
224
+ equivalent_layer = img_cls_layer - shift_layer
225
+ model = list(ssl_model.encoder.net.children())[:equivalent_layer]
226
+ model = nn.Sequential(*model).to(device)
227
+ model.eval()
228
+
229
+ opt_img = (init_scale * torch.randn(1, 3, 224, 224)).to(device).requires_grad_()
230
+ target_feats = model(img1).detach()
231
+ optimizer = torch.optim.SGD([opt_img], lr=lr, momentum=0.9)
232
+
233
+ for i in range(201):
234
+ opt_img.data = opt_img.data.clip(0,1)
235
+ optimizer.zero_grad()
236
+ output = model(opt_img)
237
+ l2_loss = l2_weight * ((output - target_feats) ** 2).sum() / (target_feats ** 2).sum()
238
+ reg_alpha = alpha_weight * (opt_img ** alpha_power).sum()
239
+ reg_total_variation = tv_weight * tv_reg(opt_img, l1 = False)
240
+ loss = l2_loss + reg_alpha + reg_total_variation
241
+ loss.backward()
242
+ optimizer.step()
243
+
244
+ if reduce_lr and i % 40 == 0:
245
+ for param_group in optimizer.param_groups:
246
+ param_group['lr'] *= 1/10
247
+
248
+ return opt_img
249
+
250
+ def get_difference(ssl_model, baseline, image, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network):
251
+
252
+ imagenet_images = []
253
+ ssl_images = []
254
+
255
+ for lay in range(4,7):
256
+ image_net_image = synthesize(ssl_model, baseline, image, lay, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network).detach().clone()
257
+ ssl_image = synthesize(ssl_model, 'ssl', image, lay, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network).detach().clone()
258
+ imagenet_images.append(image_net_image)
259
+ ssl_images.append(ssl_image)
260
+
261
+ return imagenet_images, ssl_images
262
+
263
+ def create_mixed_images(transform_type, ig_transforms, step, img_path, add_noise):
264
+
265
+ img = Image.open(img_path).convert('RGB')
266
+ img1 = ig_transforms['pure'](img).unsqueeze(0).to(device)
267
+ img2 = ig_transforms[transform_type](img).unsqueeze(0).to(device)
268
+
269
+ lambdas = np.arange(1,0,-step)
270
+ mixed_images = []
271
+ for l,lam in enumerate(lambdas):
272
+ mixed_img = lam * img1 + (1 - lam) * img2
273
+ mixed_images.append(mixed_img)
274
+
275
+ if add_noise:
276
+ sigma = 0.15 / (torch.max(img1) - torch.min(img1)).item()
277
+ mixed_images = [im + torch.zeros_like(im).normal_(0, sigma) if (n>0) and (n<len(mixed_images)-1) else im for n,im in enumerate(mixed_images)]
278
+
279
+ return mixed_images
280
+
281
+ def averaged_transforms(guided, ssl_model, mixed_images, blur_output):
282
+
283
+ measure = nn.CosineSimilarity(dim=-1)
284
+
285
+ if guided:
286
+ handles = []
287
+ for i, module in enumerate(ssl_model.modules()):
288
+ if isinstance(module, nn.ReLU):
289
+ handles.append(module.register_backward_hook(relu_hook_function))
290
+
291
+ grads1 = []
292
+ grads2 = []
293
+
294
+ for xbar_image in mixed_images[1:]:
295
+ input_image1 = mixed_images[0].clone().requires_grad_()
296
+ input_image2 = xbar_image.clone().requires_grad_()
297
+
298
+ if input_image1.grad is not None:
299
+ input_image1.grad.data.zero_()
300
+ input_image2.grad.data.zero_()
301
+
302
+ score = measure(ssl_model(input_image1), ssl_model(input_image2))
303
+ score.backward()
304
+ grads1.append(input_image1.grad.data)
305
+ grads2.append(input_image2.grad.data)
306
+
307
+ grads1 = torch.cat(grads1).mean(0).unsqueeze(0)
308
+ grads2 = torch.cat(grads2).mean(0).unsqueeze(0)
309
+
310
+ sailency1, _ = torch.max((mixed_images[0] * grads1).abs(), dim=1)
311
+ sailency2, _ = torch.max((mixed_images[-1] * grads2).abs(), dim=1)
312
+
313
+ if guided: # remove handles after finishing
314
+ for handle in handles:
315
+ handle.remove()
316
+
317
+ if blur_output:
318
+ sailency1 = blur_sailency(sailency1)
319
+ sailency2 = blur_sailency(sailency2)
320
+
321
+ return sailency1, sailency2
322
+
323
+ def sailency(guided, ssl_model, img1, img2, blur_output):
324
+
325
+ measure = nn.CosineSimilarity(dim=-1)
326
+
327
+ if guided:
328
+ handles = []
329
+ for i, module in enumerate(ssl_model.modules()):
330
+ if isinstance(module, nn.ReLU):
331
+ handles.append(module.register_backward_hook(relu_hook_function))
332
+
333
+ input_image1 = img1.clone().requires_grad_()
334
+ input_image2 = img2.clone().requires_grad_()
335
+ score = measure(ssl_model(input_image1), ssl_model(input_image2))
336
+ score.backward()
337
+ grads1 = input_image1.grad.data
338
+ grads2 = input_image2.grad.data
339
+ sailency1, _ = torch.max((img1 * grads1).abs(), dim=1)
340
+ sailency2, _ = torch.max((img2 * grads2).abs(), dim=1)
341
+
342
+ if guided: # remove handles after finishing
343
+ for handle in handles:
344
+ handle.remove()
345
+
346
+ if blur_output:
347
+ sailency1 = blur_sailency(sailency1)
348
+ sailency2 = blur_sailency(sailency2)
349
+
350
+ return sailency1, sailency2
351
+
352
+ def smooth_grad(guided, ssl_model, img1, img2, blur_output, steps = 50):
353
+
354
+ measure = nn.CosineSimilarity(dim=-1)
355
+ sigma = 0.15 / (torch.max(img1) - torch.min(img1)).item()
356
+
357
+ if guided:
358
+ handles = []
359
+ for i, module in enumerate(ssl_model.modules()):
360
+ if isinstance(module, nn.ReLU):
361
+ handles.append(module.register_backward_hook(relu_hook_function))
362
+
363
+ noise_images1 = []
364
+ noise_images2 = []
365
+
366
+ for _ in range(steps):
367
+ noise = torch.zeros_like(img1).normal_(0, sigma)
368
+ noise_images1.append(img1 + noise)
369
+ noise_images2.append(img2 + noise)
370
+
371
+ grads1 = []
372
+ grads2 = []
373
+
374
+ for n1, n2 in zip(noise_images1, noise_images2):
375
+ input_image1 = n1.clone().requires_grad_()
376
+ input_image2 = n2.clone().requires_grad_()
377
+
378
+ if input_image1.grad is not None:
379
+ input_image1.grad.data.zero_()
380
+ input_image2.grad.data.zero_()
381
+
382
+ score = measure(ssl_model(input_image1), ssl_model(input_image2))
383
+ score.backward()
384
+ grads1.append(input_image1.grad.data)
385
+ grads2.append(input_image2.grad.data)
386
+
387
+ grads1 = torch.cat(grads1).mean(0).unsqueeze(0)
388
+ grads2 = torch.cat(grads2).mean(0).unsqueeze(0)
389
+ sailency1, _ = torch.max((img1 * grads1 ).abs(), dim=1)
390
+ sailency2, _ = torch.max((img2 * grads2).abs(), dim=1)
391
+
392
+ if guided: # remove handles after finishing
393
+ for handle in handles:
394
+ handle.remove()
395
+
396
+ if blur_output:
397
+ sailency1 = blur_sailency(sailency1)
398
+ sailency2 = blur_sailency(sailency2)
399
+
400
+ return sailency1, sailency2
401
+
402
+ def get_sample_dataset(img_path, num_augments, batch_size, no_shift_transforms, ssl_model, n_components):
403
+
404
+ measure = nn.CosineSimilarity(dim=-1)
405
+ img = Image.open(img_path).convert('RGB')
406
+ no_shift_aug = transforms.Compose([no_shift_transforms['aug'],
407
+ transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3))])
408
+
409
+ augments2 = [no_shift_aug(img).unsqueeze(0) for _ in range(num_augments)]
410
+ data_samples1 = no_shift_transforms['pure'](img).unsqueeze(0).expand(num_augments, -1, -1, -1).to(device)
411
+ data_samples2 = torch.cat(augments2).to(device)
412
+
413
+ labels = []
414
+ feats_invariance = []
415
+
416
+ for b in range(0, data_samples1.shape[0], batch_size):
417
+
418
+ with torch.no_grad():
419
+ out1 = ssl_model(data_samples1[b : b + batch_size, :])
420
+ out2 = ssl_model(data_samples2[b : b + batch_size, :])
421
+ labels.append(measure(out1, out2))
422
+ feats_invariance.append(F.relu(out2))
423
+
424
+ data_labels = torch.cat(labels).unsqueeze(-1).to(device)
425
+ feats_invariance = torch.cat(feats_invariance).to(device)
426
+ nmf_model = NMF(n_components=n_components, init='random')
427
+ # (T, 2048) = W.H = (2048,N) . (N,T), where H is the matrix representing the features of each transform
428
+ H = nmf_model.fit_transform(feats_invariance.cpu().numpy())
429
+ labels_invariance = torch.from_numpy(H.mean(1)).unsqueeze(-1).to(device)
430
+
431
+ return data_samples1, data_samples2, data_labels, labels_invariance
432
+
433
+ def pixel_invariance(data_samples1, data_samples2, data_labels, labels_invariance, resize_transform, size, epochs, learning_rate, l1_weight, zero_small_values, blur_output, nmf_weight):
434
+
435
+ """
436
+ size: resize the image to that when training the surrogate. Later we upsize
437
+ epochs: number of epochs to train the surrogate model
438
+ learning_rate: learning rate to train the surrogate model
439
+ l1_weight: if not None, enables l1 regularization (sparsity)
440
+ """
441
+ x1 = resize_transform((size, size))(data_samples1) # (num_samples, 3, size, size)
442
+ x2 = resize_transform((size, size))(data_samples2) # (num_samples, 3, size, size)
443
+
444
+ x1 = x1.reshape(x1.size(0), -1).to(device)
445
+ x2 = x2.reshape(x2.size(0), -1).to(device)
446
+
447
+ surrogate = nn.Linear(size * size * 3, 1).to(device)
448
+
449
+ criterion = nn.BCEWithLogitsLoss(reduction = 'sum')
450
+ invariance_criterion = nn.MSELoss()
451
+ optimizer = torch.optim.SGD(surrogate.parameters(), lr=learning_rate)
452
+
453
+ for epoch in range(epochs):
454
+ pred1, pred2 = surrogate(x1), surrogate(x2)
455
+ preds = (pred1 + pred2) / 2
456
+ loss = criterion(preds, data_labels)
457
+ loss += nmf_weight * invariance_criterion(torch.sigmoid(preds), labels_invariance)
458
+
459
+ if l1_weight is not None:
460
+ loss += l1_weight * sum(p.abs().sum() for p in surrogate.parameters())
461
+
462
+ optimizer.zero_grad()
463
+ loss.backward()
464
+ optimizer.step()
465
+
466
+ heatmap = surrogate.weight.reshape(3, size, size)
467
+ heatmap, _ = torch.max(heatmap, 0)
468
+ heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
469
+
470
+ if zero_small_values:
471
+ heatmap[heatmap < 0.5] = 0
472
+
473
+ if blur_output:
474
+ heatmap = blur_sailency(heatmap.unsqueeze(0)).squeeze(0)
475
+
476
+ return heatmap
477
+
478
+ class GradCAM(nn.Module):
479
+
480
+ def __init__(self, ssl_model):
481
+ super(GradCAM, self).__init__()
482
+
483
+ self.gradients = {}
484
+ self.features = {}
485
+
486
+ self.feature_extractor = ssl_model.encoder.net
487
+ self.contrastive_head = ssl_model.contrastive_head
488
+ self.measure = nn.CosineSimilarity(dim=-1)
489
+
490
+ def save_grads(self, img_index):
491
+
492
+ def hook(grad):
493
+ self.gradients[img_index] = grad.detach()
494
+
495
+ return hook
496
+
497
+ def save_features(self, img_index, feats):
498
+ self.features[img_index] = feats.detach()
499
+
500
+ def forward(self, img1, img2):
501
+
502
+ features1 = self.feature_extractor(img1)
503
+ features2 = self.feature_extractor(img2)
504
+
505
+ self.save_features('1', features1)
506
+ self.save_features('2', features2)
507
+
508
+ h1 = features1.register_hook(self.save_grads('1'))
509
+ h2 = features2.register_hook(self.save_grads('2'))
510
+
511
+ out1, out2 = features1.mean(dim=[2, 3]), features2.mean(dim=[2, 3])
512
+ out1, out2 = self.contrastive_head(out1), self.contrastive_head(out2)
513
+ score = self.measure(out1, out2)
514
+
515
+ return score
516
+
517
+ def weight_activation(feats, grads):
518
+ cam = feats * F.relu(grads)
519
+ cam = torch.sum(cam, dim=1).squeeze().cpu().detach().numpy()
520
+ return cam
521
+
522
+ def get_gradcam(ssl_model, img1, img2):
523
+
524
+ grad_cam = GradCAM(ssl_model).to(device)
525
+ score = grad_cam(img1, img2)
526
+ grad_cam.zero_grad()
527
+ score.backward()
528
+
529
+ cam1 = weight_activation(grad_cam.features['1'], grad_cam.gradients['1'])
530
+ cam2 = weight_activation(grad_cam.features['2'], grad_cam.gradients['2'])
531
+ return cam1, cam2
532
+
533
+ def get_interactioncam(ssl_model, img1, img2, reduction, grad_interact = False):
534
+
535
+ grad_cam = GradCAM(ssl_model).to(device)
536
+ score = grad_cam(img1, img2)
537
+ grad_cam.zero_grad()
538
+ score.backward()
539
+
540
+ G1 = grad_cam.gradients['1']
541
+ G2 = grad_cam.gradients['2']
542
+
543
+ if grad_interact:
544
+ B, D, H, W = G1.size()
545
+ G1_ = G1.permute(0,2,3,1).view(B, H * W, D)
546
+ G2_ = G2.permute(0,2,3,1).view(B, H * W, D)
547
+ G_ = torch.bmm(G1_.permute(0,2,1), G2_) # (B, D, D)
548
+ G1, _ = torch.max(G_, dim = -1) # (B, D)
549
+ G2, _ = torch.max(G_, dim = 1) # (B, D)
550
+ G1 = G1.unsqueeze(-1).unsqueeze(-1)
551
+ G2 = G2.unsqueeze(-1).unsqueeze(-1)
552
+
553
+ if reduction == 'mean':
554
+ joint_weight = grad_cam.features['1'].mean([2,3]) * grad_cam.features['2'].mean([2,3])
555
+ elif reduction == 'max':
556
+ max_pooled1 = F.max_pool2d(grad_cam.features['1'], kernel_size=grad_cam.features['1'].size()[2:]).squeeze(-1).squeeze(-1)
557
+ max_pooled2 = F.max_pool2d(grad_cam.features['2'], kernel_size=grad_cam.features['2'].size()[2:]).squeeze(-1).squeeze(-1)
558
+ joint_weight = max_pooled1 * max_pooled2
559
+ else:
560
+ B, D, H, W = grad_cam.features['1'].size()
561
+ reshaped1 = grad_cam.features['1'].permute(0,2,3,1).reshape(B, H * W, D)
562
+ reshaped2 = grad_cam.features['2'].permute(0,2,3,1).reshape(B, H * W, D)
563
+ features1_query, features2_query = reshaped1.mean(1).unsqueeze(1), reshaped2.mean(1).unsqueeze(1)
564
+ attn1 = (features1_query @ reshaped1.transpose(-2, -1)).softmax(dim=-1)
565
+ attn2 = (features2_query @ reshaped2.transpose(-2, -1)).softmax(dim=-1)
566
+ att_reduced1 = (attn1 @ reshaped1).squeeze(1)
567
+ att_reduced2 = (attn2 @ reshaped2).squeeze(1)
568
+ joint_weight = att_reduced1 * att_reduced2
569
+
570
+ joint_weight = joint_weight.unsqueeze(-1).unsqueeze(-1).expand_as(grad_cam.features['1'])
571
+
572
+ feats1 = grad_cam.features['1'] * joint_weight
573
+ feats2 = grad_cam.features['2'] * joint_weight
574
+
575
+ cam1 = weight_activation(feats1, G1)
576
+ cam2 = weight_activation(feats2, G2)
577
+
578
+ return cam1, cam2
utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from PIL import Image
6
+ import random
7
+ import cv2
8
+ import io
9
+ from ssl_models.simclr2 import get_simclr2_model
10
+ from ssl_models.barlow_twins import get_barlow_twins_model
11
+ from ssl_models.simsiam import get_simsiam
12
+ from ssl_models.dino import get_dino_model_without_loss, get_dino_model_with_loss
13
+
14
+ def get_ssl_model(network, variant):
15
+
16
+ if network == 'simclrv2':
17
+ if variant == '1x':
18
+ ssl_model = get_simclr2_model('r50_1x_sk0_ema.pth').eval()
19
+ else:
20
+ ssl_model = get_simclr2_model('r50_2x_sk0_ema.pth').eval()
21
+ elif network == 'barlow_twins':
22
+ ssl_model = get_barlow_twins_model().eval()
23
+ elif network == 'simsiam':
24
+ ssl_model = get_simsiam().eval()
25
+ elif network == 'dino':
26
+ ssl_model = get_dino_model_without_loss().eval()
27
+ elif network == 'dino+loss':
28
+ ssl_model, dino_score = get_dino_model_with_loss()
29
+ ssl_model = ssl_model.eval()
30
+
31
+ return ssl_model
32
+
33
+ def overlay_heatmap(img, heatmap, denormalize = False):
34
+ loaded_img = img.squeeze(0).cpu().numpy().transpose((1, 2, 0))
35
+
36
+ if denormalize:
37
+ mean = np.array([0.485, 0.456, 0.406])
38
+ std = np.array([0.229, 0.224, 0.225])
39
+ loaded_img = std * loaded_img + mean
40
+
41
+ loaded_img = (loaded_img.clip(0, 1) * 255).astype(np.uint8)
42
+ cam = heatmap / heatmap.max()
43
+ cam = cv2.resize(cam, (224, 224))
44
+ cam = np.uint8(255 * cam)
45
+ cam = cv2.applyColorMap(cam, cv2.COLORMAP_JET) # jet: blue --> red
46
+ cam = cv2.cvtColor(cam, cv2.COLOR_BGR2RGB)
47
+ added_image = cv2.addWeighted(cam, 0.5, loaded_img, 0.5, 0)
48
+ return added_image
49
+
50
+ def viz_map(img_path, heatmap):
51
+ "For pixel invariance"
52
+ img = np.array(Image.open(img_path).resize((224,224)))
53
+ width, height, _ = img.shape
54
+ cam = heatmap.detach().cpu().numpy()
55
+ cam = cam / cam.max()
56
+ cam = cv2.resize(cam, (height, width))
57
+ heatmap = np.uint8(255 * cam)
58
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
59
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
60
+ added_image = cv2.addWeighted(heatmap, 0.5, img, 0.7, 0)
61
+ return added_image
62
+
63
+ def show_image(x, squeeze = True, denormalize = False):
64
+
65
+ if squeeze:
66
+ x = x.squeeze(0)
67
+
68
+ x = x.cpu().numpy().transpose((1, 2, 0))
69
+
70
+ if denormalize:
71
+ mean = np.array([0.485, 0.456, 0.406])
72
+ std = np.array([0.229, 0.224, 0.225])
73
+ x = std * x + mean
74
+
75
+ return x.clip(0, 1)
76
+
77
+ def deprocess(inp, to_numpy = True, to_PIL = False, denormalize = False):
78
+
79
+ if to_numpy:
80
+ inp = inp.detach().cpu().numpy()
81
+
82
+ inp = inp.squeeze(0).transpose((1, 2, 0))
83
+
84
+ if denormalize:
85
+ mean = np.array([0.485, 0.456, 0.406])
86
+ std = np.array([0.229, 0.224, 0.225])
87
+ inp = std * inp + mean
88
+
89
+ inp = (inp.clip(0, 1) * 255).astype(np.uint8)
90
+
91
+ if to_PIL:
92
+ return Image.fromarray(inp)
93
+ return inp
94
+
95
+ def fig2img(fig):
96
+ """Convert a Matplotlib figure to a PIL Image and return it"""
97
+ buf = io.BytesIO()
98
+ fig.savefig(buf, bbox_inches='tight', pad_inches=0)
99
+ buf.seek(0)
100
+ img = Image.open(buf)
101
+ return img