huzey commited on
Commit
eac9a58
1 Parent(s): 8383165

update click

Browse files
Files changed (1) hide show
  1. app.py +54 -27
app.py CHANGED
@@ -2303,11 +2303,12 @@ with demo:
2303
 
2304
  tsne_2d_points = gr.State(np.array([]))
2305
  edges = gr.State(np.array([]))
 
2306
  fps_eigvecs = gr.State(np.array([]))
2307
  fps_indices = gr.State(np.array([]))
2308
  fps_tsne_rgb = gr.State(np.array([]))
2309
 
2310
- def plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, k, hightlight_idx=None, highlight_connections=False):
2311
  # Plot the t-SNE points
2312
  fig, ax = plt.subplots(1, 1, figsize=(6, 6))
2313
  ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne3d_rgb)
@@ -2316,15 +2317,24 @@ with demo:
2316
  max_length = lengthes[k:].max()
2317
  diag_length = np.linalg.norm(tsne_embed.max(axis=0) - tsne_embed.min(axis=0))
2318
  # draw the edges
 
 
 
 
 
 
 
 
 
 
 
 
 
2319
  for i_edge in range(k, len(edges)):
2320
  edge = edges[i_edge]
2321
- # _do = np.clip(lengthes[i_edge] / (diag_length*0.3), 0, 1)
2322
- if lengthes[i_edge] > diag_length*0.1:
2323
- _do = 1.0
2324
- else:
2325
- _do = 0.0
2326
- alpha = 0.7 * (1 - _do) + 0.0
2327
- ax.plot(tsne_embed[edge, 0], tsne_embed[edge, 1], 'k-', lw=1, alpha=alpha)
2328
  # highlight the selected node
2329
  if hightlight_idx is not None:
2330
  if highlight_connections:
@@ -2477,17 +2487,17 @@ with demo:
2477
  tsne_embed[:, 1] = (tsne_embed[:, 1] - tsne_embed[:, 1].min()) / (tsne_embed[:, 1].max() - tsne_embed[:, 1].min()) * 2 - 1
2478
 
2479
  if tree_method == 'eigvecs':
2480
- edges = build_tree(fps_eigvecs, dist='cosine')
2481
  if tree_method == 'tsne':
2482
- edges = build_tree(tsne_embed, dist='euclidean')
2483
 
2484
  # Plot the t-SNE points
2485
- pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, 0)
2486
 
2487
  # Plot the t-SNE points with image heatmaps
2488
  big_pil_image = plot_tsne_with_image_heatmaps(images, eigvecs, fps_eigvecs, tsne_embed, fps_tsne3d_rgb, max_display_dots)
2489
 
2490
- return tsne_embed, edges, fps_eigvecs, fps_tsne3d_rgb, fps_idx, pil_image, gr.update(value={'image': big_pil_image, 'points': []}, interactive=True)
2491
 
2492
  gr.Markdown('---')
2493
  gr.Markdown('<h3 style="text-align: center;">↓ interactively inspect the hierarchical structure</h3>')
@@ -2498,7 +2508,7 @@ with demo:
2498
  run_hierarchical_button.click(
2499
  run_fps_tsne_hierarchical,
2500
  inputs=[input_gallery, eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider, tree_method_radio],
2501
- outputs=[tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot, tsne_image_plot],
2502
  )
2503
  with gr.Row():
2504
  with gr.Column(scale=5, min_width=200) as tsne_select:
@@ -2526,6 +2536,15 @@ with demo:
2526
  image_slider.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
2527
  output_gallery.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
2528
  output_gallery.change(fn=lambda x: gr.update(maximum=len(x)-1, interactive=True), inputs=[output_gallery], outputs=[image_slider])
 
 
 
 
 
 
 
 
 
2529
  with gr.Column(scale=5, min_width=200) as tsne_image_select:
2530
  gr.Markdown('---')
2531
  gr.Markdown('<h3 style="text-align: center;">Please click on the image above ↑</h3>')
@@ -2550,18 +2569,18 @@ with demo:
2550
  2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
