import os import random from itertools import product import gdown import gradio as gr import matplotlib import matplotlib.patches as patches import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.models as models import torchvision.transforms as transforms import torchvision.transforms.functional as TF from matplotlib import pyplot as plt from matplotlib.patches import ConnectionPatch from PIL import Image from torch.utils.data import DataLoader from common.evaluation import Evaluator from common.logger import AverageMeter, Logger from data import download from model import chmnet from model.base.geometry import Geometry # Downloading the Model # gdown.download(id="1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6", output="pas_psi.pt", quiet=False) md5 = "6b7b4d7bad7f89600fac340d6aa7708b" gdown.cached_download( url="1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6", path="pas_psi.pt", quiet=False, md5=md5 ) # Model Initialization args = dict( { "alpha": [0.05, 0.1], "benchmark": "pfpascal", "bsz": 90, "datapath": "../Datasets_CHM", "img_size": 240, "ktype": "psi", "load": "pas_psi.pt", "thres": "img", } ) model = chmnet.CHMNet(args["ktype"]) model.load_state_dict(torch.load(args["load"], map_location=torch.device("cpu"))) Evaluator.initialize(args["alpha"]) Geometry.initialize(img_size=args["img_size"]) model.eval() # Transforms chm_transform = transforms.Compose( [ transforms.Resize(args["img_size"]), transforms.CenterCrop((args["img_size"], args["img_size"])), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) chm_transform_plot = transforms.Compose( [ transforms.Resize(args["img_size"]), transforms.CenterCrop((args["img_size"], args["img_size"])), ] ) # A Helper Function to_np = lambda x: x.data.to("cpu").numpy() # Colors for Plotting cmap = matplotlib.cm.get_cmap("Spectral") rgba = cmap(0.5) colors = [] for k in range(49): colors.append(cmap(k / 49.0)) # CHM MODEL def run_chm( source_image, target_image, selected_points, number_src_points, chm_transform, display_transform, ): # Convert to Tensor src_img_tnsr = chm_transform(source_image).unsqueeze(0) tgt_img_tnsr = chm_transform(target_image).unsqueeze(0) # Selected_points = selected_points.T keypoints = torch.tensor(selected_points).unsqueeze(0) n_pts = torch.tensor(np.asarray([number_src_points])) # RUN CHM ------------------------------------------------------------------------ with torch.no_grad(): corr_matrix = model(src_img_tnsr, tgt_img_tnsr) prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False) # VISUALIZATION src_points = keypoints[0].squeeze(0).squeeze(0).numpy() tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy() src_points_converted = [] w, h = display_transform(source_image).size for x, y in zip(src_points[0], src_points[1]): src_points_converted.append( [int(x * w / args["img_size"]), int((y) * h / args["img_size"])] ) src_points_converted = np.asarray(src_points_converted[:number_src_points]) tgt_points_converted = [] w, h = display_transform(target_image).size for x, y in zip(tgt_points[0], tgt_points[1]): tgt_points_converted.append( [int(((x + 1) / 2.0) * w), int(((y + 1) / 2.0) * h)] ) tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points]) tgt_grid = [] for x, y in zip(tgt_points[0], tgt_points[1]): tgt_grid.append([int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)]) # PLOT fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8)) ax[0].imshow(display_transform(source_image)) ax[0].scatter( src_points_converted[:, 0], src_points_converted[:, 1], c=colors[:number_src_points], ) ax[0].set_title("Source") ax[0].set_xticks([]) ax[0].set_yticks([]) ax[1].imshow(display_transform(target_image)) ax[1].scatter( tgt_points_converted[:, 0], tgt_points_converted[:, 1], c=colors[:number_src_points], ) ax[1].set_title("Target") ax[1].set_xticks([]) ax[1].set_yticks([]) for TL in range(49): ax[0].text( x=src_points_converted[TL][0], y=src_points_converted[TL][1], s=str(TL), fontdict=dict(color="red", size=11), ) for TL in range(49): ax[1].text( x=tgt_points_converted[TL][0], y=tgt_points_converted[TL][1], s=f"{str(TL)}", fontdict=dict(color="orange", size=11), ) plt.tight_layout() fig.suptitle("CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ", fontsize=16) return fig # Wrapper def generate_correspondences( sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100 ): A = np.linspace(min_x, max_x, 7) B = np.linspace(min_y, max_y, 7) point_list = list(product(A, B)) new_points = np.asarray(point_list, dtype=np.float64).T return run_chm( sousrce_image, target_image, selected_points=new_points, number_src_points=49, chm_transform=chm_transform, display_transform=chm_transform_plot, ) # Gradio App main = gr.Interface( fn=generate_correspondences, inputs=[ gr.Image(shape=(240, 240), type="pil"), gr.Image(shape=(240, 240), type="pil"), gr.Slider(minimum=1, maximum=240, step=1, default=15, label="Min X"), gr.Slider(minimum=1, maximum=240, step=1, default=215, label="Max X"), gr.Slider(minimum=1, maximum=240, step=1, default=15, label="Min Y"), gr.Slider(minimum=1, maximum=240, step=1, default=215, label="Max Y"), ], allow_flagging="never", outputs="plot", examples=[ ["./examples/sample1.jpeg", "./examples/sample2.jpeg", 17, 223, 17, 223], [ "./examples/Red_Winged_Blackbird_0012_6015.jpg", "./examples/Red_Winged_Blackbird_0025_5342.jpg", 17, 223, 17, 223, ], [ "./examples/Yellow_Headed_Blackbird_0026_8545.jpg", "./examples/Yellow_Headed_Blackbird_0020_8549.jpg", 17, 223, 17, 223, ], ], ) blocks = gr.Blocks() with blocks: gr.Markdown( """ # Correspondence Matching with Convolutional Hough Matching Networks Performs keypoint transform from a 7x7 gird on the source image to the target image. Use the sliders to adjust the grid. [Original Paper](https://arxiv.org/abs/2103.16831) - [Github Page](https://github.com/juhongm999/chm) """ ) gr.TabbedInterface([main], ["Main"]) blocks.launch( debug=True, enable_queue=False, )