taesiri commited on
Commit
3717c61
1 Parent(s): b40ab37

update demo

Browse files
Files changed (3) hide show
  1. CHMCorr.py +1 -0
  2. app.py +41 -18
  3. visualization.py +117 -10
CHMCorr.py CHANGED
@@ -494,6 +494,7 @@ def export_visualizations_results(
494
  "chm-prediction": pfn,
495
  "chm-prediction-confidence": pr,
496
  "chm-nearest-neighbors": rfiles,
 
497
  "correspondance_map": cmaps,
498
  "masked_cos_values": MASKED_COSINE_VALUES,
499
  "src-keypoints": list_of_source_points,
 
494
  "chm-prediction": pfn,
495
  "chm-prediction-confidence": pr,
496
  "chm-nearest-neighbors": rfiles,
497
+ "chm-nearest-neighbors-all": reranked_nns,
498
  "correspondance_map": cmaps,
499
  "masked_cos_values": MASKED_COSINE_VALUES,
500
  "src-keypoints": list_of_source_points,
app.py CHANGED
@@ -13,7 +13,7 @@ from PIL import Image
13
  from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
14
  from ExtractEmbedding import QueryToEmbedding
15
  from CHMCorr import chm_classify_and_visualize
16
- from visualization import plot_from_reranker_output
17
 
18
  csv.field_size_limit(sys.maxsize)
19
 
@@ -74,7 +74,7 @@ id_to_bird_name = {
74
  }
75
 
76
 
77
- def search(query_image, draw_arcs, searcher=searcher):
78
  query_embedding = QueryToEmbedding(query_image)
79
  scores, indices, labels = searcher.search(query_embedding, k=50)
80
 
@@ -101,7 +101,7 @@ def search(query_image, draw_arcs, searcher=searcher):
101
  query_image, kNN_results, support, training_folder
102
  )
103
 
104
- fig = plot_from_reranker_output(chm_output, draw_arcs=draw_arcs)
105
 
106
  # Resize the output
107
 
@@ -117,35 +117,58 @@ def search(query_image, draw_arcs, searcher=searcher):
117
  right = (width + new_width) / 2
118
  bottom = (height + new_height) / 2
119
 
120
- viz_image = image.crop((left + 540, top + 40, right - 492, bottom - 100))
121
 
122
- return viz_image, predicted_labels
 
 
 
 
 
 
 
123
 
124
 
125
  blocks = gr.Blocks()
126
 
127
  with blocks:
128
  gr.Markdown(""" # CHM-Corr DEMO""")
129
- gr.Markdown(""" ### Parameters: N=50, k=20 - Using ResNet50 features""")
 
 
130
 
131
- # with gr.Row():
132
  input_image = gr.Image(type="filepath")
133
- with gr.Column():
134
- arcs_checkbox = gr.Checkbox(label="Draw Arcs")
135
  run_btn = gr.Button("Classify")
136
-
137
- # with gr.Column():
138
- gr.Markdown(""" ### CHM-Corr Output """)
139
- viz_plot = gr.Image(type="pil")
140
- gr.Markdown(""" ### kNN Predicted Labels """)
141
- predicted_labels = gr.Label(label="kNN Prediction")
142
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  run_btn.click(
144
  search,
145
- inputs=[input_image, arcs_checkbox],
146
- outputs=[viz_plot, predicted_labels],
147
  )
148
 
 
149
  if __name__ == "__main__":
150
  blocks.launch(
151
  debug=True,
 
13
  from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
14
  from ExtractEmbedding import QueryToEmbedding
15
  from CHMCorr import chm_classify_and_visualize
16
+ from visualization import plot_from_reranker_corrmap
17
 
18
  csv.field_size_limit(sys.maxsize)
19
 
 
74
  }
75
 
76
 
77
+ def search(query_image, searcher=searcher):
78
  query_embedding = QueryToEmbedding(query_image)
79
  scores, indices, labels = searcher.search(query_embedding, k=50)
80
 
 
101
  query_image, kNN_results, support, training_folder
102
  )
