|
import csv |
|
import os |
|
import random |
|
import sys |
|
from itertools import product |
|
|
|
import gdown |
|
import gradio as gr |
|
import matplotlib |
|
import matplotlib.patches as patches |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision |
|
import torchvision.models as models |
|
import torchvision.transforms as transforms |
|
import torchvision.transforms.functional as TF |
|
from matplotlib import pyplot as plt |
|
from matplotlib.patches import ConnectionPatch |
|
from PIL import Image |
|
from torch.utils.data import DataLoader |
|
|
|
from common.evaluation import Evaluator |
|
from common.logger import AverageMeter, Logger |
|
from data import download |
|
from model import chmnet |
|
from model.base.geometry import Geometry |
|
|
|
csv.field_size_limit(sys.maxsize) |
|
|
|
|
|
|
|
md5 = "6b7b4d7bad7f89600fac340d6aa7708b" |
|
|
|
gdown.cached_download( |
|
url="https://drive.google.com/u/0/uc?id=1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6&export=download", |
|
path="pas_psi.pt", |
|
quiet=False, |
|
md5=md5, |
|
) |
|
|
|
|
|
args = dict( |
|
{ |
|
"alpha": [0.05, 0.1], |
|
"benchmark": "pfpascal", |
|
"bsz": 90, |
|
"datapath": "../Datasets_CHM", |
|
"img_size": 240, |
|
"ktype": "psi", |
|
"load": "pas_psi.pt", |
|
"thres": "img", |
|
} |
|
) |
|
|
|
model = chmnet.CHMNet(args["ktype"]) |
|
model.load_state_dict(torch.load(args["load"], map_location=torch.device("cpu"))) |
|
Evaluator.initialize(args["alpha"]) |
|
Geometry.initialize(img_size=args["img_size"]) |
|
model.eval() |
|
|
|
|
|
|
|
chm_transform = transforms.Compose( |
|
[ |
|
transforms.Resize(args["img_size"]), |
|
transforms.CenterCrop((args["img_size"], args["img_size"])), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
chm_transform_plot = transforms.Compose( |
|
[ |
|
transforms.Resize(args["img_size"]), |
|
transforms.CenterCrop((args["img_size"], args["img_size"])), |
|
] |
|
) |
|
|
|
|
|
to_np = lambda x: x.data.to("cpu").numpy() |
|
|
|
|
|
cmap = matplotlib.cm.get_cmap("Spectral") |
|
rgba = cmap(0.5) |
|
colors = [] |
|
for k in range(49): |
|
colors.append(cmap(k / 49.0)) |
|
|
|
|
|
|
|
def run_chm( |
|
source_image, |
|
target_image, |
|
selected_points, |
|
number_src_points, |
|
chm_transform, |
|
display_transform, |
|
): |
|
|
|
src_img_tnsr = chm_transform(source_image).unsqueeze(0) |
|
tgt_img_tnsr = chm_transform(target_image).unsqueeze(0) |
|
|
|
|
|
keypoints = torch.tensor(selected_points).unsqueeze(0) |
|
n_pts = torch.tensor(np.asarray([number_src_points])) |
|
|
|
|
|
with torch.no_grad(): |
|
corr_matrix = model(src_img_tnsr, tgt_img_tnsr) |
|
prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False) |
|
|
|
|
|
src_points = keypoints[0].squeeze(0).squeeze(0).numpy() |
|
tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy() |
|
|
|
src_points_converted = [] |
|
w, h = display_transform(source_image).size |
|
|
|
for x, y in zip(src_points[0], src_points[1]): |
|
src_points_converted.append( |
|
[int(x * w / args["img_size"]), int((y) * h / args["img_size"])] |
|
) |
|
|
|
src_points_converted = np.asarray(src_points_converted[:number_src_points]) |
|
tgt_points_converted = [] |
|
|
|
w, h = display_transform(target_image).size |
|
for x, y in zip(tgt_points[0], tgt_points[1]): |
|
tgt_points_converted.append( |
|
[int(((x + 1) / 2.0) * w), int(((y + 1) / 2.0) * h)] |
|
) |
|
|
|
tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points]) |
|
|
|
tgt_grid = [] |
|
|
|
for x, y in zip(tgt_points[0], tgt_points[1]): |
|
tgt_grid.append([int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)]) |
|
|
|
|
|
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8)) |
|
|
|
ax[0].imshow(display_transform(source_image)) |
|
ax[0].scatter( |
|
src_points_converted[:, 0], |
|
src_points_converted[:, 1], |
|
c=colors[:number_src_points], |
|
) |
|
ax[0].set_title("Source") |
|
ax[0].set_xticks([]) |
|
ax[0].set_yticks([]) |
|
|
|
ax[1].imshow(display_transform(target_image)) |
|
ax[1].scatter( |
|
tgt_points_converted[:, 0], |
|
tgt_points_converted[:, 1], |
|
c=colors[:number_src_points], |
|
) |
|
ax[1].set_title("Target") |
|
ax[1].set_xticks([]) |
|
ax[1].set_yticks([]) |
|
|
|
for TL in range(49): |
|
ax[0].text( |
|
x=src_points_converted[TL][0], |
|
y=src_points_converted[TL][1], |
|
s=str(TL), |
|
fontdict=dict(color="red", size=11), |
|
) |
|
|
|
for TL in range(49): |
|
ax[1].text( |
|
x=tgt_points_converted[TL][0], |
|
y=tgt_points_converted[TL][1], |
|
s=f"{str(TL)}", |
|
fontdict=dict(color="orange", size=11), |
|
) |
|
|
|
plt.tight_layout() |
|
fig.suptitle("CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ", fontsize=16) |
|
return fig |
|
|
|
|
|
|
|
def generate_correspondences( |
|
sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100 |
|
): |
|
A = np.linspace(min_x, max_x, 7) |
|
B = np.linspace(min_y, max_y, 7) |
|
point_list = list(product(A, B)) |
|
new_points = np.asarray(point_list, dtype=np.float64).T |
|
return run_chm( |
|
sousrce_image, |
|
target_image, |
|
selected_points=new_points, |
|
number_src_points=49, |
|
chm_transform=chm_transform, |
|
display_transform=chm_transform_plot, |
|
) |
|
|
|
|
|
|
|
main = gr.Interface( |
|
fn=generate_correspondences, |
|
inputs=[ |
|
gr.Image(shape=(240, 240), type="pil"), |
|
gr.Image(shape=(240, 240), type="pil"), |
|
gr.Slider(minimum=1, maximum=240, step=1, default=15, label="Min X"), |
|
gr.Slider(minimum=1, maximum=240, step=1, default=215, label="Max X"), |
|
gr.Slider(minimum=1, maximum=240, step=1, default=15, label="Min Y"), |
|
gr.Slider(minimum=1, maximum=240, step=1, default=215, label="Max Y"), |
|
], |
|
allow_flagging="never", |
|
outputs="plot", |
|
examples=[ |
|
["./examples/sample1.jpeg", "./examples/sample2.jpeg", 17, 223, 17, 223], |
|
[ |
|
"./examples/Red_Winged_Blackbird_0012_6015.jpg", |
|
"./examples/Red_Winged_Blackbird_0025_5342.jpg", |
|
17, |
|
223, |
|
17, |
|
223, |
|
], |
|
[ |
|
"./examples/Yellow_Headed_Blackbird_0026_8545.jpg", |
|
"./examples/Yellow_Headed_Blackbird_0020_8549.jpg", |
|
17, |
|
223, |
|
17, |
|
223, |
|
], |
|
], |
|
) |
|
|
|
|
|
blocks = gr.Blocks() |
|
with blocks: |
|
|
|
gr.Markdown( |
|
""" |
|
# Correspondence Matching with Convolutional Hough Matching Networks |
|
Performs keypoint transform from a 7x7 gird on the source image to the target image. Use the sliders to adjust the grid. |
|
[Original Paper](https://arxiv.org/abs/2103.16831) - [Github Page](https://github.com/juhongm999/chm) |
|
""" |
|
) |
|
|
|
gr.TabbedInterface([main], ["Main"]) |
|
|
|
|
|
blocks.launch( |
|
debug=True, |
|
enable_queue=False, |
|
) |
|
|