DeepSimilarity / app.py
taesiri's picture
Updated Gradio
a4e6581
raw history blame
No virus
3.68 kB
import gradio as gr
import numpy as np
import skimage.transform
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from numpy import matlib as mb
from PIL import Image
import csv
import sys
csv.field_size_limit(sys.maxsize)
def compute_spatial_similarity(conv1, conv2):
"""
Takes in the last convolutional layer from two images, computes the pooled output
feature, and then generates the spatial similarity map for both images.
"""
conv1 = conv1.reshape(-1, 7 * 7).T
conv2 = conv2.reshape(-1, 7 * 7).T
pool1 = np.mean(conv1, axis=0)
pool2 = np.mean(conv2, axis=0)
out_sz = (int(np.sqrt(conv1.shape[0])), int(np.sqrt(conv1.shape[0])))
conv1_normed = conv1 / np.linalg.norm(pool1) / conv1.shape[0]
conv2_normed = conv2 / np.linalg.norm(pool2) / conv2.shape[0]
im_similarity = np.zeros((conv1_normed.shape[0], conv1_normed.shape[0]))
for zz in range(conv1_normed.shape[0]):
repPx = mb.repmat(conv1_normed[zz, :], conv1_normed.shape[0], 1)
im_similarity[zz, :] = np.multiply(repPx, conv2_normed).sum(axis=1)
similarity1 = np.reshape(np.sum(im_similarity, axis=1), out_sz)
similarity2 = np.reshape(np.sum(im_similarity, axis=0), out_sz)
return similarity1, similarity2
# Get Layer 4
display_transform = transforms.Compose(
[transforms.Resize(256), transforms.CenterCrop((224, 224))]
)
imagenet_transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
class Wrapper(torch.nn.Module):
def __init__(self, model):
super(Wrapper, self).__init__()
self.model = model
self.layer4_ouputs = None
def fw_hook(module, input, output):
self.layer4_ouputs = output
self.model.layer4.register_forward_hook(fw_hook)
def forward(self, input):
_ = self.model(input)
return self.layer4_ouputs
def __repr__(self):
return "Wrapper"
def get_layer4(input_image):
l4_model = models.resnet50(pretrained=True)
# l4_model = l4_model.cuda()
l4_model.eval()
wrapped_model = Wrapper(l4_model)
with torch.no_grad():
data = imagenet_transform(input_image).unsqueeze(0)
# data = data.cuda()
reference_layer4 = wrapped_model(data)
return reference_layer4.data.to("cpu").numpy()
# Visualization
def visualize_similarities(image1, image2):
a = get_layer4(image1).squeeze()
b = get_layer4(image2).squeeze()
sim1, sim2 = compute_spatial_similarity(a, b)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(display_transform(image1))
im1 = axes[0].imshow(
skimage.transform.resize(sim1, (224, 224)), alpha=0.5, cmap="jet"
)
# axes[0].colorbar()
axes[1].imshow(display_transform(image2))
im2 = axes[1].imshow(
skimage.transform.resize(sim2, (224, 224)), alpha=0.5, cmap="jet"
)
# axes[1].colorbar()
fig.colorbar(im1, ax=axes[0])
fig.colorbar(im2, ax=axes[1])
plt.tight_layout()
return fig
# GRADIO APP
iface = gr.Interface(
fn=visualize_similarities,
inputs=[
gr.Image(type="pil"),
gr.Image(type="pil"),
],
allow_flagging="never",
outputs=[gr.Plot(type="matplotlib")],
cache_examples=True,
enable_queue=False,
examples=[
[
"./examples/Red_Winged_Blackbird_0012_6015.jpg",
"./examples/Red_Winged_Blackbird_0025_5342.jpg",
]
],
)
iface.launch()