Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
import matplotlib | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import cv2 | |
import gradio as gr | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
from data_transforms import normal_transforms, no_shift_transforms, ig_transforms, modify_transforms | |
from utils import overlay_heatmap, viz_map, show_image, deprocess, get_ssl_model, fig2img | |
from methods import occlusion, occlusion_context_agnositc, pairwise_occlusion | |
from methods import get_difference | |
from methods import create_mixed_images, averaged_transforms, sailency, smooth_grad | |
from methods import get_sample_dataset, pixel_invariance, get_gradcam, get_interactioncam | |
matplotlib.use('Agg') | |
def load_model(model_name): | |
global network, ssl_model, denorm | |
if model_name == "simclrv2 (1X)": | |
variant = '1x' | |
network = 'simclrv2' | |
denorm = False | |
elif model_name == "simclrv2 (2X)": | |
variant = '2x' | |
network = 'simclrv2' | |
denorm = False | |
elif model_name == "Barlow Twins": | |
network = 'barlow_twins' | |
variant = None | |
denorm = True | |
ssl_model = get_ssl_model(network, variant) | |
if network != 'simclrv2': | |
global normal_transforms, no_shift_transforms, ig_transforms | |
normal_transforms, no_shift_transforms, ig_transforms = modify_transforms(normal_transforms, no_shift_transforms, ig_transforms) | |
return "Loaded Model Successfully" | |
def load_or_augment_images(img1_input, img2_input, use_aug): | |
global img_main, img1, img2 | |
img_main = img1_input.convert('RGB') | |
if use_aug: | |
img1 = normal_transforms['pure'](img_main).unsqueeze(0).to(device) | |
img2 = normal_transforms['aug'](img_main).unsqueeze(0).to(device) | |
else: | |
img1 = normal_transforms['pure'](img_main).unsqueeze(0).to(device) | |
img2 = img2_input.convert('RGB') | |
img2 = normal_transforms['pure'](img2).unsqueeze(0).to(device) | |
similarity = "Similarity: {:.3f}".format(nn.CosineSimilarity(dim=-1)(ssl_model(img1), ssl_model(img2)).item()) | |
fig, axs = plt.subplots(1, 2, figsize=(10,10)) | |
np.vectorize(lambda ax:ax.axis('off'))(axs) | |
axs[0].imshow(show_image(img1, denormalize = denorm)) | |
axs[1].imshow(show_image(img2, denormalize = denorm)) | |
plt.subplots_adjust(wspace=0.1, hspace = 0) | |
pil_output = fig2img(fig) | |
return pil_output, similarity | |
def run_occlusion(w_size, stride): | |
heatmap1, heatmap2 = occlusion(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32) | |
heatmap1_ca, heatmap2_ca = occlusion_context_agnositc(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32) | |
heatmap1_po, heatmap2_po = pairwise_occlusion(img1, img2, ssl_model, batch_size = 32, erase_scale = (0.1, 0.3), erase_ratio = (1, 1.5), num_erases = 100) | |
added_image1 = overlay_heatmap(img1, heatmap1, denormalize = denorm) | |
added_image2 = overlay_heatmap(img2, heatmap2, denormalize = denorm) | |
added_image1_ca = overlay_heatmap(img1, heatmap1_ca, denormalize = denorm) | |
added_image2_ca = overlay_heatmap(img2, heatmap2_ca, denormalize = denorm) | |
fig, axs = plt.subplots(2, 4, figsize=(20,10)) | |
np.vectorize(lambda ax:ax.axis('off'))(axs) | |
axs[0, 0].imshow(show_image(img1, denormalize = denorm)) | |
axs[0, 1].imshow(added_image1) | |
axs[0, 1].set_title("Conditional Occlusion") | |
axs[0, 2].imshow(added_image1_ca) | |
axs[0, 2].set_title("CA Cond. Occlusion") | |
axs[0, 3].imshow((deprocess(img1, denormalize = denorm) * heatmap1_po[:,:,None]).astype('uint8')) | |
axs[0, 3].set_title("Pairwise Occlusion") | |
axs[1, 0].imshow(show_image(img2, denormalize = denorm)) | |
axs[1, 1].imshow(added_image2) | |
axs[1, 2].imshow(added_image2_ca) | |
axs[1, 3].imshow((deprocess(img2, denormalize = denorm) * heatmap2_po[:,:,None]).astype('uint8')) | |
plt.subplots_adjust(wspace=0, hspace = 0.01) | |
pil_output = fig2img(fig) | |
return pil_output | |
def get_model_difference(later): | |
imagenet_images, ssl_images = get_difference(ssl_model = ssl_model, baseline = 'imagenet', image = img2, lr = 1e4, | |
l2_weight = 0.1, alpha_weight = 1e-7, alpha_power = 6, tv_weight = 1e-8, | |
init_scale = 0.1, network = network) | |
fig, axs = plt.subplots(3, 3, figsize=(10,10)) | |
np.vectorize(lambda ax:ax.axis('off'))(axs) | |
for aa, (in_img, ssl_img) in enumerate(zip(imagenet_images, ssl_images)): | |
axs[aa,0].imshow(deprocess(img2, denormalize = denorm)) | |
axs[aa,1].imshow(deprocess(in_img)) | |
axs[aa,2].imshow(deprocess(ssl_img)) | |
axs[0,0].set_title("Original Image") | |
axs[0,1].set_title("Synthesized (cls)") | |
axs[0,2].set_title("Synthesized (contastive)") | |
plt.subplots_adjust(wspace=0.01, hspace = 0.01) | |
pil_output = fig2img(fig) | |
return pil_output | |
def get_avg_trasforms(transform_type, add_noise, blur_output, guided): | |
mixed_images = create_mixed_images(transform_type = transform_type, | |
ig_transforms = ig_transforms, | |
step = 0.1, | |
img_path = img_main, | |
add_noise = add_noise) | |
# vanilla gradients (for comparison purposes) | |
sailency1_van, sailency2_van = sailency(guided = guided, ssl_model = ssl_model, | |
img1 = mixed_images[0], img2 = mixed_images[-1], | |
blur_output = blur_output) | |
# smooth gradients (for comparison purposes) | |
sailency1_s, sailency2_s = smooth_grad(guided = guided, ssl_model = ssl_model, | |
img1 = mixed_images[0], img2 = mixed_images[-1], | |
blur_output = blur_output, steps = 50) | |
# integrated transform | |
sailency1, sailency2 = averaged_transforms(guided = guided, ssl_model = ssl_model, | |
mixed_images = mixed_images, | |
blur_output = blur_output) | |
fig, axs = plt.subplots(2, 4, figsize=(20,10)) | |
np.vectorize(lambda ax:ax.axis('off'))(axs) | |
axs[0,0].imshow(show_image(mixed_images[0], denormalize = denorm)) | |
axs[0,1].imshow(show_image(sailency1_van.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet) | |
axs[0,1].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5) | |
axs[0,1].set_title("Vanilla Gradients") | |
axs[0,2].imshow(show_image(sailency1_s.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet) | |
axs[0,2].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5) | |
axs[0,2].set_title("Smooth Gradients") | |
axs[0,3].imshow(show_image(sailency1.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet) | |
axs[0,3].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5) | |
axs[0,3].set_title("Integrated Transform") | |
axs[1,0].imshow(show_image(mixed_images[-1], denormalize = denorm)) | |
axs[1,1].imshow(show_image(sailency2_van.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet) | |
axs[1,1].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5) | |
axs[1,2].imshow(show_image(sailency2_s.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet) | |
axs[1,2].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5) | |
axs[1,3].imshow(show_image(sailency2.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet) | |
axs[1,3].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5) | |
plt.subplots_adjust(wspace=0.02, hspace = 0.02) | |
pil_output = fig2img(fig) | |
return pil_output | |
def get_cams(): | |
gradcam1, gradcam2 = get_gradcam(ssl_model, img1, img2) | |
intcam1_mean, intcam2_mean = get_interactioncam(ssl_model, img1, img2, reduction = 'mean') | |
intcam1_maxmax, intcam2_maxmax = get_interactioncam(ssl_model, img1, img2, reduction = 'max', grad_interact = True) | |
intcam1_attnmax, intcam2_attnmax = get_interactioncam(ssl_model, img1, img2, reduction = 'attn', grad_interact = True) | |
fig, axs = plt.subplots(2, 5, figsize=(20,8)) | |
np.vectorize(lambda ax:ax.axis('off'))(axs) | |
axs[0,0].imshow(show_image(img1[0], squeeze = False, denormalize = denorm)) | |
axs[0,1].imshow(overlay_heatmap(img1, gradcam1, denormalize = denorm)) | |
axs[0,1].set_title("Grad-CAM") | |
axs[0,2].imshow(overlay_heatmap(img1, intcam1_mean, denormalize = denorm)) | |
axs[0,2].set_title("IntCAM Mean") | |
axs[0,3].imshow(overlay_heatmap(img1, intcam1_maxmax, denormalize = denorm)) | |
axs[0,3].set_title("IntCAM Max + IntGradMax") | |
axs[0,4].imshow(overlay_heatmap(img1, intcam1_attnmax, denormalize = denorm)) | |
axs[0,4].set_title("IntCAM Attn + IntGradMax") | |
axs[1,0].imshow(show_image(img2[0], squeeze = False, denormalize = denorm)) | |
axs[1,1].imshow(overlay_heatmap(img2, gradcam2, denormalize = denorm)) | |
axs[1,2].imshow(overlay_heatmap(img2, intcam2_mean, denormalize = denorm)) | |
axs[1,3].imshow(overlay_heatmap(img2, intcam2_maxmax, denormalize = denorm)) | |
axs[1,4].imshow(overlay_heatmap(img2, intcam2_attnmax, denormalize = denorm)) | |
plt.subplots_adjust(wspace=0.01, hspace = 0.01) | |
pil_output = fig2img(fig) | |
return pil_output | |
def get_pixel_invariance(): | |
data_samples1, data_samples2, data_labels, labels_invariance = get_sample_dataset(img_path = img_main, | |
num_augments = 1000, | |
batch_size = 32, | |
no_shift_transforms = no_shift_transforms, | |
ssl_model = ssl_model, | |
n_components = 10) | |
inv_heatmap = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels, | |
labels_invariance = labels_invariance, resize_transform = transforms.Resize, size = 64, | |
epochs = 1000, learning_rate = 0.1, l1_weight = 0.2, zero_small_values = True, | |
blur_output = True, nmf_weight = 0) | |
inv_heatmap_nmf = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels, | |
labels_invariance = labels_invariance, resize_transform = transforms.Resize, size = 64, | |
epochs = 100, learning_rate = 0.1, l1_weight = 0.2, zero_small_values = True, | |
blur_output = True, nmf_weight = 1) | |
fig, axs = plt.subplots(1, 2, figsize=(10,5)) | |
np.vectorize(lambda ax:ax.axis('off'))(axs) | |
axs[0].imshow(viz_map(img_main, inv_heatmap)) | |
axs[0].set_title("Heatmap w/o NMF") | |
axs[1].imshow(viz_map(img_main, inv_heatmap_nmf)) | |
axs[1].set_title("Heatmap w/ NMF") | |
plt.subplots_adjust(wspace=0.01, hspace = 0.01) | |
pil_output = fig2img(fig) | |
return pil_output | |
xai = gr.Blocks() | |
with xai: | |
gr.Markdown("<h1>Methods for Explaining Contrastive Learning, CVPR 2023 Submission</h1>") | |
gr.Markdown("The interface is simplified as much as possible with only necessary options to select for each method. Please use our Google Colab demo for more flexibility.") | |
with gr.Row(): | |
model_name = gr.Dropdown(["simclrv2 (1X)", "simclrv2 (2X)", "Barlow Twins"], label="Choose Model and press \"Load Model\"") | |
load_model_button = gr.Button("Load Model") | |
status_or_similarity = gr.inputs.Textbox(label = "Status") | |
with gr.Row(): | |
gr.Markdown("You can either load two images or load a single image and augment it to get the second image (in that case please check the \"Use Augmentations\" button). After that, please press on \"Show Images\"") | |
img1 = gr.Image(type='pil', label = "First Image") | |
img2 = gr.Image(type='pil', label = "Second Image") | |
with gr.Row(): | |
use_aug = gr.Checkbox(value = False, label = "Use Augmentations") | |
load_images_button = gr.Button("Show Images") | |
gr.Markdown("Choose a method from the different tabs. You may leave the default options as they are and press on \"Run\" ") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Tabs(): | |
with gr.TabItem("Interaction-CAM"): | |
cams_button = gr.Button("Get Heatmaps") | |
with gr.TabItem("Perturbation Methods"): | |
w_size = gr.Number(value = 64, label = "Occlusion Window Size", precision = 0) | |
stride = gr.Number(value = 8, label = "Occlusion Stride", precision = 0) | |
occlusion_button = gr.Button("Get Heatmap") | |
with gr.TabItem("Averaged Transforms"): | |
transform_type = gr.inputs.Radio(label="Data Augment", choices=['color_jitter', 'blur', 'grayscale', 'solarize', 'combine'], default="combine") | |
add_noise = gr.Checkbox(value = True, label = "Add Noise") | |
blur_output = gr.Checkbox(value = True, label = "Blur Output") | |
guided = gr.Checkbox(value = True, label = "Guided Backprop") | |
avgtransform_button = gr.Button("Get Saliency") | |
with gr.TabItem("Pixel Invariance"): | |
gr.Markdown("Note: Invariance map will be obtained for the first image") | |
pixel_invariance_button = gr.Button("Get Invariance Map") | |
with gr.TabItem("Image Synthesization"): | |
baseline = gr.inputs.Radio(label="Compare With", choices=["Supervised Classification"], default="Supervised Classification") | |
modeldiff_button = gr.Button("Compare") | |
with gr.Column(): | |
output_image = gr.Image(type='pil', show_label = False) | |
load_model_button.click(load_model, inputs = model_name, outputs = status_or_similarity) | |
load_images_button.click(load_or_augment_images, inputs = [img1, img2, use_aug], outputs = [output_image, status_or_similarity]) | |
occlusion_button.click(run_occlusion, inputs=[w_size,stride], outputs=output_image) | |
modeldiff_button.click(get_model_difference, inputs = baseline, outputs = output_image) | |
avgtransform_button.click(get_avg_trasforms, inputs = [transform_type, add_noise, blur_output, guided], outputs = output_image) | |
cams_button.click(get_cams, inputs = [], outputs = output_image) | |
pixel_invariance_button.click(get_pixel_invariance, inputs = [], outputs = output_image) | |
xai.launch() | |