from torch.utils.data import DataLoader import torch from model.base.geometry import Geometry from common.evaluation import Evaluator from common.logger import AverageMeter from common.logger import Logger from data import download from model import chmnet from itertools import product import matplotlib import matplotlib.patches as patches from matplotlib.patches import ConnectionPatch from matplotlib import pyplot as plt from PIL import Image import numpy as np import os import torchvision import torchvision.transforms as transforms import torchvision.transforms.functional as TF import torchvision.models as models import torch.nn as nn import torch.nn.functional as F import random import gradio as gr import gdown # Downloading the Model gdown.download(id="1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6", output="pas_psi.pt", quiet=False) # Model Initialization args = dict( { "alpha": [0.05, 0.1], "benchmark": "pfpascal", "bsz": 90, "datapath": "../Datasets_CHM", "img_size": 240, "ktype": "psi", "load": "pas_psi.pt", "thres": "img", } ) model = chmnet.CHMNet(args["ktype"]) model.load_state_dict(torch.load(args["load"], map_location=torch.device("cpu"))) Evaluator.initialize(args["alpha"]) Geometry.initialize(img_size=args["img_size"]) model.eval() # Transforms chm_transform = transforms.Compose( [ transforms.Resize(args["img_size"]), transforms.CenterCrop((args["img_size"], args["img_size"])), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) chm_transform_plot = transforms.Compose( [ transforms.Resize(args["img_size"]), transforms.CenterCrop((args["img_size"], args["img_size"])), ] ) # A Helper Function to_np = lambda x: x.data.to("cpu").numpy() # Colors for Plotting cmap = matplotlib.cm.get_cmap("Spectral") rgba = cmap(0.5) colors = [] for k in range(49): colors.append(cmap(k / 49.0)) # CHM MODEL def run_chm( source_image, target_image, selected_points, number_src_points, chm_transform, display_transform, ): # Convert to Tensor src_img_tnsr = chm_transform(source_image).unsqueeze(0) tgt_img_tnsr = chm_transform(target_image).unsqueeze(0) # Selected_points = selected_points.T keypoints = torch.tensor(selected_points).unsqueeze(0) n_pts = torch.tensor(np.asarray([number_src_points])) # RUN CHM ------------------------------------------------------------------------ with torch.no_grad(): corr_matrix = model(src_img_tnsr, tgt_img_tnsr) prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False) # VISUALIZATION src_points = keypoints[0].squeeze(0).squeeze(0).numpy() tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy() src_points_converted = [] w, h = display_transform(source_image).size for x, y in zip(src_points[0], src_points[1]): src_points_converted.append( [int(x * w / args["img_size"]), int((y) * h / args["img_size"])] ) src_points_converted = np.asarray(src_points_converted[:number_src_points]) tgt_points_converted = [] w, h = display_transform(target_image).size for x, y in zip(tgt_points[0], tgt_points[1]): tgt_points_converted.append( [int(((x + 1) / 2.0) * w), int(((y + 1) / 2.0) * h)] ) tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points]) tgt_grid = [] for x, y in zip(tgt_points[0], tgt_points[1]): tgt_grid.append([int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)]) # PLOT fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8)) ax[0].imshow(display_transform(source_image)) ax[0].scatter( src_points_converted[:, 0], src_points_converted[:, 1], c=colors[:number_src_points], ) ax[0].set_title("Source") ax[0].set_xticks([]) ax[0].set_yticks([]) ax[1].imshow(display_transform(target_image)) ax[1].scatter( tgt_points_converted[:, 0], tgt_points_converted[:, 1], c=colors[:number_src_points], ) ax[1].set_title("Target") ax[1].set_xticks([]) ax[1].set_yticks([]) for TL in range(49): ax[0].text( x=src_points_converted[TL][0], y=src_points_converted[TL][1], s=str(TL), fontdict=dict(color="red", size=11), ) for TL in range(49): ax[1].text( x=tgt_points_converted[TL][0], y=tgt_points_converted[TL][1], s=f"{str(TL)}", fontdict=dict(color="orange", size=11), ) plt.tight_layout() fig.suptitle("CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ", fontsize=16) return fig # Wrapper def generate_correspondences( sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100 ): A = np.linspace(min_x, max_x, 7) B = np.linspace(min_y, max_y, 7) point_list = list(product(A, B)) new_points = np.asarray(point_list, dtype=np.float64).T return run_chm( sousrce_image, target_image, selected_points=new_points, number_src_points=49, chm_transform=chm_transform, display_transform=chm_transform_plot, ) # GRADIO APP title = "Correspondence Matching with Convolutional Hough Matching Networks " description = "Performs keypoint transform from a 7x7 gird on the source image to the target image. Use the sliders to adjust the grid." article = "
" iface = gr.Interface( fn=generate_correspondences, inputs=[ gr.inputs.Image(shape=(240, 240), type="pil"), gr.inputs.Image(shape=(240, 240), type="pil"), gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label="Min X"), gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label="Max X"), gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label="Min Y"), gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label="Max Y"), ], outputs="plot", enable_queue=True, title=title, description=description, article=article, examples=[["sample1.jpeg", "sample2.jpeg", 15, 215, 15, 215]], ) iface.launch()