Spaces:
Build error
Build error
| import csv | |
| import os | |
| import random | |
| import sys | |
| from itertools import product | |
| import gdown | |
| import gradio as gr | |
| import matplotlib | |
| import matplotlib.patches as patches | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms.functional as TF | |
| from matplotlib import pyplot as plt | |
| from matplotlib.patches import ConnectionPatch | |
| from PIL import Image | |
| from torch.utils.data import DataLoader | |
| from common.evaluation import Evaluator | |
| from common.logger import AverageMeter, Logger | |
| from data import download | |
| from model import chmnet | |
| from model.base.geometry import Geometry | |
| csv.field_size_limit(sys.maxsize) | |
| # Downloading the Model | |
| md5 = "6b7b4d7bad7f89600fac340d6aa7708b" | |
| gdown.cached_download( | |
| url="https://drive.google.com/u/0/uc?id=1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6&export=download", | |
| path="pas_psi.pt", | |
| quiet=False, | |
| md5=md5, | |
| ) | |
| # Model Initialization | |
| args = dict( | |
| { | |
| "alpha": [0.05, 0.1], | |
| "benchmark": "pfpascal", | |
| "bsz": 90, | |
| "datapath": "../Datasets_CHM", | |
| "img_size": 240, | |
| "ktype": "psi", | |
| "load": "pas_psi.pt", | |
| "thres": "img", | |
| } | |
| ) | |
| model = chmnet.CHMNet(args["ktype"]) | |
| model.load_state_dict(torch.load(args["load"], map_location=torch.device("cpu"))) | |
| Evaluator.initialize(args["alpha"]) | |
| Geometry.initialize(img_size=args["img_size"]) | |
| model.eval() | |
| # Transforms | |
| chm_transform = transforms.Compose( | |
| [ | |
| transforms.Resize(args["img_size"]), | |
| transforms.CenterCrop((args["img_size"], args["img_size"])), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| chm_transform_plot = transforms.Compose( | |
| [ | |
| transforms.Resize(args["img_size"]), | |
| transforms.CenterCrop((args["img_size"], args["img_size"])), | |
| ] | |
| ) | |
| # A Helper Function | |
| to_np = lambda x: x.data.to("cpu").numpy() | |
| # Colors for Plotting | |
| cmap = matplotlib.cm.get_cmap("Spectral") | |
| rgba = cmap(0.5) | |
| colors = [] | |
| for k in range(49): | |
| colors.append(cmap(k / 49.0)) | |
| # CHM MODEL | |
| def run_chm( | |
| source_image, | |
| target_image, | |
| selected_points, | |
| number_src_points, | |
| chm_transform, | |
| display_transform, | |
| ): | |
| # Convert to Tensor | |
| src_img_tnsr = chm_transform(source_image).unsqueeze(0) | |
| tgt_img_tnsr = chm_transform(target_image).unsqueeze(0) | |
| # Selected_points = selected_points.T | |
| keypoints = torch.tensor(selected_points).unsqueeze(0) | |
| n_pts = torch.tensor(np.asarray([number_src_points])) | |
| # RUN CHM ------------------------------------------------------------------------ | |
| with torch.no_grad(): | |
| corr_matrix = model(src_img_tnsr, tgt_img_tnsr) | |
| prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False) | |
| # VISUALIZATION | |
| src_points = keypoints[0].squeeze(0).squeeze(0).numpy() | |
| tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy() | |
| src_points_converted = [] | |
| w, h = display_transform(source_image).size | |
| for x, y in zip(src_points[0], src_points[1]): | |
| src_points_converted.append( | |
| [int(x * w / args["img_size"]), int((y) * h / args["img_size"])] | |
| ) | |
| src_points_converted = np.asarray(src_points_converted[:number_src_points]) | |
| tgt_points_converted = [] | |
| w, h = display_transform(target_image).size | |
| for x, y in zip(tgt_points[0], tgt_points[1]): | |
| tgt_points_converted.append( | |
| [int(((x + 1) / 2.0) * w), int(((y + 1) / 2.0) * h)] | |
| ) | |
| tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points]) | |
| tgt_grid = [] | |
| for x, y in zip(tgt_points[0], tgt_points[1]): | |
| tgt_grid.append([int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)]) | |
| # VISUALIZATION | |
| # PLOT | |
| fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8)) | |
| # Source image plot | |
| ax[0].imshow(display_transform(source_image)) | |
| ax[0].scatter( | |
| src_points_converted[:, 0], | |
| src_points_converted[:, 1], | |
| c="blue", | |
| edgecolors="white", | |
| s=50, | |
| label="Source points", | |
| ) | |
| ax[0].set_title("Source Image with Selected Points") | |
| ax[0].set_xticks([]) | |
| ax[0].set_yticks([]) | |
| # Target image plot | |
| ax[1].imshow(display_transform(target_image)) | |
| ax[1].scatter( | |
| tgt_points_converted[:, 0], | |
| tgt_points_converted[:, 1], | |
| c="red", | |
| edgecolors="white", | |
| s=50, | |
| label="Target points", | |
| ) | |
| ax[1].set_title("Target Image with Corresponding Points") | |
| ax[1].set_xticks([]) | |
| ax[1].set_yticks([]) | |
| # Adding labels to points | |
| for i, (src, tgt) in enumerate(zip(src_points_converted, tgt_points_converted)): | |
| ax[0].text(*src, str(i), color="white", bbox=dict(facecolor="black", alpha=0.5)) | |
| ax[1].text(*tgt, str(i), color="black", bbox=dict(facecolor="white", alpha=0.7)) | |
| # Create a colormap that will generate 49 distinct colors | |
| cmap = plt.get_cmap( | |
| "gist_rainbow", 49 | |
| ) # 'gist_rainbow' is just an example, you can choose another colormap | |
| # Drawing lines between corresponding source and target points | |
| # for i, (src, tgt) in enumerate(zip(src_points_converted, tgt_points_converted)): | |
| # con = ConnectionPatch( | |
| # xyA=tgt, | |
| # xyB=src, | |
| # coordsA="data", | |
| # coordsB="data", | |
| # axesA=ax[1], | |
| # axesB=ax[0], | |
| # color=cmap(i), | |
| # linewidth=2, | |
| # ) | |
| # ax[1].add_artist(con) | |
| # Adding legend | |
| ax[0].legend(loc="lower right", bbox_to_anchor=(1, -0.075)) | |
| ax[1].legend(loc="lower right", bbox_to_anchor=(1, -0.075)) | |
| plt.tight_layout() | |
| plt.subplots_adjust(wspace=0.1, hspace=0.1) | |
| fig.suptitle("CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ", fontsize=16) | |
| return fig | |
| # Wrapper | |
| def generate_correspondences( | |
| sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100 | |
| ): | |
| A = np.linspace(min_x, max_x, 7) | |
| B = np.linspace(min_y, max_y, 7) | |
| point_list = list(product(A, B)) | |
| new_points = np.asarray(point_list, dtype=np.float64).T | |
| return run_chm( | |
| sousrce_image, | |
| target_image, | |
| selected_points=new_points, | |
| number_src_points=49, | |
| chm_transform=chm_transform, | |
| display_transform=chm_transform_plot, | |
| ) | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Correspondence Matching with Convolutional Hough Matching Networks | |
| Performs keypoint transform from a 7x7 gird on the source image to the target image. Use the sliders to adjust the grid. | |
| [Original Paper](https://arxiv.org/abs/2103.16831) - [Github Page](https://github.com/juhongm999/chm) | |
| """ | |
| ) | |
| with gr.Row(): | |
| # Add an Image component to display the source image. | |
| image1 = gr.Image( | |
| height=240, | |
| width=240, | |
| type="pil", | |
| label="Source Image", | |
| ) | |
| # Add an Image component to display the target image. | |
| image2 = gr.Image( | |
| height=240, | |
| width=240, | |
| type="pil", | |
| label="Target Image", | |
| ) | |
| with gr.Row(): | |
| # Add a Slider component to adjust the minimum x-coordinate of the grid. | |
| min_x = gr.Slider( | |
| minimum=1, | |
| maximum=240, | |
| step=1, | |
| value=15, | |
| label="Min X", | |
| ) | |
| # Add a Slider component to adjust the maximum x-coordinate of the grid. | |
| max_x = gr.Slider( | |
| minimum=1, | |
| maximum=240, | |
| step=1, | |
| value=215, | |
| label="Max X", | |
| ) | |
| # Add a Slider component to adjust the minimum y-coordinate of the grid. | |
| min_y = gr.Slider( | |
| minimum=1, | |
| maximum=240, | |
| step=1, | |
| value=15, | |
| label="Min Y", | |
| ) | |
| # Add a Slider component to adjust the maximum y-coordinate of the grid. | |
| max_y = gr.Slider( | |
| minimum=1, | |
| maximum=240, | |
| step=1, | |
| value=215, | |
| label="Max Y", | |
| ) | |
| with gr.Row(): | |
| output_plot = gr.Plot() | |
| gr.Examples( | |
| [ | |
| ["./examples/sample1.jpeg", "./examples/sample2.jpeg", 17, 223, 17, 223], | |
| [ | |
| "./examples/Red_Winged_Blackbird_0012_6015.jpg", | |
| "./examples/Red_Winged_Blackbird_0025_5342.jpg", | |
| 17, | |
| 223, | |
| 17, | |
| 223, | |
| ], | |
| [ | |
| "./examples/Yellow_Headed_Blackbird_0026_8545.jpg", | |
| "./examples/Yellow_Headed_Blackbird_0020_8549.jpg", | |
| 17, | |
| 223, | |
| 17, | |
| 223, | |
| ], | |
| ], | |
| inputs=[ | |
| image1, | |
| image2, | |
| min_x, | |
| max_x, | |
| min_y, | |
| max_y, | |
| ], | |
| ) | |
| run_btn = gr.Button("Run") | |
| run_btn.click( | |
| generate_correspondences, | |
| inputs=[image1, image2, min_x, max_x, min_y, max_y], | |
| outputs=output_plot, | |
| ) | |
| demo.launch() | |