2551
  """)
2552
  with gr.Column(scale=5, min_width=200):
2553
- prompt_radio = gr.Radio(["Tree [+Image]", "Image"], label="Where to click on?", value="Tree [+Image]", elem_id="prompt_radio", show_label=True)
2554
  granularity_slider = gr.Slider(1, 1000, step=1, label="Cluster Granularity (k)", value=100, elem_id="granularity")
2555
  num_sample_fps_slider.change(fn=lambda x: gr.update(maximum=x, interactive=True), inputs=[num_sample_fps_slider], outputs=[granularity_slider])
2556
- def updaste_tsne_plot_change_granularity(granularity, tsne_embed, edges, fps_tsne_rgb, tsne_prompt_image):
2557
  # Plot the t-SNE points
2558
- pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne_rgb, granularity)
2559
  return gr.update(value=pil_image, label=f"spectral-tSNE tree [k={granularity}]")
2560
  granularity_slider.change(updaste_tsne_plot_change_granularity,
2561
- inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb, tsne_prompt_image],
2562
  outputs=[tsne_non_prompt_image])
2563
  tsne_plot.change(updaste_tsne_plot_change_granularity,
2564
- inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb],
2565
  outputs=[tsne_non_prompt_image])
2566
  prompt_radio.change(update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
2567
  # prompt_radio.change(updaste_tsne_plot_change_granularity,
@@ -2576,7 +2595,9 @@ with demo:
2576
  tsne_image_select.visible = True
2577
  tsne_select.visible = False
2578
  image_select.visible = False
2579
- prompt_radio.change(fn=lambda x: gr.update(visible=x=="Image"), inputs=prompt_radio, outputs=[image_select])
 
 
2580
  prompt_radio.change(fn=lambda x: gr.update(visible=x=="Tree [+Image]"), inputs=prompt_radio, outputs=[tsne_image_select])
2581
 
2582
  MAX_ROWS = 20
@@ -2595,6 +2616,7 @@ with demo:
2595
  # output_row_occupy[i_row-1] = False
2596
  return output_row_occupy, gr.update(visible=False)
2597
  delete_button.click(partial(delete_a_row, i_row=i_row), output_row_occupy, outputs=[output_row_occupy, inspect_output_row])
 
2598
  return inspect_output_row, output_tree_image, heatmap_gallery, text_block
2599
 
2600
  gr.Markdown('---')
@@ -2662,21 +2684,23 @@ with demo:
2662
  closest_idx = np.argmax(sim)
2663
  return closest_idx, (_x_ratio, _y_ratio)
2664
 
2665
- def find_closest_fps_point(prompt_radio, tsne_image_prompt, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs):
2666
  try:
2667
  if prompt_radio == "Tree":
2668
  return find_closest_fps_point_for_tsne_tree_plot(tsne_prompt, tsne2d_embed)
2669
- if prompt_radio == "Image":
2670
  return find_closest_fps_point_for_image_prompt(image_prompt, i_image, eigvecs, fps_eigvecs)
 
 
2671
  if prompt_radio == "Tree [+Image]":
2672
  return find_closest_fps_point_for_tsne_tree_plot(tsne_image_prompt, tsne2d_embed)
2673
  except:
2674
  raise gr.Error("""No blue point is selected. <br/>Please left-click on the image to select a blue point. <br/>After reloading the image (e.g., change granularity), please use the eraser to remove the previous point, then click on the image to select a blue point.""")
2675
 
2676
- def run_inspection(tsne_image_prompt, tsne_prompt, image_prompt, prompt_radio, current_output_row, tsne2d_embed, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity, eigvecs, i_image, tsne3d_rgb, input_gallery, output_row_occupy, max_rows=MAX_ROWS):
2677
  if len(tsne2d_embed) == 0:
2678
  raise gr.Error("Please run FPS+Cluster first.")
2679
- closest_idx, (_x, _y) = find_closest_fps_point(prompt_radio, tsne_image_prompt, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs)
2680
  closest_rgb = fps_tsne_rgb[closest_idx]
2681
  closest_rgb = (closest_rgb * 255).astype(np.uint8)
2682
 
@@ -2686,7 +2710,7 @@ with demo:
2686
  logging_text = f"Clicked: idx={closest_idx}, xy=[{_x:.2f}, {_y:.2f}], RGB={closest_rgb}"
2687
  logging_text += f"\nGranularity: k={granularity}, Connected: n={len(connected_idxs)}"
2688
 
2689
- output_tsne_plot = plot_tsne_tree(tsne2d_embed, edges, fps_tsne_rgb, granularity, closest_idx, highlight_connections=True)
2690
 
2691
  # draw heatmap for the connected components
2692
  ## cosine distance
@@ -2758,7 +2782,7 @@ with demo:
2758
 
2759
  run_inspection_button.click(
2760
  run_inspection,
2761
- inputs=[tsne_image_plot, tsne_prompt_image, image_plot, prompt_radio, current_output_row, tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity_slider, eigvecs, image_slider, tsne3d_rgb, input_gallery, output_row_occupy],
2762
  outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, output_row_occupy, inspect_logging_text],
2763
  )
2764
 
@@ -2767,7 +2791,7 @@ with demo:
2767
  with gr.Row():
2768
  with gr.Column(scale=5, min_width=200):
2769
  gr.Markdown("### Step 1: Load Images")
2770
- input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=10)
2771
  submit_button.visible = False
2772
  num_images_slider.value = 30
2773
 
@@ -2877,7 +2901,10 @@ with demo:
2877
  right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx]
2878
  y, x = None, None
2879
  else:
2880
- right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
 
 
 
2881
  right = right[:n_eig]
2882
  left = F.normalize(left, p=2, dim=-1)
2883
  _right = F.normalize(right, p=2, dim=-1)
 
2303
 
2304
  tsne_2d_points = gr.State(np.array([]))
2305
  edges = gr.State(np.array([]))
2306
+ levels = gr.State(np.array([]))
2307
  fps_eigvecs = gr.State(np.array([]))
2308
  fps_indices = gr.State(np.array([]))
2309
  fps_tsne_rgb = gr.State(np.array([]))
2310
 
2311
+ def plot_tsne_tree(tsne_embed, edges, levels, fps_tsne3d_rgb, k, hightlight_idx=None, highlight_connections=False):
2312
  # Plot the t-SNE points
2313
  fig, ax = plt.subplots(1, 1, figsize=(6, 6))
2314
  ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne3d_rgb)
 
2317
  max_length = lengthes[k:].max()
2318
  diag_length = np.linalg.norm(tsne_embed.max(axis=0) - tsne_embed.min(axis=0))
2319
  # draw the edges
2320
+ # for i_edge in range(k, len(edges)):
2321
+ # edge = edges[i_edge]
2322
+ # # _do = np.clip(lengthes[i_edge] / (diag_length*0.3), 0, 1)
2323
+ # if lengthes[i_edge] > diag_length*0.1:
2324
+ # _do = 1.0
2325
+ # else:
2326
+ # _do = 0.0
2327
+ # alpha = 0.7 * (1 - _do) + 0.0
2328
+ # ax.plot(tsne_embed[edge, 0], tsne_embed[edge, 1], 'k-', lw=1, alpha=alpha)
2329
+ max_level = levels.max()
2330
+ line_widths = levels / max_level * 3
2331
+ # flip line width
2332
+ line_widths = 3 - line_widths
2333
  for i_edge in range(k, len(edges)):
2334
  edge = edges[i_edge]
2335
+ level = levels[i_edge]
2336
+ plt.plot(tsne_embed[edge, 0], tsne_embed[edge, 1], c='gray', lw=line_widths[level], alpha=0.5)
2337
+
 
 
 
 
2338
  # highlight the selected node
2339
  if hightlight_idx is not None:
2340
  if highlight_connections:
 
2487
  tsne_embed[:, 1] = (tsne_embed[:, 1] - tsne_embed[:, 1].min()) / (tsne_embed[:, 1].max() - tsne_embed[:, 1].min()) * 2 - 1
2488
 
2489
  if tree_method == 'eigvecs':
2490
+ edges, levels = build_tree(fps_eigvecs, dist='cosine')
2491
  if tree_method == 'tsne':
2492
+ edges, levels = build_tree(tsne_embed, dist='euclidean')
2493
 
2494
  # Plot the t-SNE points
2495
+ pil_image = plot_tsne_tree(tsne_embed, edges, levels, fps_tsne3d_rgb, 0)
2496
 
2497
  # Plot the t-SNE points with image heatmaps
2498
  big_pil_image = plot_tsne_with_image_heatmaps(images, eigvecs, fps_eigvecs, tsne_embed, fps_tsne3d_rgb, max_display_dots)
2499
 
2500
+ return tsne_embed, edges, levels, fps_eigvecs, fps_tsne3d_rgb, fps_idx, pil_image, gr.update(value={'image': big_pil_image, 'points': []}, interactive=True)
2501
 
2502
  gr.Markdown('---')
2503
  gr.Markdown('<h3 style="text-align: center;">↓ interactively inspect the hierarchical structure</h3>')
 
2508
  run_hierarchical_button.click(
2509
  run_fps_tsne_hierarchical,
2510
  inputs=[input_gallery, eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider, tree_method_radio],
2511
+ outputs=[tsne_2d_points, edges, levels, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot, tsne_image_plot],
2512
  )
2513
  with gr.Row():
2514
  with gr.Column(scale=5, min_width=200) as tsne_select:
 
2536
  image_slider.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
2537
  output_gallery.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
2538
  output_gallery.change(fn=lambda x: gr.update(maximum=len(x)-1, interactive=True), inputs=[output_gallery], outputs=[image_slider])
2539
+ with gr.Column(scale=5, min_width=200) as orig_image_select:
2540
+ gr.Markdown('---')
2541
+ gr.Markdown('<h3 style="text-align: center;">Please click on the image blow ↓</h3>')
2542
+ gr.Markdown('---')
2543
+ image_plot2 = ImagePrompter(show_label=True, elem_id="image_plot", interactive=False, label="NCUT spectral-tSNE")
2544
+ image_slider2 = gr.Slider(0, 100, step=1, label="Image Index", value=0, elem_id="image_slider", interactive=True)
2545
+ image_slider2.change(fn=update_image_prompt, inputs=[image_slider2, input_gallery], outputs=[image_plot2])
2546
+ output_gallery.change(fn=update_image_prompt, inputs=[image_slider2, input_gallery], outputs=[image_plot2])
2547
+ output_gallery.change(fn=lambda x: gr.update(maximum=len(x)-1, interactive=True), inputs=[output_gallery], outputs=[image_slider2])
2548
  with gr.Column(scale=5, min_width=200) as tsne_image_select:
2549
  gr.Markdown('---')
2550
  gr.Markdown('<h3 style="text-align: center;">Please click on the image above ↑</h3>')
 
2569
  2. Cluster Heatmap: max of N cosine similarity to N points in the connected components.
2570
  """)
