import torch import numpy as np import torch.nn as nn import torchvision.transforms as transforms import matplotlib import matplotlib.pyplot as plt from PIL import Image import cv2 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") from data_transforms import normal_transforms, no_shift_transforms, ig_transforms, modify_transforms from utils import overlay_heatmap, viz_map, show_image, deprocess, get_ssl_model, fig2img from methods import occlusion, occlusion_context_agnositc, pairwise_occlusion from methods import get_difference from methods import create_mixed_images, averaged_transforms, sailency, smooth_grad from methods import get_sample_dataset, pixel_invariance, get_gradcam, get_interactioncam matplotlib.use('Agg') def load_model(model_name): global network, ssl_model, denorm if model_name == "simclrv2 (1X)": variant = '1x' network = 'simclrv2' denorm = False elif model_name == "simclrv2 (2X)": variant = '2x' network = 'simclrv2' denorm = False elif model_name == "Barlow Twins": network = 'barlow_twins' variant = None denorm = True ssl_model = get_ssl_model(network, variant) if network != 'simclrv2': global normal_transforms, no_shift_transforms, ig_transforms normal_transforms, no_shift_transforms, ig_transforms = modify_transforms(normal_transforms, no_shift_transforms, ig_transforms) return "Loaded Model Successfully" def load_or_augment_images(img1_input, img2_input, use_aug): global img_main, img1, img2 img_main = img1_input.convert('RGB') if use_aug: img1 = normal_transforms['pure'](img_main).unsqueeze(0).to(device) img2 = normal_transforms['aug'](img_main).unsqueeze(0).to(device) else: img1 = normal_transforms['pure'](img_main).unsqueeze(0).to(device) img2 = img2_input.convert('RGB') img2 = normal_transforms['pure'](img2).unsqueeze(0).to(device) similarity = "Similarity: {:.3f}".format(nn.CosineSimilarity(dim=-1)(ssl_model(img1), ssl_model(img2)).item()) fig, axs = plt.subplots(1, 2, figsize=(10,10)) np.vectorize(lambda ax:ax.axis('off'))(axs) axs[0].imshow(show_image(img1, denormalize = denorm)) axs[1].imshow(show_image(img2, denormalize = denorm)) plt.subplots_adjust(wspace=0.1, hspace = 0) pil_output = fig2img(fig) return pil_output, similarity def run_occlusion(w_size, stride): heatmap1, heatmap2 = occlusion(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32) heatmap1_ca, heatmap2_ca = occlusion_context_agnositc(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32) heatmap1_po, heatmap2_po = pairwise_occlusion(img1, img2, ssl_model, batch_size = 32, erase_scale = (0.1, 0.3), erase_ratio = (1, 1.5), num_erases = 100) added_image1 = overlay_heatmap(img1, heatmap1, denormalize = denorm) added_image2 = overlay_heatmap(img2, heatmap2, denormalize = denorm) added_image1_ca = overlay_heatmap(img1, heatmap1_ca, denormalize = denorm) added_image2_ca = overlay_heatmap(img2, heatmap2_ca, denormalize = denorm) fig, axs = plt.subplots(2, 4, figsize=(20,10)) np.vectorize(lambda ax:ax.axis('off'))(axs) axs[0, 0].imshow(show_image(img1, denormalize = denorm)) axs[0, 1].imshow(added_image1) axs[0, 1].set_title("Conditional Occlusion") axs[0, 2].imshow(added_image1_ca) axs[0, 2].set_title("CA Cond. Occlusion") axs[0, 3].imshow((deprocess(img1, denormalize = denorm) * heatmap1_po[:,:,None]).astype('uint8')) axs[0, 3].set_title("Pairwise Occlusion") axs[1, 0].imshow(show_image(img2, denormalize = denorm)) axs[1, 1].imshow(added_image2) axs[1, 2].imshow(added_image2_ca) axs[1, 3].imshow((deprocess(img2, denormalize = denorm) * heatmap2_po[:,:,None]).astype('uint8')) plt.subplots_adjust(wspace=0, hspace = 0.01) pil_output = fig2img(fig) return pil_output def get_model_difference(later): imagenet_images, ssl_images = get_difference(ssl_model = ssl_model, baseline = 'imagenet', image = img2, lr = 1e4, l2_weight = 0.1, alpha_weight = 1e-7, alpha_power = 6, tv_weight = 1e-8, init_scale = 0.1, network = network) fig, axs = plt.subplots(3, 3, figsize=(10,10)) np.vectorize(lambda ax:ax.axis('off'))(axs) for aa, (in_img, ssl_img) in enumerate(zip(imagenet_images, ssl_images)): axs[aa,0].imshow(deprocess(img2, denormalize = denorm)) axs[aa,1].imshow(deprocess(in_img)) axs[aa,2].imshow(deprocess(ssl_img)) axs[0,0].set_title("Original Image") axs[0,1].set_title("Synthesized (cls)") axs[0,2].set_title("Synthesized (contastive)") plt.subplots_adjust(wspace=0.01, hspace = 0.01) pil_output = fig2img(fig) return pil_output def get_avg_trasforms(transform_type, add_noise, blur_output, guided): mixed_images = create_mixed_images(transform_type = transform_type, ig_transforms = ig_transforms, step = 0.1, img_path = img_main, add_noise = add_noise) # vanilla gradients (for comparison purposes) sailency1_van, sailency2_van = sailency(guided = guided, ssl_model = ssl_model, img1 = mixed_images[0], img2 = mixed_images[-1], blur_output = blur_output) # smooth gradients (for comparison purposes) sailency1_s, sailency2_s = smooth_grad(guided = guided, ssl_model = ssl_model, img1 = mixed_images[0], img2 = mixed_images[-1], blur_output = blur_output, steps = 50) # integrated transform sailency1, sailency2 = averaged_transforms(guided = guided, ssl_model = ssl_model, mixed_images = mixed_images, blur_output = blur_output) fig, axs = plt.subplots(2, 4, figsize=(20,10)) np.vectorize(lambda ax:ax.axis('off'))(axs) axs[0,0].imshow(show_image(mixed_images[0], denormalize = denorm)) axs[0,1].imshow(show_image(sailency1_van.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet) axs[0,1].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5) axs[0,1].set_title("Vanilla Gradients") axs[0,2].imshow(show_image(sailency1_s.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet) axs[0,2].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5) axs[0,2].set_title("Smooth Gradients") axs[0,3].imshow(show_image(sailency1.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet) axs[0,3].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5) axs[0,3].set_title("Integrated Transform") axs[1,0].imshow(show_image(mixed_images[-1], denormalize = denorm)) axs[1,1].imshow(show_image(sailency2_van.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet) axs[1,1].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5) axs[1,2].imshow(show_image(sailency2_s.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet) axs[1,2].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5) axs[1,3].imshow(show_image(sailency2.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet) axs[1,3].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5) plt.subplots_adjust(wspace=0.02, hspace = 0.02) pil_output = fig2img(fig) return pil_output def get_cams(): gradcam1, gradcam2 = get_gradcam(ssl_model, img1, img2) intcam1_mean, intcam2_mean = get_interactioncam(ssl_model, img1, img2, reduction = 'mean') intcam1_maxmax, intcam2_maxmax = get_interactioncam(ssl_model, img1, img2, reduction = 'max', grad_interact = True) intcam1_attnmax, intcam2_attnmax = get_interactioncam(ssl_model, img1, img2, reduction = 'attn', grad_interact = True) fig, axs = plt.subplots(2, 5, figsize=(20,8)) np.vectorize(lambda ax:ax.axis('off'))(axs) axs[0,0].imshow(show_image(img1[0], squeeze = False, denormalize = denorm)) axs[0,1].imshow(overlay_heatmap(img1, gradcam1, denormalize = denorm)) axs[0,1].set_title("Grad-CAM") axs[0,2].imshow(overlay_heatmap(img1, intcam1_mean, denormalize = denorm)) axs[0,2].set_title("IntCAM Mean") axs[0,3].imshow(overlay_heatmap(img1, intcam1_maxmax, denormalize = denorm)) axs[0,3].set_title("IntCAM Max + IntGradMax") axs[0,4].imshow(overlay_heatmap(img1, intcam1_attnmax, denormalize = denorm)) axs[0,4].set_title("IntCAM Attn + IntGradMax") axs[1,0].imshow(show_image(img2[0], squeeze = False, denormalize = denorm)) axs[1,1].imshow(overlay_heatmap(img2, gradcam2, denormalize = denorm)) axs[1,2].imshow(overlay_heatmap(img2, intcam2_mean, denormalize = denorm)) axs[1,3].imshow(overlay_heatmap(img2, intcam2_maxmax, denormalize = denorm)) axs[1,4].imshow(overlay_heatmap(img2, intcam2_attnmax, denormalize = denorm)) plt.subplots_adjust(wspace=0.01, hspace = 0.01) pil_output = fig2img(fig) return pil_output def get_pixel_invariance(): data_samples1, data_samples2, data_labels, labels_invariance = get_sample_dataset(img_path = img_main, num_augments = 1000, batch_size = 32, no_shift_transforms = no_shift_transforms, ssl_model = ssl_model, n_components = 10) inv_heatmap = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels, labels_invariance = labels_invariance, resize_transform = transforms.Resize, size = 64, epochs = 1000, learning_rate = 0.1, l1_weight = 0.2, zero_small_values = True, blur_output = True, nmf_weight = 0) inv_heatmap_nmf = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels, labels_invariance = labels_invariance, resize_transform = transforms.Resize, size = 64, epochs = 100, learning_rate = 0.1, l1_weight = 0.2, zero_small_values = True, blur_output = True, nmf_weight = 1) fig, axs = plt.subplots(1, 2, figsize=(10,5)) np.vectorize(lambda ax:ax.axis('off'))(axs) axs[0].imshow(viz_map(img_main, inv_heatmap)) axs[0].set_title("Heatmap w/o NMF") axs[1].imshow(viz_map(img_main, inv_heatmap_nmf)) axs[1].set_title("Heatmap w/ NMF") plt.subplots_adjust(wspace=0.01, hspace = 0.01) pil_output = fig2img(fig) return pil_output xai = gr.Blocks() with xai: gr.Markdown("