import csv import os import random import sys from itertools import product import gdown import gradio as gr import matplotlib import matplotlib.patches as patches import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.models as models import torchvision.transforms as transforms import torchvision.transforms.functional as TF from matplotlib import pyplot as plt from matplotlib.patches import ConnectionPatch from PIL import Image from torch.utils.data import DataLoader from common.evaluation import Evaluator from common.logger import AverageMeter, Logger from data import download from model import chmnet from model.base.geometry import Geometry csv.field_size_limit(sys.maxsize) # Downloading the Model md5 = "6b7b4d7bad7f89600fac340d6aa7708b" gdown.cached_download( url="https://drive.google.com/u/0/uc?id=1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6&export=download", path="pas_psi.pt", quiet=False, md5=md5, ) # Model Initialization args = dict( { "alpha": [0.05, 0.1], "benchmark": "pfpascal", "bsz": 90, "datapath": "../Datasets_CHM", "img_size": 240, "ktype": "psi", "load": "pas_psi.pt", "thres": "img", } ) model = chmnet.CHMNet(args["ktype"]) model.load_state_dict(torch.load(args["load"], map_location=torch.device("cpu"))) Evaluator.initialize(args["alpha"]) Geometry.initialize(img_size=args["img_size"]) model.eval() # Transforms chm_transform = transforms.Compose( [ transforms.Resize(args["img_size"]), transforms.CenterCrop((args["img_size"], args["img_size"])), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) chm_transform_plot = transforms.Compose( [ transforms.Resize(args["img_size"]), transforms.CenterCrop((args["img_size"], args["img_size"])), ] ) # A Helper Function to_np = lambda x: x.data.to("cpu").numpy() # Colors for Plotting cmap = matplotlib.cm.get_cmap("Spectral") rgba = cmap(0.5) colors = [] for k in range(49): colors.append(cmap(k / 49.0)) # CHM MODEL def run_chm( source_image, target_image, selected_points, number_src_points, chm_transform, display_transform, ): # Convert to Tensor src_img_tnsr = chm_transform(source_image).unsqueeze(0) tgt_img_tnsr = chm_transform(target_image).unsqueeze(0) # Selected_points = selected_points.T keypoints = torch.tensor(selected_points).unsqueeze(0) n_pts = torch.tensor(np.asarray([number_src_points])) # RUN CHM ------------------------------------------------------------------------ with torch.no_grad(): corr_matrix = model(src_img_tnsr, tgt_img_tnsr) prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False) # VISUALIZATION src_points = keypoints[0].squeeze(0).squeeze(0).numpy() tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy() src_points_converted = [] w, h = display_transform(source_image).size for x, y in zip(src_points[0], src_points[1]): src_points_converted.append( [int(x * w / args["img_size"]), int((y) * h / args["img_size"])] ) src_points_converted = np.asarray(src_points_converted[:number_src_points]) tgt_points_converted = [] w, h = display_transform(target_image).size for x, y in zip(tgt_points[0], tgt_points[1]): tgt_points_converted.append( [int(((x + 1) / 2.0) * w), int(((y + 1) / 2.0) * h)] ) tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points]) tgt_grid = [] for x, y in zip(tgt_points[0], tgt_points[1]): tgt_grid.append([int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)]) # VISUALIZATION # PLOT fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8)) # Source image plot ax[0].imshow(display_transform(source_image)) ax[0].scatter( src_points_converted[:, 0], src_points_converted[:, 1], c="blue", edgecolors="white", s=50, label="Source points", ) ax[0].set_title("Source Image with Selected Points") ax[0].set_xticks([]) ax[0].set_yticks([]) # Target image plot ax[1].imshow(display_transform(target_image)) ax[1].scatter( tgt_points_converted[:, 0], tgt_points_converted[:, 1], c="red", edgecolors="white", s=50, label="Target points", ) ax[1].set_title("Target Image with Corresponding Points") ax[1].set_xticks([]) ax[1].set_yticks([]) # Adding labels to points for i, (src, tgt) in enumerate(zip(src_points_converted, tgt_points_converted)): ax[0].text(*src, str(i), color="white", bbox=dict(facecolor="black", alpha=0.5)) ax[1].text(*tgt, str(i), color="black", bbox=dict(facecolor="white", alpha=0.7)) # Create a colormap that will generate 49 distinct colors cmap = plt.get_cmap( "gist_rainbow", 49 ) # 'gist_rainbow' is just an example, you can choose another colormap # Drawing lines between corresponding source and target points # for i, (src, tgt) in enumerate(zip(src_points_converted, tgt_points_converted)): # con = ConnectionPatch( # xyA=tgt, # xyB=src, # coordsA="data", # coordsB="data", # axesA=ax[1], # axesB=ax[0], # color=cmap(i), # linewidth=2, # ) # ax[1].add_artist(con) # Adding legend ax[0].legend(loc="lower right", bbox_to_anchor=(1, -0.075)) ax[1].legend(loc="lower right", bbox_to_anchor=(1, -0.075)) plt.tight_layout() plt.subplots_adjust(wspace=0.1, hspace=0.1) fig.suptitle("CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ", fontsize=16) return fig # Wrapper def generate_correspondences( sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100 ): A = np.linspace(min_x, max_x, 7) B = np.linspace(min_y, max_y, 7) point_list = list(product(A, B)) new_points = np.asarray(point_list, dtype=np.float64).T return run_chm( sousrce_image, target_image, selected_points=new_points, number_src_points=49, chm_transform=chm_transform, display_transform=chm_transform_plot, ) with gr.Blocks() as demo: gr.Markdown( """ # Correspondence Matching with Convolutional Hough Matching Networks Performs keypoint transform from a 7x7 gird on the source image to the target image. Use the sliders to adjust the grid. [Original Paper](https://arxiv.org/abs/2103.16831) - [Github Page](https://github.com/juhongm999/chm) """ ) with gr.Row(): # Add an Image component to display the source image. image1 = gr.Image( shape=(240, 240), type="pil", label="Source Image", ) # Add an Image component to display the target image. image2 = gr.Image( shape=(240, 240), type="pil", label="Target Image", ) with gr.Row(): # Add a Slider component to adjust the minimum x-coordinate of the grid. min_x = gr.Slider( minimum=1, maximum=240, step=1, default=15, label="Min X", ) # Add a Slider component to adjust the maximum x-coordinate of the grid. max_x = gr.Slider( minimum=1, maximum=240, step=1, default=215, label="Max X", ) # Add a Slider component to adjust the minimum y-coordinate of the grid. min_y = gr.Slider( minimum=1, maximum=240, step=1, default=15, label="Min Y", ) # Add a Slider component to adjust the maximum y-coordinate of the grid. max_y = gr.Slider( minimum=1, maximum=240, step=1, default=215, label="Max Y", ) with gr.Row(): output_plot = gr.Plot( type="plot", label="Output Plot", ) gr.Examples( [ ["./examples/sample1.jpeg", "./examples/sample2.jpeg", 17, 223, 17, 223], [ "./examples/Red_Winged_Blackbird_0012_6015.jpg", "./examples/Red_Winged_Blackbird_0025_5342.jpg", 17, 223, 17, 223, ], [ "./examples/Yellow_Headed_Blackbird_0026_8545.jpg", "./examples/Yellow_Headed_Blackbird_0020_8549.jpg", 17, 223, 17, 223, ], ], inputs=[ image1, image2, min_x, max_x, min_y, max_y, ], ) run_btn = gr.Button("Run") run_btn.click( generate_correspondences, inputs=[image1, image2, min_x, max_x, min_y, max_y], outputs=output_plot, ) demo.launch(debug=True, enable_queue=False)