2571
  with gr.Column(scale=5, min_width=200):
2572
+ prompt_radio = gr.Radio(["Tree [+Image]", "Image (NCUT)", "Image (Orig)"], label="Where to click on?", value="Tree [+Image]", elem_id="prompt_radio", show_label=True)
2573
  granularity_slider = gr.Slider(1, 1000, step=1, label="Cluster Granularity (k)", value=100, elem_id="granularity")
2574
  num_sample_fps_slider.change(fn=lambda x: gr.update(maximum=x, interactive=True), inputs=[num_sample_fps_slider], outputs=[granularity_slider])
2575
+ def updaste_tsne_plot_change_granularity(granularity, tsne_embed, edges, levels, fps_tsne_rgb, tsne_prompt_image):
2576
  # Plot the t-SNE points
2577
+ pil_image = plot_tsne_tree(tsne_embed, edges, levels, fps_tsne_rgb, granularity)
2578
  return gr.update(value=pil_image, label=f"spectral-tSNE tree [k={granularity}]")
2579
  granularity_slider.change(updaste_tsne_plot_change_granularity,
2580
+ inputs=[granularity_slider, tsne_2d_points, edges, levels, fps_tsne_rgb, tsne_prompt_image],
2581
  outputs=[tsne_non_prompt_image])
