import gradio as gr import cv2 import torch import torch.utils.data as data from torchvision import transforms from torch import nn import torch.nn.functional as F import matplotlib.pyplot as plt from matplotlib import cm from matplotlib import colors from mpl_toolkits.axes_grid1 import ImageGrid import fire_network import numpy as np from PIL import Image # Possible Scales for multiscale inference scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25] device = 'cpu' # Load net state = torch.load('fire.pth', map_location='cpu') state['net_params']['pretrained'] = None # no need for imagenet pretrained model net = fire_network.init_network(**state['net_params']).to(device) net.load_state_dict(state['state_dict']) # --------------------------------------- transform = transforms.Compose([ transforms.Resize(1024), transforms.ToTensor(), transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std']))) ]) # --------------------------------------- # class ImgDataset(data.Dataset): # def __init__(self, images, imsize): # self.images = images # self.imsize = imsize # self.transform = transforms.Compose([transforms.ToTensor(), \ # transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))]) # def __getitem__(self, index): # img = self.images[index] # img.thumbnail((self.imsize, self.imsize), Image.Resampling.LANCZOS) # print('after imresize:', img.size) # return self.transform(img) # def __len__(self): # return len(self.images) # --------------------------------------- def match(query_feat, pos_feat, LoweRatioTh=0.9): # first perform reciprocal nn dist = torch.cdist(query_feat, pos_feat) best1 = torch.argmin(dist, dim=1) best2 = torch.argmin(dist, dim=0) print('best2.size',best2.size()) arange = torch.arange(best2.size(0)) reciprocal = best1[best2]==arange # check Lowe ratio test dist2 = dist.clone() dist2[best2,arange] = float('Inf') dist2_second2 = torch.argmin(dist2, dim=0) ratio1to2 = dist[best2,arange] / dist2_second2 valid = torch.logical_and(reciprocal, ratio1to2<=LoweRatioTh) pindices = torch.where(valid)[0] qindices = best2[pindices] # keep only the ones with same indices valid = pindices==qindices return pindices[valid] # sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9] col = plt.get_cmap('tab10') def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50, sf_ids='', only_matching=True): print('im1:', im1.size) print('im2:', im2.size) # which sf sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9] if sf_ids.lower().startswith('r'): n_sf_ids = int(sf_ids[1:]) sf_idx_ = np.random.randint(256, size=n_sf_ids) elif sf_ids != '': sf_idx_ = map(int, sf_ids.strip().split(',')) # dataset_ = ImgDataset(images=[im1, im2], imsize=1024) # loader = torch.utils.data.DataLoader(dataset_, shuffle=False, pin_memory=True) im1_tensor = transform(im1).unsqueeze(0) im2_tensor = transform(im2).unsqueeze(0) im1_cv = np.array(im1)[:, :, ::-1].copy() im2_cv = np.array(im2)[:, :, ::-1].copy() # extract features with torch.no_grad(): output1 = net.get_superfeatures(im1_tensor.to(device), scales=[scales[scale_id]]) feats1 = output1[0][0] attns1 = output1[1][0] strenghts1 = output1[2][0] output2 = net.get_superfeatures(im2_tensor.to(device), scales=[scales[scale_id]]) feats2 = output2[0][0] attns2 = output2[1][0] strenghts2 = output2[2][0] feats1n = F.normalize(feats1, dim=1) feats2n = F.normalize(feats2, dim=1) ind_match = match(feats1n, feats2n) print('ind', ind_match) print('ind.shape', ind_match.shape) # outputs = [] # for im_tensor in loader: # outputs.append(net.get_superfeatures(im_tensor.to(device), scales=[scales[scale_id]])) # feats1 = outputs[0][0][0] # attns1 = outputs[0][1][0] # strenghts1 = outputs[0][2][0] # feats2 = outputs[1][0][0] # attns2 = outputs[1][1][0] # strenghts2 = outputs[1][2][0] print(feats1.shape, feats2.shape) print(attns1.shape, attns2.shape) print(strenghts1.shape, strenghts2.shape) # if only_matching: # Store all binary SF att maps to show them all at once in the end all_att_bin1 = [] all_att_bin2 = [] for n, i in enumerate(sf_idx_): # all_atts[n].append(attn[j][scale_id][0,i,:,:].numpy()) att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32) att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0) att_heat_bin = np.where(att_heat>threshold, 255, 0) # print(att_heat_bin) all_att_bin1.append(att_heat_bin) att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32) att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0) att_heat_bin = np.where(att_heat>threshold, 255, 0) all_att_bin2.append(att_heat_bin) fin_img = [] img1rsz = np.copy(im1_cv) print('im1:', im1.size) print('img1rsz:', img1rsz.shape) for j, att in enumerate(all_att_bin1): att = cv2.resize(att, im1.size, interpolation=cv2.INTER_NEAREST) # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC) # att = cv2.resize(att, imgz[i].shape[:2][::-1]) # att = att.resize(shape) # att = resize(att, im1.size) mask2d = zip(*np.where(att==255)) for m,n in mask2d: col_ = col.colors[j] if j < 7 else col.colors[j+1] if j == 0: col_ = col.colors[9] col_ = 255*np.array(colors.to_rgba(col_))[:3] img1rsz[m,n, :] = col_[::-1] fin_img.append(img1rsz) img2rsz = np.copy(im2_cv) print('im2:', im2.size) print('img2rsz:', img2rsz.shape) for j, att in enumerate(all_att_bin2): att = cv2.resize(att, im2.size, interpolation=cv2.INTER_NEAREST) # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC) # # att = cv2.resize(att, imgz[i].shape[:2][::-1]) # att = att.resize(im2.shape) # print('att:', att.shape) mask2d = zip(*np.where(att==255)) for m,n in mask2d: col_ = col.colors[j] if j < 7 else col.colors[j+1] if j == 0: col_ = col.colors[9] col_ = 255*np.array(colors.to_rgba(col_))[:3] img2rsz[m,n, :] = col_[::-1] fin_img.append(img2rsz) fig1 = plt.figure(1) plt.imshow(cv2.cvtColor(img1rsz, cv2.COLOR_BGR2RGB)) ax1 = plt.gca() # ax1.axis('scaled') ax1.axis('off') plt.tight_layout() # fig1.canvas.draw() fig2 = plt.figure(2) plt.imshow(cv2.cvtColor(img2rsz, cv2.COLOR_BGR2RGB)) ax2 = plt.gca() # ax2.axis('scaled') ax2.axis('off') plt.tight_layout() # fig2.canvas.draw() # fig = plt.figure() # grid = ImageGrid(fig, 111, nrows_ncols=(2, 1), axes_pad=0.1) # for ax, img in zip(grid, fin_img): # ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) # ax.axis('scaled') # ax.axis('off') # plt.tight_layout() # fig.suptitle("Matching SFs", fontsize=16) # fig.canvas.draw() # # Now we can save it to a numpy array. # data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) # data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return fig1, fig2, ','.join(map(str, sf_idx_)) # GRADIO APP title = "Visualizing Super-features" description = "This is a visualization demo for the ICLR 2022 paper Learning Super-Features for Image Retrieval

