DeepSimilarity / app.py
taesiri's picture
Update app.py
df1b5bc
import csv
import sys
import gradio as gr
import numpy as np
import skimage.transform
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from numpy import matlib as mb
from PIL import Image
csv.field_size_limit(sys.maxsize)
def compute_spatial_similarity(conv1, conv2):
"""
Takes in the last convolutional layer from two images, computes the pooled output
feature, and then generates the spatial similarity map for both images.
"""
conv1 = conv1.reshape(-1, 7 * 7).T
conv2 = conv2.reshape(-1, 7 * 7).T
pool1 = np.mean(conv1, axis=0)
pool2 = np.mean(conv2, axis=0)
out_sz = (int(np.sqrt(conv1.shape[0])), int(np.sqrt(conv1.shape[0])))
conv1_normed = conv1 / np.linalg.norm(pool1) / conv1.shape[0]
conv2_normed = conv2 / np.linalg.norm(pool2) / conv2.shape[0]
im_similarity = np.zeros((conv1_normed.shape[0], conv1_normed.shape[0]))
for zz in range(conv1_normed.shape[0]):
repPx = mb.repmat(conv1_normed[zz, :], conv1_normed.shape[0], 1)
im_similarity[zz, :] = np.multiply(repPx, conv2_normed).sum(axis=1)
similarity1 = np.reshape(np.sum(im_similarity, axis=1), out_sz)
similarity2 = np.reshape(np.sum(im_similarity, axis=0), out_sz)
return similarity1, similarity2
# Get Layer 4
display_transform = transforms.Compose(
[transforms.Resize(256), transforms.CenterCrop((224, 224))]
)
imagenet_transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
class Wrapper(torch.nn.Module):
def __init__(self, model):
super(Wrapper, self).__init__()
self.model = model
self.layer4_ouputs = None
def fw_hook(module, input, output):
self.layer4_ouputs = output
self.model.layer4.register_forward_hook(fw_hook)
def forward(self, input):
_ = self.model(input)
return self.layer4_ouputs
def __repr__(self):
return "Wrapper"
def get_layer4(input_image):
l4_model = models.resnet50(pretrained=True)
# l4_model = l4_model.cuda()
l4_model.eval()
wrapped_model = Wrapper(l4_model)
with torch.no_grad():
data = imagenet_transform(input_image).unsqueeze(0)
# data = data.cuda()
reference_layer4 = wrapped_model(data)
return reference_layer4.data.to("cpu").numpy()
def NormalizeData(data):
return (data - np.min(data)) / (np.max(data) - np.min(data))
# Visualization
def visualize_similarities(q, n):
image1 = Image.fromarray(q)
image2 = Image.fromarray(n)
a = get_layer4(image1).squeeze()
b = get_layer4(image2).squeeze()
sim1, sim2 = compute_spatial_similarity(a, b)
sim1 = NormalizeData(sim1)
sim2 = NormalizeData(sim2)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(display_transform(image1))
im1 = axes[0].imshow(
skimage.transform.resize(sim1, (224, 224)),
alpha=0.5,
cmap="jet",
vmin=0,
vmax=1,
)
axes[1].imshow(display_transform(image2))
im2 = axes[1].imshow(
skimage.transform.resize(sim2, (224, 224)),
alpha=0.5,
cmap="jet",
vmin=0,
vmax=1,
)
axes[0].set_axis_off()
axes[1].set_axis_off()
fig.colorbar(im1, ax=axes[0])
fig.colorbar(im2, ax=axes[1])
plt.tight_layout()
q_image = display_transform(image1)
nearest_image = display_transform(image2)
# make a binarized veruin of the Q
fig2, ax = plt.subplots(1, figsize=(5, 5))
ax.imshow(display_transform(image1))
# create a binarized version of sim1 , for value below 0.5 set to 0 and above 0.5 set to 1
sim1_bin = np.where(sim1 > 0.5, 1, 0)
print(sim1_bin)
# create a binarized version of sim2 , for value below 0.5 set to 0 and above 0.5 set to 1
sim2_bin = np.where(sim2 > 0.5, 1, 0)
ax.imshow(
skimage.transform.resize(sim1_bin, (224, 224)),
alpha=1,
cmap="binary",
vmin=0,
vmax=1,
)
return fig, q_image, nearest_image, fig2
# GRADIO APP
main = gr.Interface(
fn=visualize_similarities,
inputs=["image", "image"],
allow_flagging="never",
outputs=["plot", "image", "image", "plot"],
cache_examples=True,
enable_queue=False,
examples=[
[
"./examples/Red_Winged_Blackbird_0012_6015.jpg",
"./examples/Red_Winged_Blackbird_0025_5342.jpg",
],
],
)
# iface.launch()
blocks = gr.Blocks()
with blocks:
gr.Markdown(
"""
# Visualizing Deep Similarity Networks
A quick demo to visualize the similarity between two images.
[Original Paper](https://arxiv.org/pdf/1901.00536.pdf) - [Github Page](https://github.com/GWUvision/Similarity-Visualization)
"""
)
gr.TabbedInterface([main], ["Main"])
blocks.launch(debug=True)