IbrahimHasani commited on
Commit
472aaf0
1 Parent(s): de8e664

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +609 -0
app.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import requests
3
+
4
+ from PIL import Image
5
+ import PIL
6
+ from PIL import ImageDraw
7
+
8
+ from matplotlib import pyplot as plt
9
+ import matplotlib
10
+ from matplotlib import rcParams
11
+
12
+ import os
13
+ import tempfile
14
+ from io import BytesIO
15
+ from pathlib import Path
16
+ import argparse
17
+ import random
18
+ import numpy as np
19
+ import torch
20
+ import matplotlib.cm as cm
21
+ import pandas as pd
22
+
23
+
24
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection
25
+ from transformers.image_utils import ImageFeatureExtractionMixin
26
+
27
+
28
+ from SuperGluePretrainedNetwork.models.matching import Matching
29
+ from SuperGluePretrainedNetwork.models.utils import (compute_pose_error, compute_epipolar_error,
30
+ estimate_pose,
31
+ error_colormap, AverageTimer, pose_auc, read_image,
32
+ rotate_intrinsics, rotate_pose_inplane,
33
+ scale_intrinsics)
34
+
35
+ torch.set_grad_enabled(False)
36
+
37
+
38
+
39
+
40
+ mixin = ImageFeatureExtractionMixin()
41
+ model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
42
+ processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
43
+
44
+
45
+ # Use GPU if available
46
+ if torch.cuda.is_available():
47
+ device = torch.device("cuda")
48
+ else:
49
+ device = torch.device("cpu")
50
+
51
+
52
+ import requests
53
+ from PIL import Image, ImageDraw
54
+ from io import BytesIO
55
+ import matplotlib.pyplot as plt
56
+ import numpy as np
57
+ import torch
58
+ import cv2
59
+ import tempfile
60
+
61
+ def detect_and_crop2(target_image_path,
62
+ query_image_path,
63
+ model,
64
+ processor,
65
+ mixin,
66
+ device,
67
+ threshold=0.5,
68
+ nms_threshold=0.3,
69
+ visualize=True):
70
+
71
+ # Open target image
72
+ image = Image.open(target_image_path).convert('RGB')
73
+ image_size = model.config.vision_config.image_size + 5
74
+ image = mixin.resize(image, image_size)
75
+ target_sizes = torch.Tensor([image.size[::-1]])
76
+
77
+ # Open query image
78
+ query_image = Image.open(query_image_path).convert('RGB')
79
+ image_size = model.config.vision_config.image_size + 5
80
+ query_image = mixin.resize(query_image, image_size)
81
+
82
+ # Process input and query image
83
+ inputs = processor(images=image, query_images=query_image, return_tensors="pt").to(device)
84
+
85
+ # Get predictions
86
+ with torch.no_grad():
87
+ outputs = model.image_guided_detection(**inputs)
88
+
89
+ # Convert predictions to CPU
90
+ img = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
91
+ outputs.logits = outputs.logits.cpu()
92
+ outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
93
+
94
+ # Post process the predictions
95
+ results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes)
96
+ boxes, scores = results[0]["boxes"], results[0]["scores"]
97
+
98
+ # If no boxes, return an empty list
99
+ if len(boxes) == 0 and visualize:
100
+ print(f"No boxes detected for image: {target_image_path}")
101
+ fig, ax = plt.subplots(figsize=(6, 6))
102
+ ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
103
+ ax.set_title("Original Image")
104
+ ax.axis("off")
105
+ plt.show()
106
+ return []
107
+
108
+ # Filter boxes
109
+ img_with_all_boxes = img.copy()
110
+ filtered_boxes = []
111
+ filtered_scores = []
112
+ img_width, img_height = img.shape[1], img.shape[0]
113
+ for box, score in zip(boxes, scores):
114
+ x1, y1, x2, y2 = [int(i) for i in box.tolist()]
115
+ if x1 < 0 or y1 < 0 or x2 < 0 or y2 < 0:
116
+ continue
117
+ if (x2 - x1) / img_width >= 0.94 and (y2 - y1) / img_height >= 0.94:
118
+ continue
119
+ filtered_boxes.append([x1, y1, x2, y2])
120
+ filtered_scores.append(score)
121
+
122
+ # Draw boxes on original image
123
+ draw = ImageDraw.Draw(image)
124
+ for box in filtered_boxes:
125
+ draw.rectangle(box, outline="red",width=3)
126
+
127
+ cropped_images = []
128
+ for box in filtered_boxes:
129
+ x1, y1, x2, y2 = box
130
+ cropped_img = img[y1:y2, x1:x2]
131
+ if cropped_img.size != 0:
132
+ cropped_images.append(cropped_img)
133
+
134
+ if visualize:
135
+ # Visualization
136
+ if not filtered_boxes:
137
+ fig, ax = plt.subplots(figsize=(6, 6))
138
+ ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
139
+ ax.set_title("Original Image")
140
+ ax.axis("off")
141
+ plt.show()
142
+ else:
143
+ fig, axs = plt.subplots(1, len(cropped_images) + 2, figsize=(15, 5))
144
+ axs[0].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
145
+ axs[0].set_title("Original Image")
146
+ axs[0].axis("off")
147
+
148
+ for i, (box, score) in enumerate(zip(filtered_boxes, filtered_scores)):
149
+ x1, y1, x2, y2 = box
150
+ cropped_img = img[y1:y2, x1:x2]
151
+ font = cv2.FONT_HERSHEY_SIMPLEX
152
+ text = f"{score:.2f}"
153
+ cv2.putText(cropped_img, text, (5, cropped_img.shape[0]-10), font, 0.5, (255,0,0), 1, cv2.LINE_AA)
154
+ axs[i+2].imshow(cv2.cvtColor(cropped_img, cv2.COLOR_BGR2RGB))
155
+ axs[i+2].set_title("Score: " + text)
156
+ axs[i+2].axis("off")
157
+ plt.tight_layout()
158
+ plt.show()
159
+
160
+ return cropped_images, image # return original image with boxes drawn
161
+
162
+ def save_array_to_temp_image(arr):
163
+ # Convert the array to an image
164
+ img = Image.fromarray(arr)
165
+
166
+ # Create a temporary file for the image
167
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir=tempfile.gettempdir())
168
+ temp_file_name = temp_file.name
169
+ temp_file.close() # We close it because we're not writing to it directly, PIL will handle the writing
170
+
171
+ # Save the image to the temp file
172
+ img.save(temp_file_name)
173
+
174
+ return temp_file_name
175
+
176
+ '''
177
+ def process_resize(w: int, h: int, resize_dims: list) -> tuple:
178
+ if len(resize_dims) == 1 and resize_dims[0] > -1:
179
+ scale = resize_dims[0] / max(h, w)
180
+ w_new, h_new = int(round(w * scale)), int(round(h * scale))
181
+ return w_new, h_new
182
+ return w, h
183
+ '''
184
+
185
+ def plot_image_pair(imgs, dpi=100, size=6, pad=.5):
186
+ n = len(imgs)
187
+ assert n == 2, 'number of images must be two'
188
+ figsize = (size*n, size*3/4) if size is not None else None
189
+ _, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
190
+ for i in range(n):
191
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap('gray'), vmin=0, vmax=255)
192
+ ax[i].get_yaxis().set_ticks([])
193
+ ax[i].get_xaxis().set_ticks([])
194
+ for spine in ax[i].spines.values(): # remove frame
195
+ spine.set_visible(False)
196
+ plt.tight_layout(pad=pad)
197
+
198
+ def plot_keypoints(kpts0, kpts1, color='w', ps=2):
199
+ ax = plt.gcf().axes
200
+ ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
201
+ ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
202
+
203
+ def plot_matches(kpts0, kpts1, color, lw=1.5, ps=4):
204
+ fig = plt.gcf()
205
+ ax = fig.axes
206
+ fig.canvas.draw()
207
+
208
+ transFigure = fig.transFigure.inverted()
209
+ fkpts0 = transFigure.transform(ax[0].transData.transform(kpts0))
210
+ fkpts1 = transFigure.transform(ax[1].transData.transform(kpts1))
211
+
212
+ fig.lines = [matplotlib.lines.Line2D(
213
+ (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), zorder=1,
214
+ transform=fig.transFigure, c=color[i], linewidth=lw)
215
+ for i in range(len(kpts0))]
216
+ ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
217
+ ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
218
+
219
+ def unified_matching_plot2(image0, image1, kpts0, kpts1, mkpts0, mkpts1,
220
+ color, text, path=None, show_keypoints=False,
221
+ fast_viz=False, opencv_display=False,
222
+ opencv_title='matches', small_text=[]):
223
+
224
+ # Set the background color for the plot
225
+ plt.figure(facecolor='#eeeeee')
226
+ plot_image_pair([image0, image1])
227
+
228
+ # Elegant points and lines for matches
229
+ if show_keypoints:
230
+ plot_keypoints(kpts0, kpts1, color='k', ps=4)
231
+ plot_keypoints(kpts0, kpts1, color='w', ps=2)
232
+ plot_matches(mkpts0, mkpts1, color, lw=1)
233
+
234
+ fig = plt.gcf()
235
+
236
+ # Add text
237
+ fig.text(
238
+ 0.01, 0.01, '\n'.join(small_text), transform=fig.axes[0].transAxes,
239
+ fontsize=10, va='bottom', ha='left', color='#333333', fontweight='bold', fontname='Helvetica',
240
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle="round,pad=0.3"))
241
+
242
+ fig.text(
243
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
244
+ fontsize=15, va='top', ha='left', color='#333333', fontweight='bold', fontname='Helvetica',
245
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle="round,pad=0.3"))
246
+
247
+ # Optional: remove axis for a cleaner look
248
+ plt.axis('off')
249
+
250
+ # Convert the figure to an OpenCV image
251
+ buf = BytesIO()
252
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
253
+ buf.seek(0)
254
+ img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
255
+ buf.close()
256
+ img = cv2.imdecode(img_arr, 1)
257
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
258
+
259
+ # Close the figure to free memory
260
+ plt.close(fig)
261
+
262
+ return img
263
+
264
+ def create_image_pyramid2(image_path, longest_side, scales=[0.25, 0.5, 1.0]):
265
+ original_image = cv2.imread(image_path)
266
+ oh, ow, _ = original_image.shape
267
+
268
+ # Determine the scaling factor based on the longest side
269
+ if oh > ow:
270
+ output_height = longest_side
271
+ output_width = int((ow / oh) * longest_side)
272
+ else:
273
+ output_width = longest_side
274
+ output_height = int((oh / ow) * longest_side)
275
+ output_size = (output_width, output_height)
276
+
277
+ pyramid = []
278
+
279
+ for scale in scales:
280
+ # Resize based on the scale factor
281
+ resized = cv2.resize(original_image, None, fx=scale, fy=scale)
282
+ rh, rw, _ = resized.shape
283
+
284
+ if scale < 1.0: # downsampling
285
+ # Calculate the amount of padding required
286
+ dy_top = max((output_size[1] - rh) // 2, 0)
287
+ dy_bottom = output_size[1] - rh - dy_top
288
+ dx_left = max((output_size[0] - rw) // 2, 0)
289
+ dx_right = output_size[0] - rw - dx_left
290
+
291
+ # Create padded image
292
+ padded = cv2.copyMakeBorder(resized, dy_top, dy_bottom, dx_left, dx_right, cv2.BORDER_CONSTANT, value=[255, 255, 255])
293
+ pyramid.append(padded)
294
+ elif scale > 1.0: # upsampling
295
+ # We need to crop the image to fit the desired output size
296
+ dy = (rh - output_size[1]) // 2
297
+ dx = (rw - output_size[0]) // 2
298
+ cropped = resized[dy:dy+output_size[1], dx:dx+output_size[0]]
299
+ pyramid.append(cropped)
300
+ else: # scale == 1.0
301
+ pyramid.append(resized)
302
+
303
+ return pyramid
304
+
305
+ # Example usage
306
+ # pyramid = create_image_pyramid('path_to_image.jpg', 800)
307
+ def image_matching(query_img, target_img, image_dims=[640*2], scale_factors=[0.33,0.66,1], visualize=True, k_thresh=None, m_thresh=None, write=False):
308
+
309
+ image1, inp1, scales1 = read_image(target_img, device, [640*2], 0, True)
310
+ query_pyramid = create_image_pyramid2(query_img, image_dims[0], scale_factors)
311
+
312
+ all_valid = []
313
+ all_inliers = []
314
+ all_return_imgs = []
315
+ max_matches_img = None
316
+ max_matches = -1
317
+
318
+ for idx, query_level in enumerate(query_pyramid):
319
+ temp_file_path = "temp_level_{}.png".format(idx)
320
+ cv2.imwrite(temp_file_path, query_level)
321
+
322
+ image0, inp0, scales0 = read_image(temp_file_path, device, [640*2], 0, True)
323
+
324
+ if image0 is None or image1 is None:
325
+ print('Problem reading image pair: {} {}'.format(query_img, target_img))
326
+ else:
327
+ # Matching
328
+ pred = matching({'image0': inp0, 'image1': inp1})
329
+ pred = {k: v[0] for k, v in pred.items()}
330
+ kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
331
+ matches, conf = pred['matches0'], pred['matching_scores0']
332
+
333
+ valid = matches > -1
334
+ mkpts0 = kpts0[valid]
335
+ mkpts1 = kpts1[matches[valid]]
336
+ mconf = conf[valid]
337
+ #color = cm.jet(mconf)[:len(mkpts0)] # Ensure consistent size
338
+ color = cm.jet(mconf.detach().numpy())[:len(mkpts0)]
339
+
340
+ all_valid.append(np.sum( valid.tolist() ))
341
+
342
+ # Convert torch tensors to numpy arrays.
343
+ mkpts0_np = mkpts0.cpu().numpy()
344
+ mkpts1_np = mkpts1.cpu().numpy()
345
+
346
+ try:
347
+ # Use RANSAC to find the homography matrix.
348
+ H, inliers = cv2.findHomography(mkpts0_np, mkpts1_np, cv2.RANSAC, 5.0)
349
+ except:
350
+ H = 0
351
+ inliers = 0
352
+ print ("Not enough points for homography")
353
+ # Convert inliers from shape (N, 1) to shape (N,) and count them.
354
+ num_inliers = np.sum(inliers)
355
+
356
+ all_inliers.append(num_inliers)
357
+
358
+ # Visualization
359
+ text = [
360
+ 'Engagify Image Matching',
361
+ 'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)),
362
+ 'Scaling Factor: {}'.format( scale_factors[idx]),
363
+ 'Matches: {}'.format(len(mkpts0)),
364
+ 'Inliers: {}'.format(num_inliers),
365
+ ]
366
+
367
+
368
+ k_thresh = matching.superpoint.config['keypoint_threshold']
369
+ m_thresh = matching.superglue.config['match_threshold']
370
+
371
+ small_text = [
372
+ 'Keypoint Threshold: {:.4f}'.format(k_thresh),
373
+ 'Match Threshold: {:.2f}'.format(m_thresh),
374
+ ]
375
+
376
+ visualized_img = None # To store the visualized image
377
+
378
+ if visualize:
379
+ ret_img = unified_matching_plot2(
380
+ image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, text, 'Test_Level_{}'.format(idx), True, False, True, 'Matches_Level_{}'.format(idx), small_text)
381
+ all_return_imgs.append(ret_img)
382
+ # Storing image with most matches
383
+ #if len(mkpts0) > max_matches:
384
+ # max_matches = len(mkpts0)
385
+ # max_matches_img = 'Matches_Level_{}'.format(idx)
386
+
387
+ avg_valid = np.sum(all_valid) / len(scale_factors)
388
+ avg_inliers = np.sum(all_inliers) / len(scale_factors)
389
+
390
+ # Convert the image with the most matches to base64 encoded format
391
+ # with open(max_matches_img, "rb") as image_file:
392
+ # encoded_string = base64.b64encode(image_file.read()).decode()
393
+
394
+ return {'valid':all_valid, 'inliers':all_inliers, 'visualized_image':all_return_imgs} #, encoded_string
395
+
396
+ # Usage:
397
+ #results = image_matching('Samples/Poster/poster_event_small_22.jpg', 'Samples/Images/16.jpeg', visualize=True)
398
+ #print (results)
399
+
400
+ def image_matching_no_pyramid(query_img, target_img, visualize=True, write=False):
401
+
402
+ image1, inp1, scales1 = read_image(target_img, device, [640*2], 0, True)
403
+ image0, inp0, scales0 = read_image(query_img, device, [640*2], 0, True)
404
+
405
+ if image0 is None or image1 is None:
406
+ print('Problem reading image pair: {} {}'.format(query_img, target_img))
407
+ return None
408
+
409
+ # Matching
410
+ pred = matching({'image0': inp0, 'image1': inp1})
411
+ pred = {k: v[0] for k, v in pred.items()}
412
+ kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
413
+ matches, conf = pred['matches0'], pred['matching_scores0']
414
+
415
+ valid = matches > -1
416
+ mkpts0 = kpts0[valid]
417
+ mkpts1 = kpts1[matches[valid]]
418
+ mconf = conf[valid]
419
+ #color = cm.jet(mconf)[:len(mkpts0)] # Ensure consistent size
420
+ color = cm.jet(mconf.detach().numpy())[:len(mkpts0)]
421
+
422
+ valid_count = np.sum(valid.tolist())
423
+
424
+ # Convert torch tensors to numpy arrays.
425
+ mkpts0_np = mkpts0.cpu().numpy()
426
+ mkpts1_np = mkpts1.cpu().numpy()
427
+
428
+ try:
429
+ # Use RANSAC to find the homography matrix.
430
+ H, inliers = cv2.findHomography(mkpts0_np, mkpts1_np, cv2.RANSAC, 5.0)
431
+ except:
432
+ H = 0
433
+ inliers = 0
434
+ print("Not enough points for homography")
435
+
436
+ # Convert inliers from shape (N, 1) to shape (N,) and count them.
437
+ num_inliers = np.sum(inliers)
438
+
439
+ # Visualization
440
+ text = [
441
+ 'Engagify Image Matching',
442
+ 'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)),
443
+ 'Matches: {}'.format(len(mkpts0)),
444
+ 'Inliers: {}'.format(num_inliers),
445
+ ]
446
+
447
+ k_thresh = matching.superpoint.config['keypoint_threshold']
448
+ m_thresh = matching.superglue.config['match_threshold']
449
+
450
+ small_text = [
451
+ 'Keypoint Threshold: {:.4f}'.format(k_thresh),
452
+ 'Match Threshold: {:.2f}'.format(m_thresh),
453
+ ]
454
+
455
+ visualized_img = None # To store the visualized image
456
+
457
+ if visualize:
458
+ visualized_img = unified_matching_plot2(
459
+ image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, text, 'Test_Match', True, False, True, 'Matches', small_text)
460
+
461
+ return {
462
+ 'valid': [valid_count],
463
+ 'inliers': [num_inliers],
464
+ 'visualized_image': [visualized_img]
465
+ }
466
+
467
+ # Usage:
468
+ #results = image_matching_no_pyramid('Samples/Poster/poster_event_small_22.jpg', 'Samples/Images/16.jpeg', visualize=True)
469
+
470
+ # Load the SuperPoint and SuperGlue models.
471
+ device = 'cuda' if torch.cuda.is_available() and not opt.force_cpu else 'cpu'
472
+ print('Running inference on device \"{}\"'.format(device))
473
+ config = {
474
+ 'superpoint': {
475
+ 'nms_radius': 4,
476
+ 'keypoint_threshold': 0.005,
477
+ 'max_keypoints': 1024
478
+ },
479
+ 'superglue': {
480
+ 'weights': 'outdoor',
481
+ 'sinkhorn_iterations': 20,
482
+ 'match_threshold': 0.2,
483
+ }
484
+ }
485
+ matching = Matching(config).eval().to(device)
486
+
487
+ from PIL import Image
488
+
489
+ def stitch_images(images):
490
+ """Stitches a list of images vertically."""
491
+ if not images:
492
+ # Return a placeholder image if the images list is empty
493
+ return Image.new('RGB', (100, 100), color='gray')
494
+
495
+ max_width = max([img.width for img in images])
496
+ total_height = sum(img.height for img in images)
497
+
498
+ composite = Image.new('RGB', (max_width, total_height))
499
+
500
+ y_offset = 0
501
+ for img in images:
502
+ composite.paste(img, (0, y_offset))
503
+ y_offset += img.height
504
+
505
+ return composite
506
+
507
+ def check_object_in_image3(query_image, target_image, threshold=50, scale_factor=[0.33,0.66,1]):
508
+ decision_on = []
509
+ # Convert cv2 images to PIL images and add them to a list
510
+ images_to_return = []
511
+
512
+ cropped_images, bbox_image = detect_and_crop2(target_image_path=target_image,
513
+ query_image_path=query_image,
514
+ model=model,
515
+ processor=processor,
516
+ mixin=mixin,
517
+ device=device,
518
+ visualize=False)
519
+
520
+ temp_files = [save_array_to_temp_image(i) for i in cropped_images]
521
+ crop_results = [image_matching_no_pyramid(query_image, i, visualize=True) for i in temp_files]
522
+
523
+ cropped_visuals = []
524
+ cropped_inliers = []
525
+ for result in crop_results:
526
+ # Add visualized images to the temporary list
527
+ for img in result['visualized_image']:
528
+ cropped_visuals.append(Image.fromarray(img))
529
+ for inliers_ in result['inliers']:
530
+ cropped_inliers.append(inliers_)
531
+ # Stitch the cropped visuals into one image
532
+ images_to_return.append(stitch_images(cropped_visuals))
533
+
534
+ pyramid_results = image_matching(query_image, target_image, visualize=True, scale_factors=scale_factor)
535
+
536
+ pyramid_visuals = [Image.fromarray(img) for img in pyramid_results['visualized_image']]
537
+ # Stitch the pyramid visuals into one image
538
+ images_to_return.append(stitch_images(pyramid_visuals))
539
+
540
+ # Check inliers and determine if the object is present
541
+ print (cropped_inliers)
542
+ is_present = any(value > threshold for value in cropped_inliers)
543
+ if is_present == True:
544
+ decision_on.append('Object Detection')
545
+ is_present = any(value > threshold for value in pyramid_results["inliers"])
546
+ if is_present == True:
547
+ decision_on.append('Pyramid Max Point')
548
+ if is_present == False:
549
+ decision_on.append("Neither, It Failed All Tests")
550
+
551
+ # Return results as a dictionary
552
+ return {
553
+ 'is_present': is_present,
554
+ 'images': images_to_return,
555
+ 'scale factors': scale_factor,
556
+ 'object detection inliers': cropped_inliers,
557
+ 'pyramid_inliers' : pyramid_results["inliers"],
558
+ 'bbox_image':bbox_image,
559
+ 'decision_on':decision_on,
560
+
561
+ }
562
+
563
+ # Example call:
564
+ #result = check_object_in_image3('Samples/Poster/poster_event_small.jpg', 'Samples/Images/True_Image_3423234.jpeg', 50)
565
+ # Accessing the results:
566
+ #print(result['is_present']) # prints True/False
567
+ #print(result['images']) # is a list of 2 stitched images.
568
+
569
+
570
+ import gradio as gr
571
+ import cv2
572
+ from PIL import Image
573
+
574
+ def gradio_interface(query_image_path, target_image_path, threshold):
575
+ result = check_object_in_image3(query_image_path, target_image_path, threshold)
576
+ # Depending on how many images are in the list, you can return them like this:
577
+ return result['bbox_image'], result['images'][0], result['object detection inliers'], result['scale factors'], result['pyramid_inliers'], result['images'][1], str(result['is_present']), result['decision_on']
578
+
579
+
580
+ # Define the Gradio interface
581
+ interface = gr.Interface(
582
+ fn=gradio_interface, # function to be called on button press
583
+ inputs=[
584
+ gr.components.Image(label="Query Image (Drop the Image you want to detect here)", type="filepath"),
585
+ gr.components.Image(label="Target Image (Drop the Image youd like to search here)", type="filepath"),
586
+ gr.components.Slider(minimum=0, maximum=200, value=50, step=5, label="Enter the Inlier Threshold"),
587
+ ],
588
+ outputs=[
589
+ gr.components.Image(label='Filtered Regions of Interest (Candidates)'),
590
+ gr.components.Image(label="Cropped Visuals from Image Guided Object Detection "),
591
+ gr.components.Text(label='Inliers detected for Image Guided Object Detection '),
592
+ gr.components.Text(label='Scale Factors Used for Pyramid (Results below, In Order)'),
593
+ gr.components.Text(label='Inliers detected for Pyramid Search (In Order)'),
594
+ gr.components.Image(label="Pyramid Visuals"),
595
+ gr.components.Textbox(label="Object Present?"),
596
+ gr.components.Textbox(label="Decision Taken Based on?"),
597
+ ],
598
+ theme=gr.themes.Monochrome(),
599
+ title="Engajify's Image Specific Image Recognition + Matching Tool",
600
+ description="[Author: Ibrahim Hasani] \n "
601
+ " This tool leverages Transformer, Deep Learning, and Traditional Computer Vision techniques to determine if a specified object "
602
+ "(given by the query image) is present within a target image. \n"
603
+ "1. Image-Guided Object Detection where we detect potential regions of interest. (Owl-Vit-Google). \n"
604
+ "2. Pyramid Search that looks at various scales of the target image. Results provide "
605
+ "visual representations of the matching process and a final verdict on the object's presence.\n"
606
+ "3. SuperPoint (MagicLeap) + SuperGlue + Homography to extract inliers, which are thresholded for decision making."
607
+ )
608
+
609
+ interface.launch()