Spaces:
Sleeping
Sleeping
import torch | |
from matplotlib import pyplot as plt | |
from PIL import Image | |
import numpy as np | |
from numpy import matlib as mb | |
import torchvision.transforms as transforms | |
import torchvision.models as models | |
import skimage.transform | |
import gradio as gr | |
def compute_spatial_similarity(conv1, conv2): | |
""" | |
Takes in the last convolutional layer from two images, computes the pooled output | |
feature, and then generates the spatial similarity map for both images. | |
""" | |
conv1 = conv1.reshape(-1, 7*7).T | |
conv2 = conv2.reshape(-1, 7*7).T | |
pool1 = np.mean(conv1, axis=0) | |
pool2 = np.mean(conv2, axis=0) | |
out_sz = (int(np.sqrt(conv1.shape[0])),int(np.sqrt(conv1.shape[0]))) | |
conv1_normed = conv1 / np.linalg.norm(pool1) / conv1.shape[0] | |
conv2_normed = conv2 / np.linalg.norm(pool2) / conv2.shape[0] | |
im_similarity = np.zeros((conv1_normed.shape[0], conv1_normed.shape[0])) | |
for zz in range(conv1_normed.shape[0]): | |
repPx = mb.repmat(conv1_normed[zz,:],conv1_normed.shape[0],1) | |
im_similarity[zz,:] = np.multiply(repPx,conv2_normed).sum(axis=1) | |
similarity1 = np.reshape(np.sum(im_similarity,axis=1),out_sz) | |
similarity2 = np.reshape(np.sum(im_similarity,axis=0),out_sz) | |
return similarity1, similarity2 | |
# Get Layer 4 | |
display_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop((224, 224))]) | |
imagenet_transform = transforms.Compose( | |
[transforms.Resize(256), | |
transforms.CenterCrop((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) | |
class Wrapper(torch.nn.Module): | |
def __init__(self, model): | |
super(Wrapper, self).__init__() | |
self.model = model | |
self.layer4_ouputs = None | |
def fw_hook(module, input, output): | |
self.layer4_ouputs = output | |
self.model.layer4.register_forward_hook(fw_hook) | |
def forward(self, input): | |
_ = self.model(input) | |
return self.layer4_ouputs | |
def __repr__(self): | |
return "Wrapper" | |
def get_layer4(input_image): | |
l4_model = models.resnet50(pretrained=True) | |
l4_model = l4_model.cuda() | |
l4_model.eval(); | |
wrapped_model = Wrapper(l4_model) | |
with torch.no_grad(): | |
data = imagenet_transform(input_image).unsqueeze(0) | |
data = data.cuda() | |
reference_layer4 = wrapped_model(data) | |
return reference_layer4.data.to('cpu').numpy() | |
# Visualization | |
def visualize_similarities(image1, image2): | |
a = get_layer4(image1).squeeze() | |
b = get_layer4(image2).squeeze() | |
sim1, sim2 = compute_spatial_similarity(a, b) | |
fig, axes = plt.subplots(1, 2, figsize=(12, 5)) | |
axes[0].imshow(display_transform(image1)) | |
im1=axes[0].imshow(skimage.transform.resize(sim1, (224, 224)), alpha=0.6, cmap='jet') | |
# axes[0].colorbar() | |
axes[1].imshow(display_transform(image2)) | |
im2=axes[1].imshow(skimage.transform.resize(sim2, (224, 224)), alpha=0.6, cmap='jet') | |
# axes[1].colorbar() | |
fig.colorbar(im1, ax=axes[0]) | |
fig.colorbar(im2, ax=axes[1]) | |
plt.tight_layout() | |
return fig | |
# GRADIO APP | |
iface = gr.Interface(fn=visualize_similarities, | |
inputs=[gr.inputs.Image(shape=(300, 300), type='pil'), | |
gr.inputs.Image(shape=(300, 300), type='pil')], outputs="plot") | |
iface.launch() |