import torch import numpy as np import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import torchvision from PIL import Image from sklearn.decomposition import NMF device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def relu_hook_function(module, grad_in, grad_out): if isinstance(module, nn.ReLU): return (F.relu(grad_in[0]),) def blur_sailency(input_image): return torchvision.transforms.functional.gaussian_blur(input_image, kernel_size=[11, 11], sigma=[5,5]) def occlusion(img1, img2, model, w_size = 64, stride = 8, batch_size = 32): measure = nn.CosineSimilarity(dim=-1) output_size = int(((img2.size(-1) - w_size) / stride) + 1) out1_condition, out2_condition = model(img1), model(img2) images1 = [] images2 = [] for i in range(output_size): for j in range(output_size): start_i, start_j = i * stride, j * stride image1 = img1.clone().detach() image2 = img2.clone().detach() image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0 image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0 images1.append(image1) images2.append(image2) images1 = torch.cat(images1, dim=0).to(device) images2 = torch.cat(images2, dim=0).to(device) score_map1 = [] score_map2 = [] assert images1.shape[0] == images2.shape[0] for b in range(0, images2.shape[0], batch_size): with torch.no_grad(): out1 = model(images1[b : b + batch_size, :]) out2 = model(images2[b : b + batch_size, :]) score_map1.append(measure(out1, out2_condition)) # try torch.mm(out2_condition, out1.t())[0] score_map2.append(measure(out1_condition, out2)) # try torch.mm(out1_condition, out2.t())[0] score_map1 = torch.cat(score_map1, dim = 0) score_map2 = torch.cat(score_map2, dim = 0) assert images2.shape[0] == score_map2.shape[0] == score_map1.shape[0] heatmap1 = score_map1.view(output_size, output_size).cpu().detach().numpy() heatmap2 = score_map2.view(output_size, output_size).cpu().detach().numpy() base_score = measure(out1_condition, out2_condition) heatmap1 = (heatmap1 - base_score.item()) * -1 # or base_score.item() - heatmap1. The higher the drop, the better heatmap2 = (heatmap2 - base_score.item()) * -1 # or base_score.item() - heatmap2. The higher the drop, the better heatmap1 = (heatmap1 - heatmap1.min()) / (heatmap1.max() - heatmap1.min()) heatmap2 = (heatmap2 - heatmap2.min()) / (heatmap2.max() - heatmap2.min()) return heatmap1, heatmap2 def pairwise_occlusion(img1, img2, model, batch_size, erase_scale, erase_ratio, num_erases): measure = nn.CosineSimilarity(dim=-1) out1_condition, out2_condition = model(img1), model(img2) baseline = measure(out1_condition, out2_condition).detach() # a bit sensitive to scale and ratio. erase_scale is from (scale[0] * 100) % to (scale[1] * 100) % random_erase = transforms.RandomErasing(p=1.0, scale=erase_scale, ratio=erase_ratio) image1 = img1.clone().detach() image2 = img2.clone().detach() images1 = [] images2 = [] for _ in range(num_erases): images1.append(random_erase(image1)) images2.append(random_erase(image2)) images1 = torch.cat(images1, dim=0).to(device) images2 = torch.cat(images2, dim=0).to(device) sims = [] weights1 = [] weights2 = [] for b in range(0, images2.shape[0], batch_size): with torch.no_grad(): out1 = model(images1[b : b + batch_size, :]) out2 = model(images2[b : b + batch_size, :]) sims.append(measure(out1, out2)) weights1.append(out1.norm(dim=-1)) weights2.append(out2.norm(dim=-1)) sims = torch.cat(sims, dim = 0) weights1, weights2 = torch.cat(weights1, dim = 0).cpu().numpy(), torch.cat(weights2, dim = 0).cpu().numpy() weights = list(zip(weights1, weights2)) sims = baseline - sims # the higher the drop, the better sims = F.softmax(sims, dim = -1) sims = sims.cpu().numpy() assert sims.shape[0] == images1.shape[0] == images2.shape[0] A1 = np.zeros((224, 224)) A2 = np.zeros((224, 224)) for n in range(images1.shape[0]): im1_2d = images1[n].cpu().numpy().transpose((1, 2, 0)).sum(axis=-1) im2_2d = images2[n].cpu().numpy().transpose((1, 2, 0)).sum(axis=-1) joint_similarity = sims[n] weight = weights[n] if weight[0] < weight[1]: A1[im1_2d == 0] += joint_similarity else: A2[im2_2d == 0] += joint_similarity A1 = A1 / (np.max(A1) + 1e-9) A2 = A2 / (np.max(A2) + 1e-9) return A1, A2 def create_mixed_images(transform_type, ig_transforms, step, img_path, add_noise): img = Image.open(img_path).convert('RGB') if isinstance(img_path, str) else img_path img1 = ig_transforms['pure'](img).unsqueeze(0).to(device) img2 = ig_transforms[transform_type](img).unsqueeze(0).to(device) lambdas = np.arange(1,0,-step) mixed_images = [] for l,lam in enumerate(lambdas): mixed_img = lam * img1 + (1 - lam) * img2 mixed_images.append(mixed_img) if add_noise: sigma = 0.15 / (torch.max(img1) - torch.min(img1)).item() mixed_images = [im + torch.zeros_like(im).normal_(0, sigma) if (n>0) and (n