AnnonSubmission commited on
Commit
0992122
1 Parent(s): ff3029d

Update methods.py

Browse files
Files changed (1) hide show
  1. methods.py +1 -146
methods.py CHANGED
@@ -60,83 +60,13 @@ def occlusion(img1, img2, model, w_size = 64, stride = 8, batch_size = 32):
60
 
61
  heatmap1 = (heatmap1 - base_score.item()) * -1 # or base_score.item() - heatmap1. The higher the drop, the better
62
  heatmap2 = (heatmap2 - base_score.item()) * -1 # or base_score.item() - heatmap2. The higher the drop, the better
63
-
64
- return heatmap1, heatmap2
65
-
66
- def occlusion_context_agnositc(img1, img2, model, w_size = 64, stride = 8, batch_size = 32):
67
-
68
- measure = nn.CosineSimilarity(dim=-1)
69
- output_size = int(((img2.size(-1) - w_size) / stride) + 1)
70
- out1_condition, out2_condition = model(img1), model(img2)
71
-
72
- images1_occlude_mask = []
73
- images2_occlude_mask = []
74
-
75
- for i in range(output_size):
76
- for j in range(output_size):
77
- start_i, start_j = i * stride, j * stride
78
- image1 = img1.clone().detach()
79
- image2 = img2.clone().detach()
80
- image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0
81
- image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0
82
- images1_occlude_mask.append(image1)
83
- images2_occlude_mask.append(image2)
84
-
85
- images1_occlude_mask = torch.cat(images1_occlude_mask, dim=0).to(device)
86
- images2_occlude_mask = torch.cat(images2_occlude_mask, dim=0).to(device)
87
-
88
- images1_occlude_backround = []
89
- images2_occlude_backround = []
90
-
91
- copy_img1 = img1.clone().detach()
92
- copy_img2 = img2.clone().detach()
93
-
94
- for i in range(output_size):
95
- for j in range(output_size):
96
- start_i, start_j = i * stride, j * stride
97
-
98
- image1 = torch.zeros_like(img1)
99
- image2 = torch.zeros_like(img2)
100
-
101
- image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = copy_img1[:, :, start_i : start_i + w_size, start_j : start_j + w_size]
102
- image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = copy_img2[:, :, start_i : start_i + w_size, start_j : start_j + w_size]
103
-
104
- images1_occlude_backround.append(image1)
105
- images2_occlude_backround.append(image2)
106
-
107
- images1_occlude_backround = torch.cat(images1_occlude_backround, dim=0).to(device)
108
- images2_occlude_backround = torch.cat(images2_occlude_backround, dim=0).to(device)
109
-
110
- score_map1 = []
111
- score_map2 = []
112
-
113
- assert images1_occlude_mask.shape[0] == images2_occlude_mask.shape[0]
114
-
115
- for b in range(0, images1_occlude_mask.shape[0], batch_size):
116
-
117
- with torch.no_grad():
118
- out1_mask = model(images1_occlude_mask[b : b + batch_size, :])
119
- out2_mask = model(images2_occlude_mask[b : b + batch_size, :])
120
- out1_backround = model(images1_occlude_backround[b : b + batch_size, :])
121
- out2_backround = model(images2_occlude_backround[b : b + batch_size, :])
122
-
123
- out1 = out1_backround - out1_mask
124
- out2 = out2_backround - out2_mask
125
- score_map1.append(measure(out1, out2_condition)) # or torch.mm(out2_condition, out1.t())[0]
126
- score_map2.append(measure(out1_condition, out2)) # or torch.mm(out1_condition, out2.t())[0]
127
-
128
- score_map1 = torch.cat(score_map1, dim = 0)
129
- score_map2 = torch.cat(score_map2, dim = 0)
130
- assert images1_occlude_mask.shape[0] == images2_occlude_mask.shape[0] == score_map2.shape[0] == score_map1.shape[0]
131
-
132
- heatmap1 = score_map1.view(output_size, output_size).cpu().detach().numpy()
133
- heatmap2 = score_map2.view(output_size, output_size).cpu().detach().numpy()
134
 
135
  heatmap1 = (heatmap1 - heatmap1.min()) / (heatmap1.max() - heatmap1.min())
136
  heatmap2 = (heatmap2 - heatmap2.min()) / (heatmap2.max() - heatmap2.min())
137
 
138
  return heatmap1, heatmap2
139
 
 
140
  def pairwise_occlusion(img1, img2, model, batch_size, erase_scale, erase_ratio, num_erases):
141
 
142
  measure = nn.CosineSimilarity(dim=-1)
@@ -338,81 +268,6 @@ def smooth_grad(guided, ssl_model, img1, img2, blur_output, steps = 50):
338
 
339
  return sailency1, sailency2
340
 
