taesiri's picture
update
1600894
import csv
import os
import random
import sys
from itertools import product
import gdown
import gradio as gr
import matplotlib
import matplotlib.patches as patches
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from matplotlib import pyplot as plt
from matplotlib.patches import ConnectionPatch
from PIL import Image
from torch.utils.data import DataLoader
from common.evaluation import Evaluator
from common.logger import AverageMeter, Logger
from data import download
from model import chmnet
from model.base.geometry import Geometry
csv.field_size_limit(sys.maxsize)
# Downloading the Model
md5 = "6b7b4d7bad7f89600fac340d6aa7708b"
gdown.cached_download(
url="https://drive.google.com/u/0/uc?id=1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6&export=download",
path="pas_psi.pt",
quiet=False,
md5=md5,
)
# Model Initialization
args = dict(
{
"alpha": [0.05, 0.1],
"benchmark": "pfpascal",
"bsz": 90,
"datapath": "../Datasets_CHM",
"img_size": 240,
"ktype": "psi",
"load": "pas_psi.pt",
"thres": "img",
}
)
model = chmnet.CHMNet(args["ktype"])
model.load_state_dict(torch.load(args["load"], map_location=torch.device("cpu")))
Evaluator.initialize(args["alpha"])
Geometry.initialize(img_size=args["img_size"])
model.eval()
# Transforms
chm_transform = transforms.Compose(
[
transforms.Resize(args["img_size"]),
transforms.CenterCrop((args["img_size"], args["img_size"])),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
chm_transform_plot = transforms.Compose(
[
transforms.Resize(args["img_size"]),
transforms.CenterCrop((args["img_size"], args["img_size"])),
]
)
# A Helper Function
to_np = lambda x: x.data.to("cpu").numpy()
# Colors for Plotting
cmap = matplotlib.cm.get_cmap("Spectral")
rgba = cmap(0.5)
colors = []
for k in range(49):
colors.append(cmap(k / 49.0))
# CHM MODEL
def run_chm(
source_image,
target_image,
selected_points,
number_src_points,
chm_transform,
display_transform,
):
# Convert to Tensor
src_img_tnsr = chm_transform(source_image).unsqueeze(0)
tgt_img_tnsr = chm_transform(target_image).unsqueeze(0)
# Selected_points = selected_points.T
keypoints = torch.tensor(selected_points).unsqueeze(0)
n_pts = torch.tensor(np.asarray([number_src_points]))
# RUN CHM ------------------------------------------------------------------------
with torch.no_grad():
corr_matrix = model(src_img_tnsr, tgt_img_tnsr)
prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False)
# VISUALIZATION
src_points = keypoints[0].squeeze(0).squeeze(0).numpy()
tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy()
src_points_converted = []
w, h = display_transform(source_image).size
for x, y in zip(src_points[0], src_points[1]):
src_points_converted.append(
[int(x * w / args["img_size"]), int((y) * h / args["img_size"])]
)
src_points_converted = np.asarray(src_points_converted[:number_src_points])
tgt_points_converted = []
w, h = display_transform(target_image).size
for x, y in zip(tgt_points[0], tgt_points[1]):
tgt_points_converted.append(
[int(((x + 1) / 2.0) * w), int(((y + 1) / 2.0) * h)]
)
tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points])
tgt_grid = []
for x, y in zip(tgt_points[0], tgt_points[1]):
tgt_grid.append([int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)])
# VISUALIZATION
# PLOT
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
# Source image plot
ax[0].imshow(display_transform(source_image))
ax[0].scatter(
src_points_converted[:, 0],
src_points_converted[:, 1],
c="blue",
edgecolors="white",
s=50,
label="Source points",
)
ax[0].set_title("Source Image with Selected Points")
ax[0].set_xticks([])
ax[0].set_yticks([])
# Target image plot
ax[1].imshow(display_transform(target_image))
ax[1].scatter(
tgt_points_converted[:, 0],
tgt_points_converted[:, 1],
c="red",
edgecolors="white",
s=50,
label="Target points",
)
ax[1].set_title("Target Image with Corresponding Points")
ax[1].set_xticks([])
ax[1].set_yticks([])
# Adding labels to points
for i, (src, tgt) in enumerate(zip(src_points_converted, tgt_points_converted)):
ax[0].text(*src, str(i), color="white", bbox=dict(facecolor="black", alpha=0.5))
ax[1].text(*tgt, str(i), color="black", bbox=dict(facecolor="white", alpha=0.7))
# Create a colormap that will generate 49 distinct colors
cmap = plt.get_cmap(
"gist_rainbow", 49
) # 'gist_rainbow' is just an example, you can choose another colormap
# Drawing lines between corresponding source and target points
# for i, (src, tgt) in enumerate(zip(src_points_converted, tgt_points_converted)):
# con = ConnectionPatch(
# xyA=tgt,
# xyB=src,
# coordsA="data",
# coordsB="data",
# axesA=ax[1],
# axesB=ax[0],
# color=cmap(i),
# linewidth=2,
# )
# ax[1].add_artist(con)
# Adding legend
ax[0].legend(loc="lower right", bbox_to_anchor=(1, -0.075))
ax[1].legend(loc="lower right", bbox_to_anchor=(1, -0.075))
plt.tight_layout()
plt.subplots_adjust(wspace=0.1, hspace=0.1)
fig.suptitle("CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ", fontsize=16)
return fig
# Wrapper
def generate_correspondences(
sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100
):
A = np.linspace(min_x, max_x, 7)
B = np.linspace(min_y, max_y, 7)
point_list = list(product(A, B))
new_points = np.asarray(point_list, dtype=np.float64).T
return run_chm(
sousrce_image,
target_image,
selected_points=new_points,
number_src_points=49,
chm_transform=chm_transform,
display_transform=chm_transform_plot,
)
with gr.Blocks() as demo:
gr.Markdown(
"""
# Correspondence Matching with Convolutional Hough Matching Networks
Performs keypoint transform from a 7x7 gird on the source image to the target image. Use the sliders to adjust the grid.
[Original Paper](https://arxiv.org/abs/2103.16831) - [Github Page](https://github.com/juhongm999/chm)
"""
)
with gr.Row():
# Add an Image component to display the source image.
image1 = gr.Image(
shape=(240, 240),
type="pil",
label="Source Image",
)
# Add an Image component to display the target image.
image2 = gr.Image(
shape=(240, 240),
type="pil",
label="Target Image",
)
with gr.Row():
# Add a Slider component to adjust the minimum x-coordinate of the grid.
min_x = gr.Slider(
minimum=1,
maximum=240,
step=1,
default=15,
label="Min X",
)
# Add a Slider component to adjust the maximum x-coordinate of the grid.
max_x = gr.Slider(
minimum=1,
maximum=240,
step=1,
default=215,
label="Max X",
)
# Add a Slider component to adjust the minimum y-coordinate of the grid.
min_y = gr.Slider(
minimum=1,
maximum=240,
step=1,
default=15,
label="Min Y",
)
# Add a Slider component to adjust the maximum y-coordinate of the grid.
max_y = gr.Slider(
minimum=1,
maximum=240,
step=1,
default=215,
label="Max Y",
)
with gr.Row():
output_plot = gr.Plot(
type="plot",
label="Output Plot",
)
gr.Examples(
[
["./examples/sample1.jpeg", "./examples/sample2.jpeg", 17, 223, 17, 223],
[
"./examples/Red_Winged_Blackbird_0012_6015.jpg",
"./examples/Red_Winged_Blackbird_0025_5342.jpg",
17,
223,
17,
223,
],
[
"./examples/Yellow_Headed_Blackbird_0026_8545.jpg",
"./examples/Yellow_Headed_Blackbird_0020_8549.jpg",
17,
223,
17,
223,
],
],
inputs=[
image1,
image2,
min_x,
max_x,
min_y,
max_y,
],
)
run_btn = gr.Button("Run")
run_btn.click(
generate_correspondences,
inputs=[image1, image2, min_x, max_x, min_y, max_y],
outputs=output_plot,
)
demo.launch(debug=True, enable_queue=False)