" article = "

Original Github Repo

" # css = ".output-image, .input-image {height: 40rem !important; width: 100% !important;}" # css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }" # css = ".output_image, .input_image {hieght: 1000px !important}" css = ".input_image, .input_image {height: 600px !important; width: 600px !important;} " # css = ".output-image, .input-image {height: 40rem !important; width: 100% !important;}" iface = gr.Interface( fn=generate_matching_superfeatures, inputs=[ # gr.inputs.Image(shape=(1024, 1024), type="pil", label="First Image"), # gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"), gr.inputs.Image(type="pil", label="First Image"), gr.inputs.Image(type="pil", label="Second Image"), gr.inputs.Slider(minimum=0, maximum=6, step=1, default=2, label="Scale"), gr.inputs.Slider(minimum=1, maximum=255, step=25, default=100, label="Binarization Threshold"), gr.inputs.Textbox(lines=1, default="", label="SF IDs to show (comma separated numbers from 0-255; typing 'rX' will return X random SFs", optional=True), gr.inputs.Checkbox(default=True, label="Show only matching SFs", optional=False), ], outputs=[ gr.outputs.Image(type="plot", label="First Image SFs"), gr.outputs.Image(type="plot", label="Second Image SFs"), gr.outputs.Textbox(label="SFs")], # outputs=gr.outputs.Image(shape=(1024,2048), type="plot"), title=title, theme='peach', layout="horizontal", description=description, article=article, css=css, examples=[ ["chateau_1.png", "chateau_2.png", 2, 100, '55,14,5,4,52,57,40,9', True], ["anafi1.jpeg", "anafi2.jpeg", 4, 50, '99,100,142,213,236', True], ["areopoli1.jpeg", "areopoli2.jpeg", 4, 50, '72,44,142,213,236', True], ] ) iface.launch(enable_queue=True)