341
- def get_sample_dataset(img_path, num_augments, batch_size, no_shift_transforms, ssl_model, n_components):
342
-
343
- measure = nn.CosineSimilarity(dim=-1)
344
- img = Image.open(img_path).convert('RGB') if isinstance(img_path, str) else img_path
345
- no_shift_aug = transforms.Compose([no_shift_transforms['aug'],
346
- transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3))])
347
-
348
- augments2 = [no_shift_aug(img).unsqueeze(0) for _ in range(num_augments)]
349
- data_samples1 = no_shift_transforms['pure'](img).unsqueeze(0).expand(num_augments, -1, -1, -1).to(device)
350
- data_samples2 = torch.cat(augments2).to(device)
351
-
352
- labels = []
353
- feats_invariance = []
354
-
355
- for b in range(0, data_samples1.shape[0], batch_size):
356
-
357
- with torch.no_grad():
358
- out1 = ssl_model(data_samples1[b : b + batch_size, :])
359
- out2 = ssl_model(data_samples2[b : b + batch_size, :])
360
- labels.append(measure(out1, out2))
361
- feats_invariance.append(F.relu(out2))
362
-
363
- data_labels = torch.cat(labels).unsqueeze(-1).to(device)
364
- feats_invariance = torch.cat(feats_invariance).to(device)
365
- nmf_model = NMF(n_components=n_components, init='random')
366
- # (T, 2048) = W.H = (2048,N) . (N,T), where H is the matrix representing the features of each transform
367
- H = nmf_model.fit_transform(feats_invariance.cpu().numpy())
368
- labels_invariance = torch.from_numpy(H.mean(1)).unsqueeze(-1).to(device)
369
-
370
- return data_samples1, data_samples2, data_labels, labels_invariance
371
-
372
- def pixel_invariance(data_samples1, data_samples2, data_labels, labels_invariance, resize_transform, size, epochs, learning_rate, l1_weight, zero_small_values, blur_output, nmf_weight):
373
-
374
- """
375
- size: resize the image to that when training the surrogate. Later we upsize
376
- epochs: number of epochs to train the surrogate model
377
- learning_rate: learning rate to train the surrogate model
378
- l1_weight: if not None, enables l1 regularization (sparsity)
379
- """
380
- x1 = resize_transform((size, size))(data_samples1) # (num_samples, 3, size, size)
381
- x2 = resize_transform((size, size))(data_samples2) # (num_samples, 3, size, size)
382
-
383
- x1 = x1.reshape(x1.size(0), -1).to(device)
384
- x2 = x2.reshape(x2.size(0), -1).to(device)
385
-
386
- surrogate = nn.Linear(size * size * 3, 1).to(device)
387
-
388
- criterion = nn.BCEWithLogitsLoss(reduction = 'sum')
389
- invariance_criterion = nn.MSELoss()
390
- optimizer = torch.optim.SGD(surrogate.parameters(), lr=learning_rate)
391
-
392
- for epoch in range(epochs):
393
- pred1, pred2 = surrogate(x1), surrogate(x2)
394
- preds = (pred1 + pred2) / 2
395
- loss = criterion(preds, data_labels)
396
- loss += nmf_weight * invariance_criterion(torch.sigmoid(preds), labels_invariance)
397
-
398
- if l1_weight is not None:
399
- loss += l1_weight * sum(p.abs().sum() for p in surrogate.parameters())
400
-
401
- optimizer.zero_grad()
402
- loss.backward()
403
- optimizer.step()
404
-
405
- heatmap = surrogate.weight.reshape(3, size, size)
406
- heatmap, _ = torch.max(heatmap, 0)
407
- heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
408
-
409
- if zero_small_values:
410
- heatmap[heatmap < 0.5] = 0
411
-
412
- if blur_output:
413
- heatmap = blur_sailency(heatmap.unsqueeze(0)).squeeze(0)
414
-
415
- return heatmap
416
 
417
  class GradCAM(nn.Module):
418
 
 
60
 
61
  heatmap1 = (heatmap1 - base_score.item()) * -1 # or base_score.item() - heatmap1. The higher the drop, the better
62
  heatmap2 = (heatmap2 - base_score.item()) * -1 # or base_score.item() - heatmap2. The higher the drop, the better
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  heatmap1 = (heatmap1 - heatmap1.min()) / (heatmap1.max() - heatmap1.min())
65
  heatmap2 = (heatmap2 - heatmap2.min()) / (heatmap2.max() - heatmap2.min())
66
 
67
  return heatmap1, heatmap2
68
 
69
+
70
  def pairwise_occlusion(img1, img2, model, batch_size, erase_scale, erase_ratio, num_erases):
71
 
72
  measure = nn.CosineSimilarity(dim=-1)
 
268
 
269
  return sailency1, sailency2
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  class GradCAM(nn.Module):
273