import gradio as gr import cv2 import torch import torch.utils.data as data from torchvision import transforms from torch import nn import torch.nn.functional as F import matplotlib.pyplot as plt from matplotlib import cm from matplotlib import colors from mpl_toolkits.axes_grid1 import ImageGrid import fire_network import numpy as np from PIL import Image # Possible Scales for multiscale inference scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25] device = 'cpu' # Load nets state = torch.load('fire.pth', map_location='cpu') state['net_params']['pretrained'] = None # no need for imagenet pretrained model net_sfm = fire_network.init_network(**state['net_params']).to(device) net_sfm.load_state_dict(state['state_dict']) dim_red_params_dict = {} for name, param in net_sfm.named_parameters(): if 'dim_reduction' in name: dim_red_params_dict[name] = param state2 = torch.load('fire_imagenet.pth', map_location='cpu') state2['net_params'] = state['net_params'] state2['state_dict'] = dict(state2['state_dict'], **dim_red_params_dict); net_imagenet = fire_network.init_network(**state['net_params']).to(device) net_imagenet.load_state_dict(state2['state_dict'], strict=False) transform = transforms.Compose([ transforms.Resize(1024), transforms.ToTensor(), transforms.Normalize(**dict(zip(["mean", "std"], net_sfm.runtime['mean_std']))) ]) def match(query_feat, pos_feat, LoweRatioTh=0.9): # first perform reciprocal nn dist = torch.cdist(query_feat, pos_feat) # print('dist.size',dist.size()) best1 = torch.argmin(dist, dim=1) best2 = torch.argmin(dist, dim=0) # print('best2.size',best2.size()) arange = torch.arange(best2.size(0)) reciprocal = best1[best2]==arange # check Lowe ratio test dist2 = dist.clone() dist2[best2,arange] = float('Inf') dist2_second2 = torch.argmin(dist2, dim=0) ratio1to2 = dist[best2,arange] / dist2_second2 valid = torch.logical_and(reciprocal, ratio1to2<=LoweRatioTh) pindices = torch.where(valid)[0] qindices = best2[pindices] # keep only the ones with same indices valid = pindices==qindices return pindices[valid] def clear_figures(): plt.figure().clear() plt.close() plt.cla() plt.clf() def generate_matching_superfeatures( im1, im2, Imagenet_model=False, scale_id=6, threshold=50, random_mode=False, sf_ids=''): #, only_matching=True): # print('im1:', im1.size) # print('im2:', im2.size) clear_figures() col = plt.get_cmap('tab10') net = net_sfm if Imagenet_model: net = net_imagenet im1_tensor = transform(im1).unsqueeze(0) im2_tensor = transform(im2).unsqueeze(0) im1_cv = np.array(im1)[:, :, ::-1].copy() im2_cv = np.array(im2)[:, :, ::-1].copy() # extract features with torch.no_grad(): output1 = net.get_superfeatures(im1_tensor.to(device), scales=[scales[scale_id]]) feats1 = output1[0][0] attns1 = output1[1][0] strenghts1 = output1[2][0] output2 = net.get_superfeatures(im2_tensor.to(device), scales=[scales[scale_id]]) feats2 = output2[0][0] attns2 = output2[1][0] strenghts2 = output2[2][0] feats1n = F.normalize(torch.t(torch.squeeze(feats1)), dim=1) feats2n = F.normalize(torch.t(torch.squeeze(feats2)), dim=1) ind_match = match(feats1n, feats2n) # which sf sf_idx_ = [] n_sf_ids = 10 if random_mode or sf_ids == '': sf_idx_ = np.random.randint(256, size=n_sf_ids) else: sf_idx_ = map(int, sf_ids.strip().split(',')) # only_matching: if random_mode: sf_idx_ = [int(jj) for jj in ind_match[np.random.randint(len(list(ind_match)), size=n_sf_ids)].numpy()] sf_idx_ = list( dict.fromkeys(sf_idx_) ) else: sf_idx_ = [i for i in sf_idx_ if i in list(ind_match)] n_sf_ids = len(sf_idx_) # Store all binary SF att maps to show them all at once in the end all_att_bin1 = [] all_att_bin2 = [] for n, i in enumerate(sf_idx_): att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32) att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0) att_heat_bin = np.where(att_heat>threshold, 255, 0) all_att_bin1.append(att_heat_bin) att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32) att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0) att_heat_bin = np.where(att_heat>threshold, 255, 0) all_att_bin2.append(att_heat_bin) fin_img = [] img1rsz = np.copy(im1_cv) for j, att in enumerate(all_att_bin1): att = cv2.resize(att, im1.size, interpolation=cv2.INTER_NEAREST) mask2d = zip(*np.where(att==255)) for m,n in mask2d: col_ = col.colors[j] col_ = 255*np.array(colors.to_rgba(col_))[:3] img1rsz[m,n, :] = col_[::-1] img2rsz = np.copy(im2_cv) for j, att in enumerate(all_att_bin2): att = cv2.resize(att, im2.size, interpolation=cv2.INTER_NEAREST) mask2d = zip(*np.where(att==255)) for m,n in mask2d: col_ = col.colors[j] col_ = 255*np.array(colors.to_rgba(col_))[:3] img2rsz[m,n, :] = col_[::-1] fig1 = plt.figure(1) plt.imshow(cv2.cvtColor(img1rsz, cv2.COLOR_BGR2RGB)) ax1 = plt.gca() ax1.axis('off') plt.tight_layout() fig2 = plt.figure(2) plt.imshow(cv2.cvtColor(img2rsz, cv2.COLOR_BGR2RGB)) ax2 = plt.gca() ax2.axis('off') plt.tight_layout() f = lambda m,c: plt.plot([],[],marker=m, color=c, ls="none")[0] handles = [f("s", col.colors[i]) for i in range(n_sf_ids)] fig_leg = plt.figure(3) legend = plt.legend(handles, sf_idx_, framealpha=1, frameon=False, facecolor='w',fontsize=25, loc="center") ax3 = plt.gca() ax3.axis('off') plt.tight_layout() im1 = None im2 = None return fig1, fig2, fig_leg # GRADIO APP title = "Visualizing Super-features" description = "This is a visualization demo for the ICLR 2022 paper Learning Super-Features for Image Retrieval