2582
  tsne_plot.change(updaste_tsne_plot_change_granularity,
2583
+ inputs=[granularity_slider, tsne_2d_points, edges, levels, fps_tsne_rgb],
2584
  outputs=[tsne_non_prompt_image])
2585
  prompt_radio.change(update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
2586
  # prompt_radio.change(updaste_tsne_plot_change_granularity,
 
2595
  tsne_image_select.visible = True
2596
  tsne_select.visible = False
2597
  image_select.visible = False
2598
+ orig_image_select.visible = False
2599
+ prompt_radio.change(fn=lambda x: gr.update(visible=x=="Image (NCUT)"), inputs=prompt_radio, outputs=[image_select])
2600
+ prompt_radio.change(fn=lambda x: gr.update(visible=x=="Image (Orig)"), inputs=prompt_radio, outputs=[orig_image_select])
2601
  prompt_radio.change(fn=lambda x: gr.update(visible=x=="Tree [+Image]"), inputs=prompt_radio, outputs=[tsne_image_select])
2602
 
2603
  MAX_ROWS = 20
 
2616
  # output_row_occupy[i_row-1] = False
2617
  return output_row_occupy, gr.update(visible=False)
2618
  delete_button.click(partial(delete_a_row, i_row=i_row), output_row_occupy, outputs=[output_row_occupy, inspect_output_row])
2619
+ delete_button.visible = False
2620
  return inspect_output_row, output_tree_image, heatmap_gallery, text_block
2621
 
2622
  gr.Markdown('---')
 
2684
  closest_idx = np.argmax(sim)
2685
  return closest_idx, (_x_ratio, _y_ratio)
2686
 
2687
+ def find_closest_fps_point(prompt_radio, tsne_image_prompt, tsne_prompt, image_prompt, image_prompt2, i_image, i_image2, tsne2d_embed, eigvecs, fps_eigvecs):
2688
  try:
2689
  if prompt_radio == "Tree":
2690
  return find_closest_fps_point_for_tsne_tree_plot(tsne_prompt, tsne2d_embed)
2691
+ if prompt_radio == "Image (NCUT)":
2692
  return find_closest_fps_point_for_image_prompt(image_prompt, i_image, eigvecs, fps_eigvecs)
2693
+ if prompt_radio == "Image (Orig)":
2694
+ return find_closest_fps_point_for_image_prompt(image_prompt2, i_image2, eigvecs, fps_eigvecs)
2695
  if prompt_radio == "Tree [+Image]":
2696
  return find_closest_fps_point_for_tsne_tree_plot(tsne_image_prompt, tsne2d_embed)
2697
  except:
2698
  raise gr.Error("""No blue point is selected. <br/>Please left-click on the image to select a blue point. <br/>After reloading the image (e.g., change granularity), please use the eraser to remove the previous point, then click on the image to select a blue point.""")
2699
 
2700
+ def run_inspection(tsne_image_prompt, tsne_prompt, image_prompt, image_prompt2, prompt_radio, current_output_row, tsne2d_embed, edges, levels, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity, eigvecs, i_image, i_image2, tsne3d_rgb, input_gallery, output_row_occupy, max_rows=MAX_ROWS):
2701
  if len(tsne2d_embed) == 0:
2702
  raise gr.Error("Please run FPS+Cluster first.")
2703
+ closest_idx, (_x, _y) = find_closest_fps_point(prompt_radio, tsne_image_prompt, tsne_prompt, image_prompt, image_prompt2, i_image, i_image2, tsne2d_embed, eigvecs, fps_eigvecs)
2704
  closest_rgb = fps_tsne_rgb[closest_idx]
2705
  closest_rgb = (closest_rgb * 255).astype(np.uint8)
2706
 
 
2710
  logging_text = f"Clicked: idx={closest_idx}, xy=[{_x:.2f}, {_y:.2f}], RGB={closest_rgb}"
2711
  logging_text += f"\nGranularity: k={granularity}, Connected: n={len(connected_idxs)}"
2712
 
2713
+ output_tsne_plot = plot_tsne_tree(tsne2d_embed, edges, levels, fps_tsne_rgb, granularity, closest_idx, highlight_connections=True)
2714
 
2715
  # draw heatmap for the connected components
2716
  ## cosine distance
 
2782
 
2783
  run_inspection_button.click(
2784
  run_inspection,
2785
+ inputs=[tsne_image_plot, tsne_prompt_image, image_plot, image_plot2, prompt_radio, current_output_row, tsne_2d_points, edges, levels, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity_slider, eigvecs, image_slider, image_slider2, tsne3d_rgb, input_gallery, output_row_occupy],
2786
  outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, output_row_occupy, inspect_logging_text],
2787
  )
2788
 
 
2791
  with gr.Row():
2792
  with gr.Column(scale=5, min_width=200):
2793
  gr.Markdown("### Step 1: Load Images")
2794
+ input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=30)
2795
  submit_button.visible = False
2796
  num_images_slider.value = 30
2797
 
 
2901
  right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx]
2902
  y, x = None, None
2903
  else:
2904
+ try:
2905
+ right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs)
2906
+ except:
2907
+ raise gr.Error("""No blue point is selected. <br/>Please left-click on the image to select a blue point. <br/>After reloading the image, please use the eraser (top-right) to remove the previous point, then click on the image to select a blue point.""")
2908
  right = right[:n_eig]
2909
  left = F.normalize(left, p=2, dim=-1)
2910
  _right = F.normalize(right, p=2, dim=-1)