103
 
104
+ fig, chm_output_label = plot_from_reranker_corrmap(chm_output)
105
 
106
  # Resize the output
107
 
 
117
  right = (width + new_width) / 2
118
  bottom = (height + new_height) / 2
119
 
120
+ viz_image = image.crop((left + 310, top + 60, right - 248, bottom - 80))
121
 
122
+ chm_output_labels = Counter(
123
+ [
124
+ x.split("/")[-2].replace(".", " ").replace("_", " ")
125
+ for x in chm_output["chm-nearest-neighbors-all"][:20]
126
+ ]
127
+ )
128
+
129
+ return viz_image, {l: s / 20.0 for l, s in chm_output_labels.items()}
130
 
131
 
132
  blocks = gr.Blocks()
133
 
134
  with blocks:
135
  gr.Markdown(""" # CHM-Corr DEMO""")
136
+ gr.Markdown(
137
+ """ ### Parameters: N=50, k=20 - Using ``ImageNet Pretrained ResNet50`` features"""
138
+ )
139
 
 
140
  input_image = gr.Image(type="filepath")
 
 
141
  run_btn = gr.Button("Classify")
142
+ gr.Markdown(""" ### CHM-Corr Output Visualization """)
143
+ viz_plot = gr.Image(type="pil", label="Visualization")
144
+ with gr.Row():
145
+ with gr.Column():
146
+ gr.Markdown(""" ### CHM-Corr Prediction """)
147
+ labels = gr.Label(label="Prediction")
148
+ with gr.Column():
149
+ gr.Markdown(""" ### Examples """)
150
+ examples = gr.Examples(
151
+ examples=[
152
+ ["./examples/bird.jpg"],
153
+ ["./examples/Red_Winged_Blackbird_0012_6015.jpg"],
154
+ ["./examples/Red_Winged_Blackbird_0025_5342.jpg"],
155
+ ["./examples/sample1.jpeg"],
156
+ ["./examples/sample2.jpeg"],
157
+ ["./examples/Yellow_Headed_Blackbird_0020_8549.jpg"],
158
+ ["./examples/Yellow_Headed_Blackbird_0026_8545.jpg"],
159
+ ],
160
+ inputs=[input_image],
161
+ outputs=[viz_plot, labels],
162
+ fn=search,
163
+ cache_examples=False,
164
+ )
165
  run_btn.click(
166
  search,
167
+ inputs=[input_image],
168
+ outputs=[viz_plot, labels],
169
  )
170
 
171
+
172
  if __name__ == "__main__":
173
  blocks.launch(
174
  debug=True,
visualization.py CHANGED
@@ -38,7 +38,6 @@ def arg_topK(inputarray, topK=5):
38
  return np.argsort(inputarray.T.reshape(-1))[::-1][:topK]
39
 
40
 
41
- # FOR MULTI
42
  def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True):
43
  """
44
  visualize chm results from a reranker output dict
@@ -261,14 +260,122 @@ def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True):
261
  color="black",
262
  fontsize=22,
263
  )
264
- # fig.text(
265
- # 0.8,
266
- # 0.95,
267
- # f"KNN: {reranker_output['knn-prediction']}",
268
- # ha="right",
269
- # va="bottom",
270
- # color="black",
271
- # fontsize=22,
272
- # )
273
 
274
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  return np.argsort(inputarray.T.reshape(-1))[::-1][:topK]
39
 
40
 
 
41
  def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True):
42
  """
43
  visualize chm results from a reranker output dict
 
260
  color="black",
261
  fontsize=22,
262
  )
 
 
 
 
 
 
 
 
 
263
 
264
  return fig
