ysalaun commited on
Commit
6cf6f4e
Β·
1 Parent(s): a05f9a0

initial commit

Browse files
Files changed (3) hide show
  1. app.py +177 -0
  2. pre-requirements.txt +2 -0
  3. standard.npy +3 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # gradio for visual demo
4
+ import gradio as gr
5
+
6
+ # transformers for easy access to nnet
7
+ #os.system("pip install scipy")
8
+ #os.system("pip install torch")
9
+ #os.system("pip install scikit-learn")
10
+ #os.system("pip install torchvision")
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torchvision.transforms as transforms
15
+ from PIL import ImageDraw, ImageColor, Image
16
+ from typing import Tuple
17
+ from scipy.ndimage import binary_closing, binary_opening
18
+ from sklearn.decomposition import PCA
19
+ from sklearn.neighbors import NearestNeighbors
20
+ from random import randint
21
+
22
+ ### Models
23
+ standard_array = np.load('standard.npy')
24
+ pca_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
25
+ pca_model.eval()
26
+
27
+ ### Parameters
28
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
29
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
30
+
31
+ smaller_edge_size = 448
32
+ interpolation_mode = transforms.InterpolationMode.BICUBIC
33
+ patch_size = pca_model.patch_size
34
+ background_threshold = 0.05
35
+
36
+ apply_opening = True
37
+ apply_closing = True
38
+ device = 'cpu'
39
+
40
+ def make_transform() -> transforms.Compose:
41
+ return transforms.Compose([
42
+ transforms.Resize(size=smaller_edge_size, interpolation=interpolation_mode, antialias=True),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
45
+ ])
46
+
47
+ def prepare_image(image: Image) -> Tuple[torch.Tensor, Tuple[int, int]]:
48
+ transform = make_transform()
49
+ image_tensor = transform(image)
50
+ resize_scale = image.width / image_tensor.shape[2]
51
+
52
+ # Crop image to dimensions that are a multiple of the patch size
53
+ height, width = image_tensor.shape[1:] # C x H x W
54
+ cropped_width, cropped_height = width - width % patch_size, height - height % patch_size
55
+ image_tensor = image_tensor[:, :cropped_height, :cropped_width]
56
+
57
+ grid_size = (cropped_height // patch_size, cropped_width // patch_size) # h x w (TODO: check)
58
+ return image_tensor, grid_size, resize_scale
59
+
60
+ def make_foreground_mask(tokens,
61
+ grid_size: Tuple[int, int]):
62
+ projection = tokens @ standard_array
63
+ mask = projection >= background_threshold
64
+ mask = mask.reshape(*grid_size)
65
+ if apply_opening:
66
+ mask = binary_opening(mask)
67
+ if apply_closing:
68
+ mask = binary_closing(mask)
69
+ return mask.flatten()
70
+
71
+ def render_patch_pca(image: Image, mask, filter_background, tokens, grid_size) -> Image:
72
+ pca = PCA(n_components=3)
73
+ if filter_background : pca.fit(tokens[mask])
74
+ else : pca.fit(tokens)
75
+ projected_tokens = pca.transform(tokens)
76
+
77
+ t = torch.tensor(projected_tokens)
78
+ t_min = t.min(dim=0, keepdim=True).values
79
+ t_max = t.max(dim=0, keepdim=True).values
80
+ normalized_t = (t - t_min) / (t_max - t_min)
81
+
82
+ array = (normalized_t * 255).byte().numpy()
83
+ if filter_background : array[~mask] = 0
84
+ array = array.reshape(*grid_size, 3)
85
+
86
+ return Image.fromarray(array).resize((image.width, image.height), 0)
87
+
88
+ def extract_features(img, image_tensor, filter_background, grid_size):
89
+ with torch.inference_mode():
90
+ image_batch = image_tensor.unsqueeze(0).to(device)
91
+ tokens = pca_model.get_intermediate_layers(image_batch)[0].squeeze()
92
+ mask = make_foreground_mask(tokens, grid_size)
93
+ img_pca = render_patch_pca(img, mask, filter_background, tokens, grid_size)
94
+ return tokens.cpu().numpy(), mask, img_pca
95
+
96
+ def compute_features(img, filter_background):
97
+ image_tensor, grid_size, resize_scale = prepare_image(img)
98
+ features, mask, img_pca = extract_features(img, image_tensor, filter_background, grid_size)
99
+ return features, mask, grid_size, resize_scale, img_pca
100
+
101
+ def idx_to_source_position(idx, grid_size, resize_scale):
102
+ row = (idx // grid_size[1])*pca_model.patch_size*resize_scale + pca_model.patch_size / 2
103
+ col = (idx % grid_size[1])*pca_model.patch_size*resize_scale + pca_model.patch_size / 2
104
+ return row, col
105
+
106
+ def compute_nn(features1, features2):
107
+ knn = NearestNeighbors(n_neighbors=1)
108
+ knn.fit(features1)
109
+ distances, match2to1 = knn.kneighbors(features2)
110
+ match2to1 = np.array(match2to1)
111
+ return distances, match2to1
112
+
113
+ def compute_matches(img1, img2, lr_check, filter_background, display_matches_threshold):
114
+ # compute features
115
+ features1, mask1, grid_size1, resize_scale1, img_pca1 = compute_features(img1, filter_background)
116
+ features2, mask2, grid_size2, resize_scale2, img_pca2 = compute_features(img2, filter_background)
117
+
118
+ # match features
119
+ distances2to1, match2to1 = compute_nn(features1, features2)
120
+ distances1to2, match1to2 = compute_nn(features2, features1)
121
+
122
+ # display matches
123
+ draw1 = ImageDraw.Draw(img1)
124
+ draw2 = ImageDraw.Draw(img2)
125
+
126
+ if(img1.size[1] > img2.size[1]):
127
+ img1 = img1.resize(img2.size)
128
+ resize_scale1 = resize_scale2
129
+ else:
130
+ img2 = img2.resize(img1.size)
131
+ resize_scale2 = resize_scale1
132
+
133
+ offset = img1.size[0]
134
+ merged_image = Image.new('RGB',(offset + img2.size[0], max(img1.size[1], img2.size[1])), (250,250,250))
135
+ merged_image.paste(img1,(0,0))
136
+ merged_image.paste(img2,(offset,0))
137
+ draw = ImageDraw.Draw(merged_image)
138
+
139
+ colormap = ImageColor.colormap
140
+ for idx2, idx1 in enumerate(match2to1):
141
+ if lr_check and match1to2[idx1] != idx2: continue
142
+ row1, col1 = idx_to_source_position(idx1, grid_size1, resize_scale1)
143
+ row2, col2 = idx_to_source_position(idx2, grid_size2, resize_scale2)
144
+
145
+ if filter_background and not mask1[idx1]: continue
146
+ if filter_background and not mask2[idx2]: continue
147
+
148
+ r = randint(0,255)
149
+ g = randint(0,255)
150
+ color = (r,g,255-r)
151
+
152
+ draw1.point((col1, row1), color)
153
+ draw2.point((col2, row2), color)
154
+
155
+ if 100*np.random.rand() > display_matches_threshold: continue
156
+ draw.line((col1, row1, col2 + offset, row2), fill=color)
157
+
158
+ return [[img1, img_pca1], [img2, img_pca2], merged_image]
159
+
160
+ iface = gr.Interface(fn=compute_matches,
161
+ inputs=[
162
+ gr.Image(type="pil"),
163
+ gr.Image(type="pil"),
164
+ gr.Checkbox(label="Keep only symmetric matches",),
165
+ gr.Checkbox(label="Mask background"),
166
+ gr.Slider(0, 100, step=5, value=5, label="Display matches ratio", info="Choose between 0 and 100%"),
167
+ ],
168
+ outputs=[
169
+ gr.Gallery(
170
+ label="Image 1", show_label=False, elem_id="gallery",
171
+ columns=[2], rows=[1], object_fit="contain", height="auto"),
172
+ gr.Gallery(
173
+ label="Image 1", show_label=False, elem_id="gallery",
174
+ columns=[2], rows=[1], object_fit="contain", height="auto"),
175
+ gr.Image(type="pil")
176
+ ])
177
+ iface.launch(debug=True)
pre-requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pip>=23.2
2
+ gradio_client==0.2.7
standard.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8670bf5a566f92828abc119766e8257c21a1e8f4f1c3e477be14ab6a4bb9afa2
3
+ size 6272