Spaces:
Runtime error
Runtime error
initial commit
Browse files- app.py +177 -0
- pre-requirements.txt +2 -0
- standard.npy +3 -0
app.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
# gradio for visual demo
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
# transformers for easy access to nnet
|
7 |
+
#os.system("pip install scipy")
|
8 |
+
#os.system("pip install torch")
|
9 |
+
#os.system("pip install scikit-learn")
|
10 |
+
#os.system("pip install torchvision")
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torchvision.transforms as transforms
|
15 |
+
from PIL import ImageDraw, ImageColor, Image
|
16 |
+
from typing import Tuple
|
17 |
+
from scipy.ndimage import binary_closing, binary_opening
|
18 |
+
from sklearn.decomposition import PCA
|
19 |
+
from sklearn.neighbors import NearestNeighbors
|
20 |
+
from random import randint
|
21 |
+
|
22 |
+
### Models
|
23 |
+
standard_array = np.load('standard.npy')
|
24 |
+
pca_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
|
25 |
+
pca_model.eval()
|
26 |
+
|
27 |
+
### Parameters
|
28 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
29 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
30 |
+
|
31 |
+
smaller_edge_size = 448
|
32 |
+
interpolation_mode = transforms.InterpolationMode.BICUBIC
|
33 |
+
patch_size = pca_model.patch_size
|
34 |
+
background_threshold = 0.05
|
35 |
+
|
36 |
+
apply_opening = True
|
37 |
+
apply_closing = True
|
38 |
+
device = 'cpu'
|
39 |
+
|
40 |
+
def make_transform() -> transforms.Compose:
|
41 |
+
return transforms.Compose([
|
42 |
+
transforms.Resize(size=smaller_edge_size, interpolation=interpolation_mode, antialias=True),
|
43 |
+
transforms.ToTensor(),
|
44 |
+
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
45 |
+
])
|
46 |
+
|
47 |
+
def prepare_image(image: Image) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
48 |
+
transform = make_transform()
|
49 |
+
image_tensor = transform(image)
|
50 |
+
resize_scale = image.width / image_tensor.shape[2]
|
51 |
+
|
52 |
+
# Crop image to dimensions that are a multiple of the patch size
|
53 |
+
height, width = image_tensor.shape[1:] # C x H x W
|
54 |
+
cropped_width, cropped_height = width - width % patch_size, height - height % patch_size
|
55 |
+
image_tensor = image_tensor[:, :cropped_height, :cropped_width]
|
56 |
+
|
57 |
+
grid_size = (cropped_height // patch_size, cropped_width // patch_size) # h x w (TODO: check)
|
58 |
+
return image_tensor, grid_size, resize_scale
|
59 |
+
|
60 |
+
def make_foreground_mask(tokens,
|
61 |
+
grid_size: Tuple[int, int]):
|
62 |
+
projection = tokens @ standard_array
|
63 |
+
mask = projection >= background_threshold
|
64 |
+
mask = mask.reshape(*grid_size)
|
65 |
+
if apply_opening:
|
66 |
+
mask = binary_opening(mask)
|
67 |
+
if apply_closing:
|
68 |
+
mask = binary_closing(mask)
|
69 |
+
return mask.flatten()
|
70 |
+
|
71 |
+
def render_patch_pca(image: Image, mask, filter_background, tokens, grid_size) -> Image:
|
72 |
+
pca = PCA(n_components=3)
|
73 |
+
if filter_background : pca.fit(tokens[mask])
|
74 |
+
else : pca.fit(tokens)
|
75 |
+
projected_tokens = pca.transform(tokens)
|
76 |
+
|
77 |
+
t = torch.tensor(projected_tokens)
|
78 |
+
t_min = t.min(dim=0, keepdim=True).values
|
79 |
+
t_max = t.max(dim=0, keepdim=True).values
|
80 |
+
normalized_t = (t - t_min) / (t_max - t_min)
|
81 |
+
|
82 |
+
array = (normalized_t * 255).byte().numpy()
|
83 |
+
if filter_background : array[~mask] = 0
|
84 |
+
array = array.reshape(*grid_size, 3)
|
85 |
+
|
86 |
+
return Image.fromarray(array).resize((image.width, image.height), 0)
|
87 |
+
|
88 |
+
def extract_features(img, image_tensor, filter_background, grid_size):
|
89 |
+
with torch.inference_mode():
|
90 |
+
image_batch = image_tensor.unsqueeze(0).to(device)
|
91 |
+
tokens = pca_model.get_intermediate_layers(image_batch)[0].squeeze()
|
92 |
+
mask = make_foreground_mask(tokens, grid_size)
|
93 |
+
img_pca = render_patch_pca(img, mask, filter_background, tokens, grid_size)
|
94 |
+
return tokens.cpu().numpy(), mask, img_pca
|
95 |
+
|
96 |
+
def compute_features(img, filter_background):
|
97 |
+
image_tensor, grid_size, resize_scale = prepare_image(img)
|
98 |
+
features, mask, img_pca = extract_features(img, image_tensor, filter_background, grid_size)
|
99 |
+
return features, mask, grid_size, resize_scale, img_pca
|
100 |
+
|
101 |
+
def idx_to_source_position(idx, grid_size, resize_scale):
|
102 |
+
row = (idx // grid_size[1])*pca_model.patch_size*resize_scale + pca_model.patch_size / 2
|
103 |
+
col = (idx % grid_size[1])*pca_model.patch_size*resize_scale + pca_model.patch_size / 2
|
104 |
+
return row, col
|
105 |
+
|
106 |
+
def compute_nn(features1, features2):
|
107 |
+
knn = NearestNeighbors(n_neighbors=1)
|
108 |
+
knn.fit(features1)
|
109 |
+
distances, match2to1 = knn.kneighbors(features2)
|
110 |
+
match2to1 = np.array(match2to1)
|
111 |
+
return distances, match2to1
|
112 |
+
|
113 |
+
def compute_matches(img1, img2, lr_check, filter_background, display_matches_threshold):
|
114 |
+
# compute features
|
115 |
+
features1, mask1, grid_size1, resize_scale1, img_pca1 = compute_features(img1, filter_background)
|
116 |
+
features2, mask2, grid_size2, resize_scale2, img_pca2 = compute_features(img2, filter_background)
|
117 |
+
|
118 |
+
# match features
|
119 |
+
distances2to1, match2to1 = compute_nn(features1, features2)
|
120 |
+
distances1to2, match1to2 = compute_nn(features2, features1)
|
121 |
+
|
122 |
+
# display matches
|
123 |
+
draw1 = ImageDraw.Draw(img1)
|
124 |
+
draw2 = ImageDraw.Draw(img2)
|
125 |
+
|
126 |
+
if(img1.size[1] > img2.size[1]):
|
127 |
+
img1 = img1.resize(img2.size)
|
128 |
+
resize_scale1 = resize_scale2
|
129 |
+
else:
|
130 |
+
img2 = img2.resize(img1.size)
|
131 |
+
resize_scale2 = resize_scale1
|
132 |
+
|
133 |
+
offset = img1.size[0]
|
134 |
+
merged_image = Image.new('RGB',(offset + img2.size[0], max(img1.size[1], img2.size[1])), (250,250,250))
|
135 |
+
merged_image.paste(img1,(0,0))
|
136 |
+
merged_image.paste(img2,(offset,0))
|
137 |
+
draw = ImageDraw.Draw(merged_image)
|
138 |
+
|
139 |
+
colormap = ImageColor.colormap
|
140 |
+
for idx2, idx1 in enumerate(match2to1):
|
141 |
+
if lr_check and match1to2[idx1] != idx2: continue
|
142 |
+
row1, col1 = idx_to_source_position(idx1, grid_size1, resize_scale1)
|
143 |
+
row2, col2 = idx_to_source_position(idx2, grid_size2, resize_scale2)
|
144 |
+
|
145 |
+
if filter_background and not mask1[idx1]: continue
|
146 |
+
if filter_background and not mask2[idx2]: continue
|
147 |
+
|
148 |
+
r = randint(0,255)
|
149 |
+
g = randint(0,255)
|
150 |
+
color = (r,g,255-r)
|
151 |
+
|
152 |
+
draw1.point((col1, row1), color)
|
153 |
+
draw2.point((col2, row2), color)
|
154 |
+
|
155 |
+
if 100*np.random.rand() > display_matches_threshold: continue
|
156 |
+
draw.line((col1, row1, col2 + offset, row2), fill=color)
|
157 |
+
|
158 |
+
return [[img1, img_pca1], [img2, img_pca2], merged_image]
|
159 |
+
|
160 |
+
iface = gr.Interface(fn=compute_matches,
|
161 |
+
inputs=[
|
162 |
+
gr.Image(type="pil"),
|
163 |
+
gr.Image(type="pil"),
|
164 |
+
gr.Checkbox(label="Keep only symmetric matches",),
|
165 |
+
gr.Checkbox(label="Mask background"),
|
166 |
+
gr.Slider(0, 100, step=5, value=5, label="Display matches ratio", info="Choose between 0 and 100%"),
|
167 |
+
],
|
168 |
+
outputs=[
|
169 |
+
gr.Gallery(
|
170 |
+
label="Image 1", show_label=False, elem_id="gallery",
|
171 |
+
columns=[2], rows=[1], object_fit="contain", height="auto"),
|
172 |
+
gr.Gallery(
|
173 |
+
label="Image 1", show_label=False, elem_id="gallery",
|
174 |
+
columns=[2], rows=[1], object_fit="contain", height="auto"),
|
175 |
+
gr.Image(type="pil")
|
176 |
+
])
|
177 |
+
iface.launch(debug=True)
|
pre-requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
pip>=23.2
|
2 |
+
gradio_client==0.2.7
|
standard.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8670bf5a566f92828abc119766e8257c21a1e8f4f1c3e477be14ab6a4bb9afa2
|
3 |
+
size 6272
|