Spaces:
Runtime error
Runtime error
ubuntu
commited on
Commit
·
32603e9
1
Parent(s):
ad86786
Initial Commit
Browse files- app.py +164 -0
- pygm_rrwm.py +179 -0
- requirements.txt +6 -0
- src/pygm_default.png +0 -0
- src/pygm_image_1.png +0 -0
- src/pygm_image_2.png +0 -0
app.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import shutil
|
3 |
+
import gradio as gr
|
4 |
+
from pygm_rrwm import pygm_rrwm
|
5 |
+
|
6 |
+
|
7 |
+
PYGM_IMG_DEFAULT_PATH = "src/pygm_default.png"
|
8 |
+
PYGM_SOLUTION_1_PATH = "src/pygm_image_1.png"
|
9 |
+
PYGM_SOLUTION_2_PATH = "src/pygm_image_2.png"
|
10 |
+
|
11 |
+
|
12 |
+
def _handle_pygm_solve(
|
13 |
+
img_1_path: str,
|
14 |
+
img_2_path: str,
|
15 |
+
kpts1_path: str,
|
16 |
+
kpts2_path: str,
|
17 |
+
):
|
18 |
+
if img_1_path is None:
|
19 |
+
raise gr.Error("Please upload file completely!")
|
20 |
+
if img_2_path is None:
|
21 |
+
raise gr.Error("Please upload file completely!")
|
22 |
+
if kpts1_path is None:
|
23 |
+
raise gr.Error("Please upload file completely!")
|
24 |
+
if kpts1_path is None:
|
25 |
+
raise gr.Error("Please upload file completely!")
|
26 |
+
|
27 |
+
start_time = time.time()
|
28 |
+
pygm_rrwm(
|
29 |
+
img1_path=img_1_path,
|
30 |
+
img2_path=img_2_path,
|
31 |
+
kpts1_path=kpts1_path,
|
32 |
+
kpts2_path=kpts2_path,
|
33 |
+
output_path="src",
|
34 |
+
filename="pygm_image"
|
35 |
+
)
|
36 |
+
solved_time = time.time() - start_time
|
37 |
+
|
38 |
+
message = "Successfully solve the TSP problem, using time ({:.3f}s).".format(solved_time)
|
39 |
+
|
40 |
+
return message, PYGM_SOLUTION_1_PATH, PYGM_SOLUTION_2_PATH
|
41 |
+
|
42 |
+
|
43 |
+
def handle_pygm_solve(
|
44 |
+
img_1_path: str,
|
45 |
+
img_2_path: str,
|
46 |
+
kpts1_path: str,
|
47 |
+
kpts2_path: str,
|
48 |
+
):
|
49 |
+
try:
|
50 |
+
message = _handle_pygm_solve(
|
51 |
+
img_1_path=img_1_path,
|
52 |
+
img_2_path=img_2_path,
|
53 |
+
kpts1_path=kpts1_path,
|
54 |
+
kpts2_path=kpts2_path,
|
55 |
+
)
|
56 |
+
return message
|
57 |
+
except Exception as e:
|
58 |
+
message = str(e)
|
59 |
+
return message, PYGM_SOLUTION_1_PATH, PYGM_SOLUTION_2_PATH
|
60 |
+
|
61 |
+
|
62 |
+
def handle_pygm_clear():
|
63 |
+
shutil.copy(
|
64 |
+
src=PYGM_IMG_DEFAULT_PATH,
|
65 |
+
dst=PYGM_SOLUTION_1_PATH
|
66 |
+
)
|
67 |
+
shutil.copy(
|
68 |
+
src=PYGM_IMG_DEFAULT_PATH,
|
69 |
+
dst=PYGM_SOLUTION_2_PATH
|
70 |
+
)
|
71 |
+
|
72 |
+
message = "successfully clear the files!"
|
73 |
+
return message, PYGM_SOLUTION_1_PATH, PYGM_SOLUTION_2_PATH
|
74 |
+
|
75 |
+
|
76 |
+
def convert_image_path_to_bytes(image_path):
|
77 |
+
with open(image_path, "rb") as f:
|
78 |
+
image_bytes = f.read()
|
79 |
+
return image_bytes
|
80 |
+
|
81 |
+
|
82 |
+
with gr.Blocks() as pygm_page:
|
83 |
+
|
84 |
+
gr.Markdown(
|
85 |
+
'''
|
86 |
+
This space displays the solution to the Graph Matching problem.
|
87 |
+
## How to use this Space?
|
88 |
+
- Upload a '.pygm' file from pygmlib .
|
89 |
+
- The images of the TSP problem and solution will be shown after you click the solve button.
|
90 |
+
- Click the 'clear' button to clear all the files.
|
91 |
+
'''
|
92 |
+
)
|
93 |
+
|
94 |
+
with gr.Row(variant="panel"):
|
95 |
+
with gr.Column(scale=2):
|
96 |
+
with gr.Row():
|
97 |
+
pygm_img_1 = gr.File(
|
98 |
+
label="Upload .png File",
|
99 |
+
file_types=[".png"],
|
100 |
+
min_width=40,
|
101 |
+
)
|
102 |
+
pygm_img_2 = gr.File(
|
103 |
+
label="Upload .png File",
|
104 |
+
file_types=[".png"],
|
105 |
+
min_width=40,
|
106 |
+
)
|
107 |
+
with gr.Row():
|
108 |
+
pygm_kpts_1 = gr.File(
|
109 |
+
label="Upload .mat File",
|
110 |
+
file_types=[".mat"],
|
111 |
+
min_width=40,
|
112 |
+
)
|
113 |
+
pygm_kpts_2 = gr.File(
|
114 |
+
label="Upload .mat File",
|
115 |
+
file_types=[".mat"],
|
116 |
+
min_width=40,
|
117 |
+
)
|
118 |
+
info = gr.Textbox(
|
119 |
+
value="",
|
120 |
+
label="Log",
|
121 |
+
scale=4,
|
122 |
+
)
|
123 |
+
with gr.Column(scale=2):
|
124 |
+
pygm_solution_1 = gr.Image(
|
125 |
+
value=PYGM_SOLUTION_1_PATH,
|
126 |
+
type="filepath",
|
127 |
+
label="Original Images"
|
128 |
+
)
|
129 |
+
pygm_solution_2 = gr.Image(
|
130 |
+
value=PYGM_SOLUTION_2_PATH,
|
131 |
+
type="filepath",
|
132 |
+
label="Graph Matching Results"
|
133 |
+
)
|
134 |
+
with gr.Row():
|
135 |
+
with gr.Column(scale=1, min_width=100):
|
136 |
+
solve_button = gr.Button(
|
137 |
+
value="Solve",
|
138 |
+
variant="primary",
|
139 |
+
scale=1
|
140 |
+
)
|
141 |
+
with gr.Column(scale=1, min_width=100):
|
142 |
+
clear_button = gr.Button(
|
143 |
+
"Clear",
|
144 |
+
variant="secondary",
|
145 |
+
scale=1
|
146 |
+
)
|
147 |
+
with gr.Column(scale=8):
|
148 |
+
pass
|
149 |
+
|
150 |
+
solve_button.click(
|
151 |
+
handle_pygm_solve,
|
152 |
+
[pygm_img_1, pygm_img_2, pygm_kpts_1, pygm_kpts_2],
|
153 |
+
outputs=[info, pygm_solution_1, pygm_solution_2]
|
154 |
+
)
|
155 |
+
|
156 |
+
clear_button.click(
|
157 |
+
handle_pygm_clear,
|
158 |
+
inputs=None,
|
159 |
+
outputs=[info, pygm_solution_1, pygm_solution_2]
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
pygm_page.launch(debug = True)
|
pygm_rrwm.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch # pytorch backend
|
3 |
+
import torchvision # CV models
|
4 |
+
import pygmtools as pygm
|
5 |
+
import matplotlib.pyplot as plt # for plotting
|
6 |
+
from matplotlib.patches import ConnectionPatch # for plotting matching result
|
7 |
+
import scipy.io as sio # for loading .mat file
|
8 |
+
import scipy.spatial as spa # for Delaunay triangulation
|
9 |
+
from sklearn.decomposition import PCA as PCAdimReduc
|
10 |
+
import itertools
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
pygm.set_backend('pytorch') # set default backend for pygmtools
|
14 |
+
|
15 |
+
|
16 |
+
##################################################################
|
17 |
+
# Utils Func #
|
18 |
+
##################################################################
|
19 |
+
|
20 |
+
def plot_image_with_graph(img, kpt, A=None):
|
21 |
+
plt.imshow(img)
|
22 |
+
plt.scatter(kpt[0], kpt[1], c='w', edgecolors='k')
|
23 |
+
if A is not None:
|
24 |
+
for idx in torch.nonzero(A, as_tuple=False):
|
25 |
+
plt.plot((kpt[0, idx[0]], kpt[0, idx[1]]), (kpt[1, idx[0]], kpt[1, idx[1]]), 'k-')
|
26 |
+
|
27 |
+
|
28 |
+
def delaunay_triangulation(kpt):
|
29 |
+
d = spa.Delaunay(kpt.numpy().transpose())
|
30 |
+
A = torch.zeros(len(kpt[0]), len(kpt[0]))
|
31 |
+
for simplex in d.simplices:
|
32 |
+
for pair in itertools.permutations(simplex, 2):
|
33 |
+
A[pair] = 1
|
34 |
+
return A
|
35 |
+
|
36 |
+
|
37 |
+
def plot_image_with_graphs(img1, img2, kpts1, kpts2, A1=None, A2=None,
|
38 |
+
title_1: str="Image 1", title_2: str="Image 2", filename="examples.png"):
|
39 |
+
plt.figure(figsize=(8, 4))
|
40 |
+
plt.subplot(1, 2, 1)
|
41 |
+
plt.title(title_1)
|
42 |
+
plot_image_with_graph(img1, kpts1, A1)
|
43 |
+
plt.subplot(1, 2, 2)
|
44 |
+
plt.title(title_2)
|
45 |
+
plot_image_with_graph(img2, kpts2, A2)
|
46 |
+
plt.savefig(filename)
|
47 |
+
|
48 |
+
|
49 |
+
def load_images(
|
50 |
+
img1_path: str,
|
51 |
+
img2_path: str,
|
52 |
+
kpts1_path: str,
|
53 |
+
kpts2_path: str,
|
54 |
+
obj_resize: tuple=(256, 256)
|
55 |
+
):
|
56 |
+
img1 = Image.open(img1_path)
|
57 |
+
img2 = Image.open(img2_path)
|
58 |
+
kpts1 = torch.tensor(sio.loadmat(kpts1_path)['pts_coord'])
|
59 |
+
kpts2 = torch.tensor(sio.loadmat(kpts2_path)['pts_coord'])
|
60 |
+
kpts1[0] = kpts1[0] * obj_resize[0] / img1.size[0]
|
61 |
+
kpts1[1] = kpts1[1] * obj_resize[1] / img1.size[1]
|
62 |
+
kpts2[0] = kpts2[0] * obj_resize[0] / img2.size[0]
|
63 |
+
kpts2[1] = kpts2[1] * obj_resize[1] / img2.size[1]
|
64 |
+
img1 = img1.resize(obj_resize, resample=Image.Resampling.BILINEAR)
|
65 |
+
img2 = img2.resize(obj_resize, resample=Image.Resampling.BILINEAR)
|
66 |
+
return img1, img2, kpts1, kpts2
|
67 |
+
|
68 |
+
|
69 |
+
##################################################################
|
70 |
+
# Process #
|
71 |
+
##################################################################
|
72 |
+
|
73 |
+
def pygm_rrwm(
|
74 |
+
img1_path: str,
|
75 |
+
img2_path: str,
|
76 |
+
kpts1_path: str,
|
77 |
+
kpts2_path: str,
|
78 |
+
obj_resize: tuple=(256, 256),
|
79 |
+
output_path: str="examples",
|
80 |
+
filename: str="example"
|
81 |
+
):
|
82 |
+
if not os.path.exists(output_path):
|
83 |
+
os.mkdir(output_path)
|
84 |
+
output_filename = os.path.join(output_path, filename) + "_{}.png"
|
85 |
+
|
86 |
+
# Load the images
|
87 |
+
img1, img2, kpts1, kpts2 = load_images(img1_path, img2_path, kpts1_path, kpts2_path, obj_resize)
|
88 |
+
plot_image_with_graphs(img1, img2, kpts1, kpts2, filename=output_filename.format(1))
|
89 |
+
|
90 |
+
# Build the graphs
|
91 |
+
A1 = delaunay_triangulation(kpts1)
|
92 |
+
A2 = delaunay_triangulation(kpts2)
|
93 |
+
A1 = ((kpts1.unsqueeze(1) - kpts1.unsqueeze(2)) ** 2).sum(dim=0) * A1
|
94 |
+
A1 = (A1 / A1.max()).to(dtype=torch.float32)
|
95 |
+
A2 = ((kpts2.unsqueeze(1) - kpts2.unsqueeze(2)) ** 2).sum(dim=0) * A2
|
96 |
+
A2 = (A2 / A2.max()).to(dtype=torch.float32)
|
97 |
+
# plot_image_with_graphs(img1, img2, kpts1, kpts2, A1, A2,
|
98 |
+
# "Image 1 with Graphs", "Image 2 with Graphs", output_filename.format(2))
|
99 |
+
|
100 |
+
# Extract node features
|
101 |
+
vgg16_cnn = torchvision.models.vgg16_bn(True)
|
102 |
+
torch_img1 = torch.from_numpy(np.array(img1, dtype=np.float32) / 256).permute(2, 0, 1).unsqueeze(0) # shape: BxCxHxW
|
103 |
+
torch_img2 = torch.from_numpy(np.array(img2, dtype=np.float32) / 256).permute(2, 0, 1).unsqueeze(0) # shape: BxCxHxW
|
104 |
+
with torch.set_grad_enabled(False):
|
105 |
+
feat1 = vgg16_cnn.features(torch_img1)
|
106 |
+
feat2 = vgg16_cnn.features(torch_img2)
|
107 |
+
|
108 |
+
# Normalize the features
|
109 |
+
num_features = feat1.shape[1]
|
110 |
+
def l2norm(node_feat):
|
111 |
+
return torch.nn.functional.local_response_norm(
|
112 |
+
node_feat, node_feat.shape[1] * 2, alpha=node_feat.shape[1] * 2, beta=0.5, k=0)
|
113 |
+
feat1 = l2norm(feat1)
|
114 |
+
feat2 = l2norm(feat2)
|
115 |
+
|
116 |
+
# Up-sample the features to the original image size
|
117 |
+
feat1_upsample = torch.nn.functional.interpolate(feat1, (obj_resize[1], obj_resize[0]), mode='bilinear')
|
118 |
+
feat2_upsample = torch.nn.functional.interpolate(feat2, (obj_resize[1], obj_resize[0]), mode='bilinear')
|
119 |
+
|
120 |
+
# Visualize the extracted CNN feature (dimensionality reduction via principle component analysis)
|
121 |
+
pca_dim_reduc = PCAdimReduc(n_components=3, whiten=True)
|
122 |
+
feat_dim_reduc = pca_dim_reduc.fit_transform(
|
123 |
+
np.concatenate((
|
124 |
+
feat1_upsample.permute(0, 2, 3, 1).reshape(-1, num_features).numpy(),
|
125 |
+
feat2_upsample.permute(0, 2, 3, 1).reshape(-1, num_features).numpy()
|
126 |
+
), axis=0)
|
127 |
+
)
|
128 |
+
feat_dim_reduc = feat_dim_reduc / np.max(np.abs(feat_dim_reduc), axis=0, keepdims=True) / 2 + 0.5
|
129 |
+
feat1_dim_reduc = feat_dim_reduc[:obj_resize[0] * obj_resize[1], :]
|
130 |
+
feat2_dim_reduc = feat_dim_reduc[obj_resize[0] * obj_resize[1]:, :]
|
131 |
+
|
132 |
+
# Plot
|
133 |
+
# plt.figure(figsize=(8, 4))
|
134 |
+
# plt.subplot(1, 2, 1)
|
135 |
+
# plt.title('Image 1 with CNN features')
|
136 |
+
# plot_image_with_graph(img1, kpts1, A1)
|
137 |
+
# plt.imshow(feat1_dim_reduc.reshape(obj_resize[1], obj_resize[0], 3), alpha=0.5)
|
138 |
+
# plt.subplot(1, 2, 2)
|
139 |
+
# plt.title('Image 2 with CNN features')
|
140 |
+
# plot_image_with_graph(img2, kpts2, A2)
|
141 |
+
# plt.imshow(feat2_dim_reduc.reshape(obj_resize[1], obj_resize[0], 3), alpha=0.5)
|
142 |
+
# plt.savefig(output_filename.format(3))
|
143 |
+
|
144 |
+
# Extract node features by nearest interpolation
|
145 |
+
rounded_kpts1 = torch.round(kpts1).to(dtype=torch.long)
|
146 |
+
rounded_kpts2 = torch.round(kpts2).to(dtype=torch.long)
|
147 |
+
node1 = feat1_upsample[0, :, rounded_kpts1[1], rounded_kpts1[0]].t() # shape: NxC
|
148 |
+
node2 = feat2_upsample[0, :, rounded_kpts2[1], rounded_kpts2[0]].t() # shape: NxC
|
149 |
+
|
150 |
+
# Build affinity matrix
|
151 |
+
conn1, edge1 = pygm.utils.dense_to_sparse(A1)
|
152 |
+
conn2, edge2 = pygm.utils.dense_to_sparse(A2)
|
153 |
+
import functools
|
154 |
+
gaussian_aff = functools.partial(pygm.utils.gaussian_aff_fn, sigma=1) # set affinity function
|
155 |
+
K = pygm.utils.build_aff_mat(node1, edge1, conn1, node2, edge2, conn2, edge_aff_fn=gaussian_aff)
|
156 |
+
|
157 |
+
# Plot affinity matrix
|
158 |
+
# plt.figure(figsize=(4, 4))
|
159 |
+
# plt.title(f'Affinity Matrix (size: {K.shape[0]}$\\times${K.shape[1]})')
|
160 |
+
# plt.imshow(K.numpy(), cmap='Blues')
|
161 |
+
# plt.savefig(output_filename.format(4))
|
162 |
+
|
163 |
+
# Solve graph matching problem by RRWM solver
|
164 |
+
X = pygm.rrwm(K, kpts1.shape[1], kpts2.shape[1])
|
165 |
+
X = pygm.hungarian(X)
|
166 |
+
|
167 |
+
# Plot the matching
|
168 |
+
plt.figure(figsize=(8, 4))
|
169 |
+
plt.suptitle('Image Matching Result by RRWM')
|
170 |
+
ax1 = plt.subplot(1, 2, 1)
|
171 |
+
plot_image_with_graph(img1, kpts1, A1)
|
172 |
+
ax2 = plt.subplot(1, 2, 2)
|
173 |
+
plot_image_with_graph(img2, kpts2, A2)
|
174 |
+
for i in range(X.shape[0]):
|
175 |
+
j = torch.argmax(X[i]).item()
|
176 |
+
con = ConnectionPatch(xyA=kpts1[:, i], xyB=kpts2[:, j], coordsA="data", coordsB="data",
|
177 |
+
axesA=ax1, axesB=ax2, color="red" if i != j else "green")
|
178 |
+
plt.gca().add_artist(con)
|
179 |
+
plt.savefig(output_filename.format(2))
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
pygmtools
|
3 |
+
matplotlib
|
4 |
+
torch==2.0.0
|
5 |
+
torchvision==0.15.1
|
6 |
+
scikit-learn
|
src/pygm_default.png
ADDED
![]() |
src/pygm_image_1.png
ADDED
![]() |
src/pygm_image_2.png
ADDED
![]() |