SuperFeatures / app.py
YannisK's picture
merge
700c051
raw
history blame
7.47 kB
import gradio as gr
import cv2
import torch
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib import colors
from mpl_toolkits.axes_grid1 import ImageGrid
from torchvision import transforms
import fire_network
import numpy as np
from PIL import Image
# Possible Scales for multiscale inference
scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
device = 'cpu'
# Load net
state = torch.load('fire.pth', map_location='cpu')
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
net = fire_network.init_network(**state['net_params']).to(device)
net.load_state_dict(state['state_dict'])
transform = transforms.Compose([
transforms.Resize(1024),
transforms.ToTensor(),
transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))
])
# sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
col = plt.get_cmap('tab10')
def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50, sf_ids=''):
print('im1:', im1.size)
print('im2:', im2.size)
# which sf
sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
if sf_ids.lower().startswith('r'):
n_sf_ids = int(sf_ids[1:])
sf_idx_ = np.random.randint(256, size=n_sf_ids)
elif sf_ids != '':
sf_idx_ = map(int, sf_ids.strip().split(','))
im1_tensor = transform(im1).unsqueeze(0)
im2_tensor = transform(im2).unsqueeze(0)
im1_cv = np.array(im1)[:, :, ::-1].copy()
im2_cv = np.array(im2)[:, :, ::-1].copy()
# extract features
with torch.no_grad():
output1 = net.get_superfeatures(im1_tensor.to(device), scales=[scales[scale_id]])
feats1 = output1[0][0]
attns1 = output1[1][0]
strenghts1 = output1[2][0]
output2 = net.get_superfeatures(im2_tensor.to(device), scales=[scales[scale_id]])
feats2 = output2[0][0]
attns2 = output2[1][0]
strenghts2 = output2[2][0]
print(feats1.shape, feats2.shape)
print(attns1.shape, attns2.shape)
print(strenghts1.shape, strenghts2.shape)
# Store all binary SF att maps to show them all at once in the end
all_att_bin1 = []
all_att_bin2 = []
for n, i in enumerate(sf_idx_):
# all_atts[n].append(attn[j][scale_id][0,i,:,:].numpy())
att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32)
att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
att_heat_bin = np.where(att_heat>threshold, 255, 0)
# print(att_heat_bin)
all_att_bin1.append(att_heat_bin)
att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32)
att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
att_heat_bin = np.where(att_heat>threshold, 255, 0)
all_att_bin2.append(att_heat_bin)
fin_img = []
img1rsz = np.copy(im1_cv)
print('im1:', im1.size)
print('img1rsz:', img1rsz.shape)
for j, att in enumerate(all_att_bin1):
att = cv2.resize(att, im1.size, interpolation=cv2.INTER_NEAREST)
# att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
# att = cv2.resize(att, imgz[i].shape[:2][::-1])
# att = att.resize(shape)
# att = resize(att, im1.size)
mask2d = zip(*np.where(att==255))
for m,n in mask2d:
col_ = col.colors[j] if j < 7 else col.colors[j+1]
if j == 0: col_ = col.colors[9]
col_ = 255*np.array(colors.to_rgba(col_))[:3]
img1rsz[m,n, :] = col_[::-1]
fin_img.append(img1rsz)
img2rsz = np.copy(im2_cv)
print('im2:', im2.size)
print('img2rsz:', img2rsz.shape)
for j, att in enumerate(all_att_bin2):
att = cv2.resize(att, im2.size, interpolation=cv2.INTER_NEAREST)
# att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
# # att = cv2.resize(att, imgz[i].shape[:2][::-1])
# att = att.resize(im2.shape)
# print('att:', att.shape)
mask2d = zip(*np.where(att==255))
for m,n in mask2d:
col_ = col.colors[j] if j < 7 else col.colors[j+1]
if j == 0: col_ = col.colors[9]
col_ = 255*np.array(colors.to_rgba(col_))[:3]
img2rsz[m,n, :] = col_[::-1]
fin_img.append(img2rsz)
fig1 = plt.figure(1)
plt.imshow(cv2.cvtColor(img1rsz, cv2.COLOR_BGR2RGB))
ax1 = plt.gca()
# ax1.axis('scaled')
ax1.axis('off')
plt.tight_layout()
# fig1.canvas.draw()
fig2 = plt.figure(2)
plt.imshow(cv2.cvtColor(img2rsz, cv2.COLOR_BGR2RGB))
ax2 = plt.gca()
# ax2.axis('scaled')
ax2.axis('off')
plt.tight_layout()
# fig2.canvas.draw()
# fig = plt.figure()
# grid = ImageGrid(fig, 111, nrows_ncols=(2, 1), axes_pad=0.1)
# for ax, img in zip(grid, fin_img):
# ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# ax.axis('scaled')
# ax.axis('off')
# plt.tight_layout()
# fig.suptitle("Matching SFs", fontsize=16)
# fig.canvas.draw()
# # Now we can save it to a numpy array.
# data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
# data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return fig1, fig2, ','.join(map(str, sf_idx_))
# GRADIO APP
title = "Visualizing Super-features"
description = "This is a visualization demo for the ICLR 2022 paper <b><a href='https://github.com/naver/fire' target='_blank'>Learning Super-Features for Image Retrieval</a></p></b>"
article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>"
# css = ".output-image, .input-image {height: 40rem !important; width: 100% !important;}"
# css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }"
# css = ".output_image, .input_image {hieght: 1000px !important}"
css = ".input_image, .input_image {height: 600px !important; width: 600px !important;} "
# css = ".output-image, .input-image {height: 40rem !important; width: 100% !important;}"
iface = gr.Interface(
fn=generate_matching_superfeatures,
inputs=[
# gr.inputs.Image(shape=(1024, 1024), type="pil", label="First Image"),
# gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"),
gr.inputs.Image(type="pil", label="First Image"),
gr.inputs.Image(type="pil", label="Second Image"),
gr.inputs.Slider(minimum=0, maximum=6, step=1, default=2, label="Scale"),
gr.inputs.Slider(minimum=1, maximum=255, step=25, default=100, label="Binarization Threshold"),
gr.inputs.Textbox(lines=1, default="", label="SF IDs to show (comma separated numbers from 0-255; typing 'rX' will return X random SFs", optional=True)],
outputs=[
gr.outputs.Image(type="plot", label="First Image SFs"),
gr.outputs.Image(type="plot", label="Second Image SFs"),
gr.outputs.Textbox(label="SFs")],
# outputs=gr.outputs.Image(shape=(1024,2048), type="plot"),
title=title,
theme='peach',
layout="horizontal",
description=description,
article=article,
css=css,
examples=[
["chateau_1.png", "chateau_2.png", 2, 100, '55,14,5,4,52,57,40,9'],
["anafi1.jpeg", "anafi2.jpeg", 4, 50, '99,100,142,213,236'],
["areopoli1.jpeg", "areopoli2.jpeg", 4, 50, '99,100,142,213,236'],
]
)
iface.launch(enable_queue=True)