import csv import sys import gradio as gr import numpy as np import skimage.transform import torch import torchvision.models as models import torchvision.transforms as transforms from matplotlib import pyplot as plt from numpy import matlib as mb from PIL import Image csv.field_size_limit(sys.maxsize) def compute_spatial_similarity(conv1, conv2): """ Takes in the last convolutional layer from two images, computes the pooled output feature, and then generates the spatial similarity map for both images. """ conv1 = conv1.reshape(-1, 7 * 7).T conv2 = conv2.reshape(-1, 7 * 7).T pool1 = np.mean(conv1, axis=0) pool2 = np.mean(conv2, axis=0) out_sz = (int(np.sqrt(conv1.shape[0])), int(np.sqrt(conv1.shape[0]))) conv1_normed = conv1 / np.linalg.norm(pool1) / conv1.shape[0] conv2_normed = conv2 / np.linalg.norm(pool2) / conv2.shape[0] im_similarity = np.zeros((conv1_normed.shape[0], conv1_normed.shape[0])) for zz in range(conv1_normed.shape[0]): repPx = mb.repmat(conv1_normed[zz, :], conv1_normed.shape[0], 1) im_similarity[zz, :] = np.multiply(repPx, conv2_normed).sum(axis=1) similarity1 = np.reshape(np.sum(im_similarity, axis=1), out_sz) similarity2 = np.reshape(np.sum(im_similarity, axis=0), out_sz) return similarity1, similarity2 # Get Layer 4 display_transform = transforms.Compose( [transforms.Resize(256), transforms.CenterCrop((224, 224))] ) imagenet_transform = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) class Wrapper(torch.nn.Module): def __init__(self, model): super(Wrapper, self).__init__() self.model = model self.layer4_ouputs = None def fw_hook(module, input, output): self.layer4_ouputs = output self.model.layer4.register_forward_hook(fw_hook) def forward(self, input): _ = self.model(input) return self.layer4_ouputs def __repr__(self): return "Wrapper" def get_layer4(input_image): l4_model = models.resnet50(pretrained=True) # l4_model = l4_model.cuda() l4_model.eval() wrapped_model = Wrapper(l4_model) with torch.no_grad(): data = imagenet_transform(input_image).unsqueeze(0) # data = data.cuda() reference_layer4 = wrapped_model(data) return reference_layer4.data.to("cpu").numpy() def NormalizeData(data): return (data - np.min(data)) / (np.max(data) - np.min(data)) # Visualization def visualize_similarities(q, n): image1 = Image.fromarray(q) image2 = Image.fromarray(n) a = get_layer4(image1).squeeze() b = get_layer4(image2).squeeze() sim1, sim2 = compute_spatial_similarity(a, b) sim1 = NormalizeData(sim1) sim2 = NormalizeData(sim2) fig, axes = plt.subplots(1, 2, figsize=(12, 5)) axes[0].imshow(display_transform(image1)) im1 = axes[0].imshow( skimage.transform.resize(sim1, (224, 224)), alpha=0.5, cmap="jet", vmin=0, vmax=1, ) axes[1].imshow(display_transform(image2)) im2 = axes[1].imshow( skimage.transform.resize(sim2, (224, 224)), alpha=0.5, cmap="jet", vmin=0, vmax=1, ) axes[0].set_axis_off() axes[1].set_axis_off() fig.colorbar(im1, ax=axes[0]) fig.colorbar(im2, ax=axes[1]) plt.tight_layout() q_image = display_transform(image1) nearest_image = display_transform(image2) # make a binarized veruin of the Q fig2, ax = plt.subplots(1, figsize=(5, 5)) ax.imshow(display_transform(image1)) # create a binarized version of sim1 , for value below 0.5 set to 0 and above 0.5 set to 1 sim1_bin = np.where(sim1 > 0.5, 1, 0) print(sim1_bin) # create a binarized version of sim2 , for value below 0.5 set to 0 and above 0.5 set to 1 sim2_bin = np.where(sim2 > 0.5, 1, 0) ax.imshow( skimage.transform.resize(sim1_bin, (224, 224)), alpha=1, cmap="binary", vmin=0, vmax=1, ) return fig, q_image, nearest_image, fig2 # GRADIO APP main = gr.Interface( fn=visualize_similarities, inputs=["image", "image"], allow_flagging="never", outputs=["plot", "image", "image", "plot"], cache_examples=True, enable_queue=False, examples=[ [ "./examples/Red_Winged_Blackbird_0012_6015.jpg", "./examples/Red_Winged_Blackbird_0025_5342.jpg", ], ], ) # iface.launch() blocks = gr.Blocks() with blocks: gr.Markdown( """ # Visualizing Deep Similarity Networks A quick demo to visualize the similarity between two images. [Original Paper](https://arxiv.org/pdf/1901.00536.pdf) - [Github Page](https://github.com/GWUvision/Similarity-Visualization) """ ) gr.TabbedInterface([main], ["Main"]) blocks.launch(debug=True)