265
+
266
+
267
+ def plot_from_reranker_corrmap(reranker_output, draw_box=True):
268
+ """
269
+ visualize chm results from a reranker output dict
270
+ """
271
+
272
+ ### SET COLORS
273
+ cmap = matplotlib.cm.get_cmap("gist_rainbow")
274
+ rgba = cmap(0.5)
275
+ colors = []
276
+ for k in range(5):
277
+ colors.append(cmap(k / 5.0))
278
+
279
+ ### SET POINTS
280
+ A = np.linspace(1 + 17, 240 - 17 - 1, 7)
281
+ point_list = list(product(A, A))
282
+
283
+ fig, axes = plt.subplots(
284
+ 2,
285
+ 7,
286
+ figsize=(25, 8),
287
+ gridspec_kw={
288
+ "wspace": 0,
289
+ "hspace": 0,
290
+ "width_ratios": [1, 0.28, 1, 1, 1, 1, 1],
291
+ },
292
+ facecolor=(1, 1, 1),
293
+ )
294
+
295
+ for i in range(2):
296
+ for j in range(7):
297
+ axes[i][j].axis("off")
298
+
299
+ axes[0][0].imshow(
300
+ display_transform(Image.open(reranker_output["q"]).convert("RGB"))
301
+ )
302
+
303
+ for i in range(min(5, reranker_output["chm-prediction-confidence"])):
304
+ axes[0][2 + i].imshow(
305
+ display_transform(Image.open(reranker_output["q"]).convert("RGB"))
306
+ )
307
+
308
+ # Lower ROWs CHM Top5
309
+ for i in range(min(5, reranker_output["chm-prediction-confidence"])):
310
+ axes[1][2 + i].imshow(
311
+ display_transform(
312
+ Image.open(reranker_output["chm-nearest-neighbors"][i]).convert("RGB")
313
+ )
314
+ )
315
+
316
+ if reranker_output["chm-prediction-confidence"] < 5:
317
+ for i in range(reranker_output["chm-prediction-confidence"], 5):
318
+ axes[0][2 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
319
+ axes[1][2 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
320
+
321
+ nzm = reranker_output["non_zero_mask"]
322
+ # Go throught top 5 nearest images
323
+
324
+ # #################################################################################
325
+ if draw_box:
326
+ # SQUARAES
327
+ for NC in range(min(5, reranker_output["chm-prediction-confidence"])):
328
+ # ON SOURCE
329
+ valid_patches_source = arg_topK(
330
+ reranker_output["masked_cos_values"][NC], topK=nzm
331
+ )
332
+
333
+ # ON QUERY
334
+ target_masked_patches = arg_topK(
335
+ reranker_output["masked_cos_values"][NC], topK=nzm
336
+ )
337
+ valid_patches_target = [
338
+ reranker_output["correspondance_map"][NC][x]
339
+ for x in target_masked_patches
340
+ ]
341
+ valid_patches_target = [(x[0] * 7) + x[1] for x in valid_patches_target]
342
+
343
+ patch_colors = [c for c in colors]
344
+ overlaps = [
345
+ item
346
+ for item, count in Counter(valid_patches_target).items()
347
+ if count > 1
348
+ ]
349
+
350
+ for O in overlaps:
351
+ indices = [i for i, val in enumerate(valid_patches_target) if val == O]
352
+ for ii in indices[1:]:
353
+ patch_colors[ii] = patch_colors[indices[0]]
354
+
355
+ for i in valid_patches_source:
356
+ Psource = point_list[i]
357
+ rect = patches.Rectangle(
358
+ (Psource[0] - 16, Psource[1] - 16),
359
+ 32,
360
+ 32,
361
+ linewidth=2,
362
+ edgecolor=patch_colors[valid_patches_source.tolist().index(i)],
363
+ facecolor="none",
364
+ alpha=1,
365
+ )
366
+ axes[0][2 + NC].add_patch(rect)
367
+
368
+ for i in valid_patches_target:
369
+ Psource = point_list[i]
370
+ rect = patches.Rectangle(
371
+ (Psource[0] - 16, Psource[1] - 16),
372
+ 32,
373
+ 32,
374
+ linewidth=2,
375
+ edgecolor=patch_colors[valid_patches_target.index(i)],
376
+ facecolor="none",
377
+ alpha=1,
378
+ )
379
+ axes[1][2 + NC].add_patch(rect)
380
+
381
+ return fig, reranker_output["chm-prediction"]