import torch import numpy as np import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import torchvision from PIL import Image from sklearn.decomposition import NMF device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def relu_hook_function(module, grad_in, grad_out): if isinstance(module, nn.ReLU): return (F.relu(grad_in[0]),) def blur_sailency(input_image): return torchvision.transforms.functional.gaussian_blur(input_image, kernel_size=[11, 11], sigma=[5,5]) def occlusion(img1, img2, model, w_size = 64, stride = 8, batch_size = 32): measure = nn.CosineSimilarity(dim=-1) output_size = int(((img2.size(-1) - w_size) / stride) + 1) out1_condition, out2_condition = model(img1), model(img2) images1 = [] images2 = [] for i in range(output_size): for j in range(output_size): start_i, start_j = i * stride, j * stride image1 = img1.clone().detach() image2 = img2.clone().detach() image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0 image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0 images1.append(image1) images2.append(image2) images1 = torch.cat(images1, dim=0).to(device) images2 = torch.cat(images2, dim=0).to(device) score_map1 = [] score_map2 = [] assert images1.shape[0] == images2.shape[0] for b in range(0, images2.shape[0], batch_size): with torch.no_grad(): out1 = model(images1[b : b + batch_size, :]) out2 = model(images2[b : b + batch_size, :]) score_map1.append(measure(out1, out2_condition)) # try torch.mm(out2_condition, out1.t())[0] score_map2.append(measure(out1_condition, out2)) # try torch.mm(out1_condition, out2.t())[0] score_map1 = torch.cat(score_map1, dim = 0) score_map2 = torch.cat(score_map2, dim = 0) assert images2.shape[0] == score_map2.shape[0] == score_map1.shape[0] heatmap1 = score_map1.view(output_size, output_size).cpu().detach().numpy() heatmap2 = score_map2.view(output_size, output_size).cpu().detach().numpy() base_score = measure(out1_condition, out2_condition) heatmap1 = (heatmap1 - base_score.item()) * -1 # or base_score.item() - heatmap1. The higher the drop, the better heatmap2 = (heatmap2 - base_score.item()) * -1 # or base_score.item() - heatmap2. The higher the drop, the better return heatmap1, heatmap2 def occlusion_context_agnositc(img1, img2, model, w_size = 64, stride = 8, batch_size = 32): measure = nn.CosineSimilarity(dim=-1) output_size = int(((img2.size(-1) - w_size) / stride) + 1) out1_condition, out2_condition = model(img1), model(img2) images1_occlude_mask = [] images2_occlude_mask = [] for i in range(output_size): for j in range(output_size): start_i, start_j = i * stride, j * stride image1 = img1.clone().detach() image2 = img2.clone().detach() image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0 image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0 images1_occlude_mask.append(image1) images2_occlude_mask.append(image2) images1_occlude_mask = torch.cat(images1_occlude_mask, dim=0).to(device) images2_occlude_mask = torch.cat(images2_occlude_mask, dim=0).to(device) images1_occlude_backround = [] images2_occlude_backround = [] copy_img1 = img1.clone().detach() copy_img2 = img2.clone().detach() for i in range(output_size): for j in range(output_size): start_i, start_j = i * stride, j * stride image1 = torch.zeros_like(img1) image2 = torch.zeros_like(img2) image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = copy_img1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = copy_img2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] images1_occlude_backround.append(image1) images2_occlude_backround.append(image2) images1_occlude_backround = torch.cat(images1_occlude_backround, dim=0).to(device) images2_occlude_backround = torch.cat(images2_occlude_backround, dim=0).to(device) score_map1 = [] score_map2 = [] assert images1_occlude_mask.shape[0] == images2_occlude_mask.shape[0] for b in range(0, images1_occlude_mask.shape[0], batch_size): with torch.no_grad(): out1_mask = model(images1_occlude_mask[b : b + batch_size, :]) out2_mask = model(images2_occlude_mask[b : b + batch_size, :]) out1_backround = model(images1_occlude_backround[b : b + batch_size, :]) out2_backround = model(images2_occlude_backround[b : b + batch_size, :]) out1 = out1_backround - out1_mask out2 = out2_backround - out2_mask score_map1.append(measure(out1, out2_condition)) # or torch.mm(out2_condition, out1.t())[0] score_map2.append(measure(out1_condition, out2)) # or torch.mm(out1_condition, out2.t())[0] score_map1 = torch.cat(score_map1, dim = 0) score_map2 = torch.cat(score_map2, dim = 0) assert images1_occlude_mask.shape[0] == images2_occlude_mask.shape[0] == score_map2.shape[0] == score_map1.shape[0] heatmap1 = score_map1.view(output_size, output_size).cpu().detach().numpy() heatmap2 = score_map2.view(output_size, output_size).cpu().detach().numpy() heatmap1 = (heatmap1 - heatmap1.min()) / (heatmap1.max() - heatmap1.min()) heatmap2 = (heatmap2 - heatmap2.min()) / (heatmap2.max() - heatmap2.min()) return heatmap1, heatmap2 def pairwise_occlusion(img1, img2, model, batch_size, erase_scale, erase_ratio, num_erases): measure = nn.CosineSimilarity(dim=-1) out1_condition, out2_condition = model(img1), model(img2) baseline = measure(out1_condition, out2_condition).detach() # a bit sensitive to scale and ratio. erase_scale is from (scale[0] * 100) % to (scale[1] * 100) % random_erase = transforms.RandomErasing(p=1.0, scale=erase_scale, ratio=erase_ratio) image1 = img1.clone().detach() image2 = img2.clone().detach() images1 = [] images2 = [] for _ in range(num_erases): images1.append(random_erase(image1)) images2.append(random_erase(image2)) images1 = torch.cat(images1, dim=0).to(device) images2 = torch.cat(images2, dim=0).to(device) sims = [] weights1 = [] weights2 = [] for b in range(0, images2.shape[0], batch_size): with torch.no_grad(): out1 = model(images1[b : b + batch_size, :]) out2 = model(images2[b : b + batch_size, :]) sims.append(measure(out1, out2)) weights1.append(out1.norm(dim=-1)) weights2.append(out2.norm(dim=-1)) sims = torch.cat(sims, dim = 0) weights1, weights2 = torch.cat(weights1, dim = 0).cpu().numpy(), torch.cat(weights2, dim = 0).cpu().numpy() weights = list(zip(weights1, weights2)) sims = baseline - sims # the higher the drop, the better sims = F.softmax(sims, dim = -1) sims = sims.cpu().numpy() assert sims.shape[0] == images1.shape[0] == images2.shape[0] A1 = np.zeros((224, 224)) A2 = np.zeros((224, 224)) for n in range(images1.shape[0]): im1_2d = images1[n].cpu().numpy().transpose((1, 2, 0)).sum(axis=-1) im2_2d = images2[n].cpu().numpy().transpose((1, 2, 0)).sum(axis=-1) joint_similarity = sims[n] weight = weights[n] if weight[0] < weight[1]: A1[im1_2d == 0] += joint_similarity else: A2[im2_2d == 0] += joint_similarity A1 = A1 / (np.max(A1) + 1e-9) A2 = A2 / (np.max(A2) + 1e-9) return A1, A2 def tv_reg(img, l1 = True): diff_i = (img[:, :, :, 1:] - img[:, :, :, :-1]) diff_j = (img[:, :, 1:, :] - img[:, :, :-1, :]) if l1: return diff_i.abs().sum() + diff_j.abs().sum() else: return diff_i.pow(2).sum() + diff_j.pow(2).sum() def synthesize(ssl_model, model_type, img1, img_cls_layer, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network): if model_type == 'imagenet': reduce_lr = False model = torchvision.models.resnet50(pretrained=True) model = list(model.children())[:img_cls_layer] model = nn.Sequential(*model).to(device) model.eval() else: reduce_lr = True shift_layer = 3 if network == 'simclrv2' else 0 equivalent_layer = img_cls_layer - shift_layer model = list(ssl_model.encoder.net.children())[:equivalent_layer] model = nn.Sequential(*model).to(device) model.eval() opt_img = (init_scale * torch.randn(1, 3, 224, 224)).to(device).requires_grad_() target_feats = model(img1).detach() optimizer = torch.optim.SGD([opt_img], lr=lr, momentum=0.9) for i in range(201): opt_img.data = opt_img.data.clip(0,1) optimizer.zero_grad() output = model(opt_img) l2_loss = l2_weight * ((output - target_feats) ** 2).sum() / (target_feats ** 2).sum() reg_alpha = alpha_weight * (opt_img ** alpha_power).sum() reg_total_variation = tv_weight * tv_reg(opt_img, l1 = False) loss = l2_loss + reg_alpha + reg_total_variation loss.backward() optimizer.step() if reduce_lr and i % 40 == 0: for param_group in optimizer.param_groups: param_group['lr'] *= 1/10 return opt_img def get_difference(ssl_model, baseline, image, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network): imagenet_images = [] ssl_images = [] for lay in range(4,7): image_net_image = synthesize(ssl_model, baseline, image, lay, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network).detach().clone() ssl_image = synthesize(ssl_model, 'ssl', image, lay, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network).detach().clone() imagenet_images.append(image_net_image) ssl_images.append(ssl_image) return imagenet_images, ssl_images def create_mixed_images(transform_type, ig_transforms, step, img_path, add_noise): img = Image.open(img_path).convert('RGB') if isinstance(img_path, str) else img_path img1 = ig_transforms['pure'](img).unsqueeze(0).to(device) img2 = ig_transforms[transform_type](img).unsqueeze(0).to(device) lambdas = np.arange(1,0,-step) mixed_images = [] for l,lam in enumerate(lambdas): mixed_img = lam * img1 + (1 - lam) * img2 mixed_images.append(mixed_img) if add_noise: sigma = 0.15 / (torch.max(img1) - torch.min(img1)).item() mixed_images = [im + torch.zeros_like(im).normal_(0, sigma) if (n>0) and (n