taesiri's picture
update
f3cff0c
raw history blame
No virus
7.03 kB
import os
import random
from itertools import product
import gdown
import gradio as gr
import matplotlib
import matplotlib.patches as patches
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from matplotlib import pyplot as plt
from matplotlib.patches import ConnectionPatch
from PIL import Image
from torch.utils.data import DataLoader
from common.evaluation import Evaluator
from common.logger import AverageMeter, Logger
from data import download
from model import chmnet
from model.base.geometry import Geometry
# Downloading the Model
# gdown.download(id="1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6", output="pas_psi.pt", quiet=False)
md5 = "6b7b4d7bad7f89600fac340d6aa7708b"
gdown.cached_download(
url="1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6", path="pas_psi.pt", quiet=False, md5=md5
)
# Model Initialization
args = dict(
{
"alpha": [0.05, 0.1],
"benchmark": "pfpascal",
"bsz": 90,
"datapath": "../Datasets_CHM",
"img_size": 240,
"ktype": "psi",
"load": "pas_psi.pt",
"thres": "img",
}
)
model = chmnet.CHMNet(args["ktype"])
model.load_state_dict(torch.load(args["load"], map_location=torch.device("cpu")))
Evaluator.initialize(args["alpha"])
Geometry.initialize(img_size=args["img_size"])
model.eval()
# Transforms
chm_transform = transforms.Compose(
[
transforms.Resize(args["img_size"]),
transforms.CenterCrop((args["img_size"], args["img_size"])),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
chm_transform_plot = transforms.Compose(
[
transforms.Resize(args["img_size"]),
transforms.CenterCrop((args["img_size"], args["img_size"])),
]
)
# A Helper Function
to_np = lambda x: x.data.to("cpu").numpy()
# Colors for Plotting
cmap = matplotlib.cm.get_cmap("Spectral")
rgba = cmap(0.5)
colors = []
for k in range(49):
colors.append(cmap(k / 49.0))
# CHM MODEL
def run_chm(
source_image,
target_image,
selected_points,
number_src_points,
chm_transform,
display_transform,
):
# Convert to Tensor
src_img_tnsr = chm_transform(source_image).unsqueeze(0)
tgt_img_tnsr = chm_transform(target_image).unsqueeze(0)
# Selected_points = selected_points.T
keypoints = torch.tensor(selected_points).unsqueeze(0)
n_pts = torch.tensor(np.asarray([number_src_points]))
# RUN CHM ------------------------------------------------------------------------
with torch.no_grad():
corr_matrix = model(src_img_tnsr, tgt_img_tnsr)
prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False)
# VISUALIZATION
src_points = keypoints[0].squeeze(0).squeeze(0).numpy()
tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy()
src_points_converted = []
w, h = display_transform(source_image).size
for x, y in zip(src_points[0], src_points[1]):
src_points_converted.append(
[int(x * w / args["img_size"]), int((y) * h / args["img_size"])]
)
src_points_converted = np.asarray(src_points_converted[:number_src_points])
tgt_points_converted = []
w, h = display_transform(target_image).size
for x, y in zip(tgt_points[0], tgt_points[1]):
tgt_points_converted.append(
[int(((x + 1) / 2.0) * w), int(((y + 1) / 2.0) * h)]
)
tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points])
tgt_grid = []
for x, y in zip(tgt_points[0], tgt_points[1]):
tgt_grid.append([int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)])
# PLOT
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
ax[0].imshow(display_transform(source_image))
ax[0].scatter(
src_points_converted[:, 0],
src_points_converted[:, 1],
c=colors[:number_src_points],
)
ax[0].set_title("Source")
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].imshow(display_transform(target_image))
ax[1].scatter(
tgt_points_converted[:, 0],
tgt_points_converted[:, 1],
c=colors[:number_src_points],
)
ax[1].set_title("Target")
ax[1].set_xticks([])
ax[1].set_yticks([])
for TL in range(49):
ax[0].text(
x=src_points_converted[TL][0],
y=src_points_converted[TL][1],
s=str(TL),
fontdict=dict(color="red", size=11),
)
for TL in range(49):
ax[1].text(
x=tgt_points_converted[TL][0],
y=tgt_points_converted[TL][1],
s=f"{str(TL)}",
fontdict=dict(color="orange", size=11),
)
plt.tight_layout()
fig.suptitle("CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ", fontsize=16)
return fig
# Wrapper
def generate_correspondences(
sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100
):
A = np.linspace(min_x, max_x, 7)
B = np.linspace(min_y, max_y, 7)
point_list = list(product(A, B))
new_points = np.asarray(point_list, dtype=np.float64).T
return run_chm(
sousrce_image,
target_image,
selected_points=new_points,
number_src_points=49,
chm_transform=chm_transform,
display_transform=chm_transform_plot,
)
# Gradio App
main = gr.Interface(
fn=generate_correspondences,
inputs=[
gr.Image(shape=(240, 240), type="pil"),
gr.Image(shape=(240, 240), type="pil"),
gr.Slider(minimum=1, maximum=240, step=1, default=15, label="Min X"),
gr.Slider(minimum=1, maximum=240, step=1, default=215, label="Max X"),
gr.Slider(minimum=1, maximum=240, step=1, default=15, label="Min Y"),
gr.Slider(minimum=1, maximum=240, step=1, default=215, label="Max Y"),
],
allow_flagging="never",
outputs="plot",
examples=[
["./examples/sample1.jpeg", "./examples/sample2.jpeg", 17, 223, 17, 223],
[
"./examples/Red_Winged_Blackbird_0012_6015.jpg",
"./examples/Red_Winged_Blackbird_0025_5342.jpg",
17,
223,
17,
223,
],
[
"./examples/Yellow_Headed_Blackbird_0026_8545.jpg",
"./examples/Yellow_Headed_Blackbird_0020_8549.jpg",
17,
223,
17,
223,
],
],
)
blocks = gr.Blocks()
with blocks:
gr.Markdown(
"""
# Correspondence Matching with Convolutional Hough Matching Networks
Performs keypoint transform from a 7x7 gird on the source image to the target image. Use the sliders to adjust the grid.
[Original Paper](https://arxiv.org/abs/2103.16831) - [Github Page](https://github.com/juhongm999/chm)
"""
)
gr.TabbedInterface([main], ["Main"])
blocks.launch(
debug=True,
enable_queue=False,
)