" article = "

Original Github Repo

" iface = gr.Interface( fn=generate_matching_superfeatures, inputs=[ gr.inputs.Image(shape=(1024, 1024), type="pil", label="First Image"), gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"), gr.inputs.Checkbox(default=False, label="ImageNet Model (Default: SfM-120k)"), gr.inputs.Slider(minimum=0, maximum=6, step=1, default=4, label="Scale"), gr.inputs.Slider(minimum=0, maximum=255, step=25, default=150, label="Binarization Threshold"), gr.inputs.Checkbox(default=True, label="Show random (matching) SFs"), gr.inputs.Textbox(lines=1, default="", label="...or show specific SF IDs:", optional=True), ], outputs=[ gr.outputs.Image(type="plot", label="First Image SFs"), gr.outputs.Image(type="plot", label="Second Image SFs"), gr.outputs.Image(type="plot", label="SF legend")], title=title, theme='peach', layout="horizontal", description=description, article=article, examples=[ ["chateau_1.png", "chateau_2.png", False, 3, 150, False, '170,15,25,63,193,125,92,214,107'], ["areopoli1.jpeg", "areopoli2.jpeg", False, 4, 150, False, '205,2,163,130'], ["jaipur1.jpeg", "jaipur2.jpeg", False, 4, 50, False, '51,206,216,49,27'], ["basil1.jpeg", "basil2.jpeg", True, 4, 100, False, '75,152,19,36,156'], ["mill1.jpeg", "mill2.jpeg", False, 4, 100, False, '177,88,170,190,151,155'], ] ) iface.launch(enable_queue=True)