huzey commited on
Commit
55debdb
·
1 Parent(s): 02c9399

add inspect playground

Browse files
Files changed (2) hide show
  1. app.py +425 -19
  2. fps_cluster.py +78 -0
app.py CHANGED
@@ -2,6 +2,7 @@
2
  # %%
3
  import copy
4
  from datetime import datetime
 
5
  import math
6
  import pickle
7
  from functools import partial
@@ -168,8 +169,6 @@ def compute_ncut(
168
  logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
169
 
170
  if only_eigvecs:
171
- eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
172
- eigvecs = eigvecs.detach().numpy()
173
  return None, logging_str, eigvecs
174
 
175
  start = time.time()
@@ -285,10 +284,15 @@ def to_pil_images(images, target_size=512, resize=True, force_size=False):
285
  res = int(size * multiplier)
286
  if force_size:
287
  res = target_size
288
- pil_images = [
289
- Image.fromarray((image * 255).cpu().numpy().astype(np.uint8))
290
- for image in images
291
- ]
 
 
 
 
 
292
  if resize:
293
  pil_images = [
294
  image.resize((res, res), Image.Resampling.NEAREST)
@@ -865,6 +869,7 @@ def ncut_run(
865
  # ailgnedcut
866
  if not directed:
867
  only_eigvecs = kwargs.get("only_eigvecs", False)
 
868
 
869
  rgb, _logging_str, eigvecs = compute_ncut(
870
  features,
@@ -886,9 +891,20 @@ def ncut_run(
886
  only_eigvecs=only_eigvecs,
887
  )
888
 
 
889
  if only_eigvecs:
 
 
 
890
  return eigvecs, logging_str
891
 
 
 
 
 
 
 
 
892
  if directed:
893
  head_index_text = kwargs.get("head_index_text", None)
894
  n_heads = features.shape[-2] # (batch, h, w, n_heads, d)
@@ -1232,6 +1248,7 @@ def run_fn(
1232
  advanced=False,
1233
  directed=False,
1234
  only_eigvecs=False,
 
1235
  ):
1236
  # print(node_type2, head_index_text, make_symmetric)
1237
  progress=gr.Progress()
@@ -1373,6 +1390,7 @@ def run_fn(
1373
  "head_index_text": head_index_text,
1374
  "make_symmetric": make_symmetric,
1375
  "only_eigvecs": only_eigvecs,
 
1376
  }
1377
  # print(kwargs)
1378
 
@@ -1599,7 +1617,7 @@ def load_dataset_images(is_advanced, dataset_name, num_images=10,
1599
  is_advanced = "Basic"
1600
 
1601
  if is_advanced == "Basic":
1602
- gr.Info(f"Loaded images from EgoExo")
1603
  return default_images
1604
  try:
1605
  progress(0.5, desc="Downloading Dataset")
@@ -1644,7 +1662,7 @@ def load_dataset_images(is_advanced, dataset_name, num_images=10,
1644
  image_idx.extend(idx.tolist())
1645
  if not is_filter:
1646
  if is_random:
1647
- if num_images < len(dataset):
1648
  image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
1649
  else:
1650
  gr.Warning(f"Dataset has less than {num_images} images.")
@@ -1653,7 +1671,7 @@ def load_dataset_images(is_advanced, dataset_name, num_images=10,
1653
  image_idx = list(range(num_images))
1654
  key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
1655
  images = [dataset[i][key] for i in image_idx]
1656
- gr.Info(f"Loaded {len(images)} images from {dataset_name}")
1657
  del dataset
1658
 
1659
  if dataset_name in CENTER_CROP_DATASETS:
@@ -2060,7 +2078,7 @@ def make_output_images_section(markdown=True, button=True):
2060
  add_rotate_flip_buttons(output_gallery)
2061
  return output_gallery
2062
 
2063
- def make_parameters_section(is_lisa=False, model_ratio=True, parameter_dropdown=True):
2064
  gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
2065
  from ncut_pytorch.backbone import list_models, get_demo_model_names
2066
  model_names = list_models()
@@ -2089,7 +2107,7 @@ def make_parameters_section(is_lisa=False, model_ratio=True, parameter_dropdown=
2089
  positive_prompt.visible = False
2090
  negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'")
2091
  negative_prompt.visible = False
2092
- node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
2093
  num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for smaller clusters')
2094
 
2095
  def change_layer_slider(model_name):
@@ -2125,7 +2143,7 @@ def make_parameters_section(is_lisa=False, model_ratio=True, parameter_dropdown=
2125
  gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False))
2126
  model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt])
2127
 
2128
- with gr.Accordion("Advanced Parameters: NCUT", open=False, visible=parameter_dropdown):
2129
  gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
2130
  affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
2131
  num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
@@ -2136,7 +2154,7 @@ def make_parameters_section(is_lisa=False, model_ratio=True, parameter_dropdown=
2136
  ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
2137
  ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=True, elem_id="ncut_indirect_connection", info="Add indirect connection to the sub-sampled graph")
2138
  ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
2139
- with gr.Accordion("Advanced Parameters: Visualization", open=False, visible=parameter_dropdown):
2140
  # embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
2141
  embedding_method_dropdown = gr.Radio(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
2142
  # embedding_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="t-SNE/UMAP metric", value="euclidean", elem_id="embedding_metric")
@@ -2168,8 +2186,396 @@ demo = gr.Blocks(
2168
  )
2169
  with demo:
2170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2171
 
 
 
 
 
 
 
2172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2173
  with gr.Tab('AlignedCut'):
2174
 
2175
  with gr.Row():
@@ -3448,7 +3854,7 @@ with demo:
3448
  with gr.Row():
3449
  with gr.Column(scale=5, min_width=200):
3450
  gr.Markdown("### Step 1: Load Images and Run NCUT")
3451
- input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=10)
3452
  # submit_button.visible = False
3453
  num_images_slider.value = 30
3454
  [
@@ -3457,7 +3863,7 @@ with demo:
3457
  embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
3458
  perplexity_slider, n_neighbors_slider, min_dist_slider,
3459
  sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
3460
- ] = make_parameters_section(parameter_dropdown=False)
3461
  num_eig_slider.value = 1000
3462
  num_eig_slider.visible = False
3463
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
@@ -3585,7 +3991,7 @@ with demo:
3585
  pil_images = overlaied_images
3586
  return pil_images, (y, x)
3587
 
3588
- def farthest_point_sampling(
3589
  features,
3590
  start_feature,
3591
  num_sample=300,
@@ -3635,7 +4041,7 @@ with demo:
3635
  num_childs = min(4, masked_eigvecs.shape[0])
3636
  assert num_childs > 0
3637
 
3638
- child_idx = farthest_point_sampling(masked_eigvecs, _picked_eigvec, num_sample=num_childs+1)
3639
  child_idx = np.sort(child_idx)[:-1]
3640
 
3641
  # convert child_idx to flat_idx
@@ -3718,7 +4124,7 @@ with demo:
3718
  with gr.Row():
3719
  with gr.Column(scale=5, min_width=200):
3720
  gr.Markdown("### Step 1: Load Images")
3721
- input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=10)
3722
  submit_button.visible = False
3723
  num_images_slider.value = 30
3724
 
@@ -3735,7 +4141,7 @@ with demo:
3735
  embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
3736
  perplexity_slider, n_neighbors_slider, min_dist_slider,
3737
  sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
3738
- ] = make_parameters_section(parameter_dropdown=False)
3739
  num_eig_slider.value = 1024
3740
  num_eig_slider.visible = False
3741
  submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary')
 
2
  # %%
3
  import copy
4
  from datetime import datetime
5
+ import io
6
  import math
7
  import pickle
8
  from functools import partial
 
169
  logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
170
 
171
  if only_eigvecs:
 
 
172
  return None, logging_str, eigvecs
173
 
174
  start = time.time()
 
284
  res = int(size * multiplier)
285
  if force_size:
286
  res = target_size
287
+
288
+ pil_images = []
289
+ for image in images:
290
+ if isinstance(image, torch.Tensor):
291
+ image = image.cpu().numpy()
292
+ if image.dtype == np.float32 or image.dtype == np.float64:
293
+ image = (image * 255).astype(np.uint8)
294
+ pil_images.append(Image.fromarray(image))
295
+
296
  if resize:
297
  pil_images = [
298
  image.resize((res, res), Image.Resampling.NEAREST)
 
869
  # ailgnedcut
870
  if not directed:
871
  only_eigvecs = kwargs.get("only_eigvecs", False)
872
+ return_eigvec_and_rgb = kwargs.get("return_eigvec_and_rgb", False)
873
 
874
  rgb, _logging_str, eigvecs = compute_ncut(
875
  features,
 
891
  only_eigvecs=only_eigvecs,
892
  )
893
 
894
+
895
  if only_eigvecs:
896
+ eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
897
+ eigvecs = eigvecs.detach().numpy()
898
+ logging_str += _logging_str
899
  return eigvecs, logging_str
900
 
901
+ if return_eigvec_and_rgb:
902
+ eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
903
+ eigvecs = eigvecs.detach().numpy()
904
+ rgb = rgb.cpu().numpy()
905
+ logging_str += _logging_str
906
+ return eigvecs, rgb, logging_str
907
+
908
  if directed:
909
  head_index_text = kwargs.get("head_index_text", None)
910
  n_heads = features.shape[-2] # (batch, h, w, n_heads, d)
 
1248
  advanced=False,
1249
  directed=False,
1250
  only_eigvecs=False,
1251
+ return_eigvec_and_rgb=False,
1252
  ):
1253
  # print(node_type2, head_index_text, make_symmetric)
1254
  progress=gr.Progress()
 
1390
  "head_index_text": head_index_text,
1391
  "make_symmetric": make_symmetric,
1392
  "only_eigvecs": only_eigvecs,
1393
+ "return_eigvec_and_rgb": return_eigvec_and_rgb,
1394
  }
1395
  # print(kwargs)
1396
 
 
1617
  is_advanced = "Basic"
1618
 
1619
  if is_advanced == "Basic":
1620
+ gr.Info(f"Loaded images from EgoExo", duration=5)
1621
  return default_images
1622
  try:
1623
  progress(0.5, desc="Downloading Dataset")
 
1662
  image_idx.extend(idx.tolist())
1663
  if not is_filter:
1664
  if is_random:
1665
+ if num_images <= len(dataset):
1666
  image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
1667
  else:
1668
  gr.Warning(f"Dataset has less than {num_images} images.")
 
1671
  image_idx = list(range(num_images))
1672
  key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
1673
  images = [dataset[i][key] for i in image_idx]
1674
+ gr.Info(f"Loaded {len(images)} images from {dataset_name}", duration=5)
1675
  del dataset
1676
 
1677
  if dataset_name in CENTER_CROP_DATASETS:
 
2078
  add_rotate_flip_buttons(output_gallery)
2079
  return output_gallery
2080
 
2081
+ def make_parameters_section(is_lisa=False, model_ratio=True, ncut_parameter_dropdown=True, tsne_parameter_dropdown=True):
2082
  gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
2083
  from ncut_pytorch.backbone import list_models, get_demo_model_names
2084
  model_names = list_models()
 
2107
  positive_prompt.visible = False
2108
  negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'")
2109
  negative_prompt.visible = False
2110
+ node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type")
2111
  num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for smaller clusters')
2112
 
2113
  def change_layer_slider(model_name):
 
2143
  gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False))
2144
  model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt])
2145
 
2146
+ with gr.Accordion("Advanced Parameters: NCUT", open=False, visible=ncut_parameter_dropdown):
2147
  gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
2148
  affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
2149
  num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
 
2154
  ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
2155
  ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=True, elem_id="ncut_indirect_connection", info="Add indirect connection to the sub-sampled graph")
2156
  ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
2157
+ with gr.Accordion("Advanced Parameters: Visualization", open=False, visible=tsne_parameter_dropdown):
2158
  # embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
2159
  embedding_method_dropdown = gr.Radio(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
2160
  # embedding_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="t-SNE/UMAP metric", value="euclidean", elem_id="embedding_metric")
 
2186
  )
2187
  with demo:
2188
 
2189
+ with gr.Tab('Hierarchical (dev)'):
2190
+ eigvecs = gr.State(np.array([]))
2191
+ tsne3d_rgb = gr.State(np.array([]))
2192
+ with gr.Row():
2193
+ with gr.Column(scale=5, min_width=200):
2194
+ # gr.Markdown("### Step 1: Load Images")
2195
+ input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=100, markdown=False)
2196
+ submit_button.value = "🔴 RUN NCUT"
2197
+ num_images_slider.value = 100
2198
+
2199
+ false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
2200
+ no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
2201
+
2202
+ with gr.Column(scale=5, min_width=200):
2203
+ # gr.Markdown("### Step 2a: Run Backbone and NCUT")
2204
+ # with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
2205
+ output_gallery = gr.Gallery(format='png', value=[], label="NCUT spectral-tSNE", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
2206
+ def add_rotate_flip_buttons_with_state(output_gallery, tsne3d_rgb):
2207
+ with gr.Row():
2208
+ rotate_button = gr.Button("🔄 Rotate", elem_id="rotate_button", variant='secondary')
2209
+ rotate_button.click(sequence_rotate_rgb_gallery, inputs=[output_gallery], outputs=[output_gallery])
2210
+ def rotate_state(arr):
2211
+ rotation_matrix = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32)
2212
+ return arr @ rotation_matrix
2213
+ rotate_button.click(rotate_state, inputs=[tsne3d_rgb], outputs=[tsne3d_rgb])
2214
+ flip_button = gr.Button("🔃 Flip", elem_id="flip_button", variant='secondary')
2215
+ flip_button.click(flip_rgb_gallery, inputs=[output_gallery], outputs=[output_gallery])
2216
+ def flip_state(arr):
2217
+ return 1 - arr
2218
+ flip_button.click(flip_state, inputs=[tsne3d_rgb], outputs=[tsne3d_rgb])
2219
+ return rotate_button, flip_button
2220
+ add_rotate_flip_buttons_with_state(output_gallery, tsne3d_rgb)
2221
+
2222
+ [
2223
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
2224
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
2225
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
2226
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
2227
+ sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
2228
+ ] = make_parameters_section(ncut_parameter_dropdown=True, tsne_parameter_dropdown=True)
2229
+ # submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary')
2230
+ logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
2231
+
2232
+ def __run_fn(*args, **kwargs):
2233
+ eigvecs, rgb, logging_str = run_fn(*args, **kwargs)
2234
+ rgb_gallery = to_pil_images(rgb)
2235
+ return eigvecs, rgb, rgb_gallery, logging_str
2236
+
2237
+ submit_button.click(
2238
+ partial(__run_fn, n_ret=2, return_eigvec_and_rgb=True),
2239
+ inputs=[
2240
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
2241
+ positive_prompt, negative_prompt,
2242
+ false_placeholder, no_prompt, no_prompt, no_prompt,
2243
+ affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
2244
+ embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
2245
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
2246
+ ],
2247
+ outputs=[eigvecs, tsne3d_rgb, output_gallery, logging_text],
2248
+ )
2249
+
2250
+ with gr.Column(scale=5, min_width=200):
2251
+ gr.Markdown('---')
2252
+ gr.Markdown('<h3 style="text-align: center;">Help</h3>')
2253
+ gr.Markdown('---')
2254
+ with gr.Accordion("Instructions", open=True):
2255
+ gr.Markdown("""
2256
+ 1. Load Dataset (left).
2257
+ 2. Choose parameters (middle).
2258
+ 3. 🔴 RUN NCUT.
2259
+ 4. 🔴 RUN FPS+Cluster.
2260
+ 5. Interact and Inspect (scroll down).
2261
+ """)
2262
+ with gr.Accordion("Methods: NCUT spectral-TSNE", open=False):
2263
+ gr.Markdown("### <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_ncut_works/' target='_blank'>Documentation: How NCUT works</a>")
2264
+ gr.Markdown("""
2265
+ 1. Run Backbone, feature extraction for each image.
2266
+ 2. Vectorize latent-pixels, concatenate all the images.
2267
+ 3. Run NCUT, on one big graph of all the images.
2268
+ 4. Run spectral-tSNE on the NCUT eigenvectors.
2269
+ 5. Plot the 3D spectral-tSNE as RGB.
2270
+ """)
2271
+ with gr.Accordion("Methods: Hierarchical Structure", open=False):
2272
+ gr.Markdown("""
2273
+ 1. Farthest Point Sampling (FPS) on the eigenvectors.
2274
+ 2. spectral-tSNE (2D) on the FPS sampled points.
2275
+ 3. Hierarchical clustering on the FPS sampled points.
2276
+ """)
2277
+ gr.Markdown('---')
2278
+
2279
+ run_hierarchical_button = gr.Button("🔴 RUN FPS+Cluster", elem_id="run_hierarchical", variant='primary')
2280
+ with gr.Accordion("Hierarchical Structure Parameters:", open=True):
2281
+ num_sample_fps_slider = gr.Slider(1, 5000, step=1, label="FPS: num_sample", value=1000, elem_id="num_sample_fps")
2282
+ tsne_perplexity_slider = gr.Slider(1, 1000, step=1, label="t-SNE: perplexity", value=500, elem_id="perplexity_tsne")
2283
+ fps_hc_seed_slider = gr.Slider(0, 1000, step=1, label="Seed", value=0, elem_id="fps_hc_seed")
2284
+ tsne_plot = gr.Image(label="spectral-tSNE tree", elem_id="tsne_plot", interactive=False, format='png')
2285
+
2286
+ tsne_2d_points = gr.State(np.array([]))
2287
+ edges = gr.State(np.array([]))
2288
+ fps_eigvecs = gr.State(np.array([]))
2289
+ fps_indices = gr.State(np.array([]))
2290
+ fps_tsne_rgb = gr.State(np.array([]))
2291
+
2292
+ def plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, k, hightlight_idx=None, highlight_connections=False):
2293
+ # Plot the t-SNE points
2294
+ fig, ax = plt.subplots(1, 1, figsize=(6, 6))
2295
+ ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne3d_rgb)
2296
+ # draw the edges
2297
+ for i_edge in range(k, len(edges)):
2298
+ edge = edges[i_edge]
2299
+ ax.plot(tsne_embed[edge, 0], tsne_embed[edge, 1], 'k-', lw=1, alpha=0.7)
2300
+ # highlight the selected node
2301
+ if hightlight_idx is not None:
2302
+ if highlight_connections:
2303
+ from fps_cluster import find_connected_component
2304
+ _edges = edges[k:, :]
2305
+ connected_nodes = find_connected_component(_edges, hightlight_idx)
2306
+ ax.scatter(tsne_embed[connected_nodes, 0], tsne_embed[connected_nodes, 1], s=50, c=fps_tsne3d_rgb[connected_nodes], marker='D', edgecolor='deeppink', linewidth=1)
2307
+ # ax.scatter(tsne_embed[hightlight_idx, 0], tsne_embed[hightlight_idx, 1], s=300, c='r', marker='x')
2308
+ ax.scatter(tsne_embed[hightlight_idx, 0], tsne_embed[hightlight_idx, 1], s=200, c='cyan', marker='o', edgecolor='black', linewidth=1)
2309
+ ax.set_xticks([])
2310
+ ax.set_yticks([])
2311
+ ax.axis('off')
2312
+ ax.set_xlim(tsne_embed[:, 0].min()*1.1, tsne_embed[:, 0].max()*1.1)
2313
+ ax.set_ylim(tsne_embed[:, 1].min()*1.1, tsne_embed[:, 1].max()*1.1)
2314
+
2315
+ # Remove the white space around the plot
2316
+ fig.tight_layout(pad=0)
2317
+
2318
+ # Save the plot to an in-memory buffer
2319
+ buf = io.BytesIO()
2320
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
2321
+ buf.seek(0)
2322
+
2323
+ # Load the image into a NumPy array
2324
+ image = np.array(Image.open(buf))
2325
 
2326
+ # Close the buffer and plot
2327
+ buf.close()
2328
+ plt.close(fig)
2329
+
2330
+ pil_image = Image.fromarray(image)
2331
+ return pil_image
2332
 
2333
+ def run_fps_tsne_hierarchical(eigvecs, num_sample_fps, perplexity_tsne, tsne3d_rgb, seed=0):
2334
+ if len(eigvecs) == 0:
2335
+ gr.Warning("Please run NCUT first.")
2336
+ return
2337
+ eigvecs = torch.tensor(eigvecs)
2338
+ eigvecs = eigvecs.reshape(-1, eigvecs.shape[-1])
2339
+ gr.Info("Running FPS, t-SNE, and Hierarchical Clustering...", 3)
2340
+ from ncut_pytorch.ncut_pytorch import farthest_point_sampling
2341
+ from sklearn.manifold import TSNE
2342
+ from fps_cluster import build_tree
2343
+
2344
+ torch.manual_seed(seed)
2345
+ np.random.seed(seed)
2346
+
2347
+ fps_idx = farthest_point_sampling(eigvecs, num_sample_fps)
2348
+ fps_eigvecs = eigvecs[fps_idx]
2349
+ fps_eigvecs = fps_eigvecs.numpy()
2350
+
2351
+ tsne3d_rgb = tsne3d_rgb.reshape(-1, 3)
2352
+ fps_tsne3d_rgb = tsne3d_rgb[fps_idx]
2353
+
2354
+ np.random.seed(seed)
2355
+ tsne_embed = TSNE(
2356
+ n_components=2,
2357
+ perplexity=perplexity_tsne,
2358
+ metric='cosine',
2359
+ random_state=seed,
2360
+ ).fit_transform(fps_eigvecs)
2361
+
2362
+ edges = build_tree(tsne_embed)
2363
+
2364
+ # Plot the t-SNE points
2365
+ pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, 0)
2366
+
2367
+ return tsne_embed, edges, fps_eigvecs, fps_tsne3d_rgb, fps_idx, pil_image
2368
+
2369
+ run_hierarchical_button.click(
2370
+ run_fps_tsne_hierarchical,
2371
+ inputs=[eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider],
2372
+ outputs=[tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot],
2373
+ )
2374
+ gr.Markdown('---')
2375
+ gr.Markdown('<h3 style="text-align: center;">↓ interactively inspect the hierarchical structure</h3>')
2376
+ gr.Markdown('---')
2377
+ with gr.Row():
2378
+ from gradio_image_prompter import ImagePrompter
2379
+ with gr.Column(scale=5, min_width=200) as tsne_select:
2380
+ tsne_prompt_image = ImagePrompter(show_label=True, elem_id="tsne_prompt_image", interactive=False, label="spectral-tSNE tree")
2381
+ # copy plot to tsne_prompt_image on change
2382
+ # tsne_plot.change(fn=lambda x: gr.update(value={'image': x}, interactive=True),
2383
+ # inputs=[tsne_plot], outputs=[tsne_prompt_image])
2384
+ with gr.Column(scale=5, min_width=200) as image_select:
2385
+ image_plot = ImagePrompter(show_label=True, elem_id="image_plot", interactive=False, label="NCUT spectral-tSNE")
2386
+ image_slider = gr.Slider(0, 100, step=1, label="Image Index", value=0, elem_id="image_slider", interactive=True)
2387
+ def update_image_prompt(image_slider, output_gallery):
2388
+ if len(output_gallery) == 0:
2389
+ return gr.update(value=None, interactive=False)
2390
+ image_idx = int(image_slider)
2391
+ image = output_gallery[image_idx][0]
2392
+ return gr.update(value={'image': image}, interactive=True)
2393
+ image_slider.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
2394
+ output_gallery.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
2395
+ output_gallery.change(fn=lambda x: gr.update(maximum=len(x)-1, interactive=True), inputs=[output_gallery], outputs=[image_slider])
2396
+ with gr.Column(scale=5, min_width=200):
2397
+ gr.Markdown('<h3 style="text-align: center;">Help</h3>')
2398
+ with gr.Accordion("Instructions", open=True):
2399
+ gr.Markdown("""
2400
+ 1. Click one dot on the left-side image.
2401
+ - Only the last clicked dot will be used
2402
+ - Eraser is at top-right corner
2403
+ - Use the right-side Radio to switch tree/image
2404
+ 2. Choose a granularity (right-side).
2405
+ 3. 🔴 RUN Inspection.
2406
+ 4. Output will be shown below.
2407
+ """)
2408
+ with gr.Accordion("Outputs", open=True):
2409
+ gr.Markdown("""
2410
+ 1. spectral-tSNE tree: ◆ means connected components to the selected point.
2411
+ 2. Cluster Heatmap: max cosine similarity to any points in the connected components.
2412
+ """)
2413
+ with gr.Column(scale=5, min_width=200):
2414
+ prompt_radio = gr.Radio(["Tree", "Image"], label="Where to click on?", value="Tree", elem_id="prompt_radio", show_label=True)
2415
+ granularity_slider = gr.Slider(1, 1000, step=1, label="Cluster Granularity", value=100, elem_id="granularity")
2416
+ num_sample_fps_slider.change(fn=lambda x: gr.update(maximum=x, interactive=True), inputs=[num_sample_fps_slider], outputs=[granularity_slider])
2417
+ def updaste_tsne_plot_change_granularity(granularity, tsne_embed, edges, fps_tsne_rgb, tsne_prompt_image):
2418
+ # Plot the t-SNE points
2419
+ pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne_rgb, granularity)
2420
+ if tsne_prompt_image is None:
2421
+ return gr.update(value={'image': pil_image}, interactive=True)
2422
+ return gr.update(value={'image': pil_image, 'points': tsne_prompt_image['points']}, interactive=True)
2423
+ granularity_slider.change(updaste_tsne_plot_change_granularity,
2424
+ inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb, tsne_prompt_image],
2425
+ outputs=[tsne_prompt_image])
2426
+ tsne_plot.change(updaste_tsne_plot_change_granularity,
2427
+ inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb],
2428
+ outputs=[tsne_prompt_image])
2429
+ run_inspection_button = gr.Button("🔴 RUN Inspection", elem_id="run_inspection", variant='primary')
2430
+ inspect_logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="inspect_logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
2431
+ # output_slot_radio = gr.Radio([1, 2, 3], label="Output Row", value=1, elem_id="output_slot", show_label=True)
2432
+
2433
+ image_select.visible = False
2434
+ tsne_select.visible = True
2435
+ prompt_radio.change(fn=lambda x: gr.update(visible=x=="Tree"), inputs=prompt_radio, outputs=[tsne_select])
2436
+ prompt_radio.change(fn=lambda x: gr.update(visible=x=="Image"), inputs=prompt_radio, outputs=[image_select])
2437
+
2438
+ def make_one_output_row(i_row=1):
2439
+ with gr.Row() as inspect_output_row:
2440
+ with gr.Column(scale=5, min_width=200):
2441
+ output_tree_image = gr.Image(label=f"spectral-tSNE tree [row#{i_row}]", elem_id="output_image", interactive=False)
2442
+ text_block = gr.Textbox("", label="Logging", elem_id=f"logging_{i_row}", type="text", placeholder="Logging information", autofocus=False, autoscroll=False, lines=2, show_label=False)
2443
+ with gr.Column(scale=10, min_width=200):
2444
+ heatmap_gallery = gr.Gallery(format='png', value=[], label=f"Cluster Heatmap [row#{i_row}]", show_label=True, elem_id="heatmap", columns=[6], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
2445
+ return inspect_output_row, output_tree_image, heatmap_gallery, text_block
2446
+
2447
+ gr.Markdown('---')
2448
+ MAX_ROWS = 100
2449
+ current_output_row = gr.State(MAX_ROWS-1)
2450
+ inspect_output_rows, output_tree_images, heatmap_galleries, text_blocks = [], [], [], []
2451
+ for i_row in range(MAX_ROWS, 0, -1):
2452
+ inspect_output_row, output_tree_image, heatmap_gallery, text_block = make_one_output_row(i_row)
2453
+ inspect_output_row.visible = False
2454
+ inspect_output_rows.append(inspect_output_row)
2455
+ output_tree_images.append(output_tree_image)
2456
+ heatmap_galleries.append(heatmap_gallery)
2457
+ text_blocks.append(text_block)
2458
+
2459
+
2460
+ def relative_xy_last_positive(prompts):
2461
+ image = prompts['image']
2462
+ points = np.asarray(prompts['points'])
2463
+ if points.shape[0] == 0:
2464
+ return [], []
2465
+ is_point = points[:, 5] == 4.0
2466
+ points = points[is_point]
2467
+ is_positive = points[:, 2] == 1.0
2468
+ if is_positive.sum() == 0:
2469
+ raise Exception("No blue point is selected.")
2470
+ is_negative = points[:, 2] == 0.0
2471
+ xy = points[:, :2].tolist()
2472
+ if isinstance(image, str):
2473
+ image = Image.open(image)
2474
+ image = np.array(image)
2475
+ h, w = image.shape[:2]
2476
+ new_xy = [(x/w, y/h) for x, y in xy]
2477
+
2478
+ last_positive_idx = np.where(is_positive)[0][-1]
2479
+ x, y = new_xy[last_positive_idx]
2480
+ return x, y
2481
+
2482
+ def find_closest_fps_point_for_tsne_tree_plot(tsne_prompt, tsne2d_embed):
2483
+ x, y = relative_xy_last_positive(tsne_prompt)
2484
+ x_vmax = tsne2d_embed[:, 0].max() * 1.1
2485
+ x_vmin = tsne2d_embed[:, 0].min() * 1.1
2486
+ y_vmax = tsne2d_embed[:, 1].max() * 1.1
2487
+ y_vmin = tsne2d_embed[:, 1].min() * 1.1
2488
+ x = x * (x_vmax - x_vmin) + x_vmin
2489
+ y = 1 - y
2490
+ y = y * (y_vmax - y_vmin) + y_vmin
2491
+ dist = np.linalg.norm(tsne2d_embed - np.array([x, y]), axis=1)
2492
+ closest_idx = np.argmin(dist)
2493
+ return closest_idx
2494
+
2495
+ def find_closest_fps_point_for_image_prompt(image_prompt, i_image, eigvecs, fps_eigvecs):
2496
+ x, y = relative_xy_last_positive(image_prompt)
2497
+ _eigvec = eigvecs[i_image]
2498
+ h, w = _eigvec.shape[:2]
2499
+ x = int(x * w)
2500
+ y = int(y * h)
2501
+ eigvec = _eigvec[y, x]
2502
+ dist = np.linalg.norm(fps_eigvecs - eigvec, axis=1)
2503
+ closest_idx = np.argmin(dist)
2504
+ return closest_idx
2505
+
2506
+ def find_closest_fps_point(prompt_radio, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs):
2507
+ try:
2508
+ if prompt_radio == "Tree":
2509
+ return find_closest_fps_point_for_tsne_tree_plot(tsne_prompt, tsne2d_embed)
2510
+ if prompt_radio == "Image":
2511
+ return find_closest_fps_point_for_image_prompt(image_prompt, i_image, eigvecs, fps_eigvecs)
2512
+ except:
2513
+ raise gr.Error("""No blue point is selected. <br/>Please left-click on the image to select a blue point. <br/>After reloading the image (e.g., change granularity), please use the eraser to remove the previous point, then click on the image to select a blue point.""")
2514
+
2515
+ def run_inspection(tsne_prompt, image_prompt, prompt_radio, output_slot, tsne2d_embed, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity, eigvecs, i_image, tsne3d_rgb, input_gallery, max_rows=MAX_ROWS):
2516
+ if len(tsne2d_embed) == 0:
2517
+ raise gr.Error("Please run FPS+Cluster first.")
2518
+ closest_idx = find_closest_fps_point(prompt_radio, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs)
2519
+ closest_rgb = fps_tsne_rgb[closest_idx]
2520
+ closest_rgb = (closest_rgb * 255).astype(np.uint8)
2521
+
2522
+ from fps_cluster import find_connected_component
2523
+ connected_idxs = find_connected_component(edges[granularity:], closest_idx)
2524
+
2525
+ logging_text = f"Clicked: idx={closest_idx}, RGB: {closest_rgb.tolist()}\n"
2526
+ logging_text += f"Granularity: k={granularity}, Connected: n={len(connected_idxs)}"
2527
+
2528
+ output_tsne_plot = plot_tsne_tree(tsne2d_embed, edges, fps_tsne_rgb, granularity, closest_idx, highlight_connections=True)
2529
+
2530
+ # draw heatmap for the connected components
2531
+ connected_eigvecs = fps_eigvecs[connected_idxs]
2532
+ left = torch.tensor(eigvecs).float() # B H W 3
2533
+ right = torch.tensor(connected_eigvecs).float()
2534
+ left = F.normalize(left, p=2, dim=-1)
2535
+ right = F.normalize(right, p=2, dim=-1)
2536
+ similarity = left @ right.T
2537
+ similarity = similarity.max(dim=-1).values # B H W
2538
+ hot_map = matplotlib.cm.get_cmap('hot')
2539
+ heatmap = hot_map(similarity)[..., :3] # B H W 3
2540
+ heatmap_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
2541
+ # overlay input images on the heatmap
2542
+ input_images = [x[0] for x in input_gallery]
2543
+ if isinstance(input_images[0], str):
2544
+ input_images = [Image.open(x) for x in input_images]
2545
+ for i, img in enumerate(input_images):
2546
+ _img = img.resize((256, 256)).convert('RGB')
2547
+ _heatmap = heatmap_images[i].resize((256, 256)).convert('RGB')
2548
+ blend = np.array(_img) * 0.5 + np.array(_heatmap) * 0.5
2549
+ blend = Image.fromarray(blend.astype(np.uint8))
2550
+ heatmap_images[i] = blend
2551
+
2552
+ # tree_label = f"spectral-tSNE tree [row#{max_rows-output_slot}] k={granularity} idx={closest_idx} n={len(connected_idxs)}"
2553
+ tree_label = f"spectral-tSNE tree [row#{max_rows-output_slot}]"
2554
+ heatmap_label = f"Cluster Heatmap [row#{max_rows-output_slot}] k={granularity} idx={closest_idx} n={len(connected_idxs)}"
2555
+ # update the output slots
2556
+ output_rows = [gr.update() for _ in range(max_rows)]
2557
+ output_tsne_plots = [gr.update() for _ in range(max_rows)]
2558
+ output_heatmaps = [gr.update() for _ in range(max_rows)]
2559
+ output_texts = [gr.update() for _ in range(max_rows)]
2560
+ output_rows[output_slot] = gr.update(visible=True)
2561
+ output_tsne_plots[output_slot] = gr.update(value=output_tsne_plot, label=tree_label)
2562
+ output_heatmaps[output_slot] = gr.update(value=heatmap_images, label=heatmap_label)
2563
+ output_texts[output_slot] = gr.update(value=logging_text)
2564
+ gr.Info(f"Output in [row#{max_rows-output_slot}]", 3)
2565
+ logging_text += f"\nOutput: [row#{max_rows-output_slot}]"
2566
+ output_slot -= 1
2567
+ if output_slot < 0:
2568
+ output_slot = max_rows - 1
2569
+
2570
+ return *output_rows, *output_tsne_plots, *output_heatmaps, *output_texts, output_slot, logging_text
2571
+
2572
+
2573
+ run_inspection_button.click(
2574
+ run_inspection,
2575
+ inputs=[tsne_prompt_image, image_plot, prompt_radio, current_output_row, tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity_slider, eigvecs, image_slider, tsne3d_rgb, input_gallery],
2576
+ outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, inspect_logging_text],
2577
+ )
2578
+
2579
  with gr.Tab('AlignedCut'):
2580
 
2581
  with gr.Row():
 
3854
  with gr.Row():
3855
  with gr.Column(scale=5, min_width=200):
3856
  gr.Markdown("### Step 1: Load Images and Run NCUT")
3857
+ input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=100)
3858
  # submit_button.visible = False
3859
  num_images_slider.value = 30
3860
  [
 
3863
  embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
3864
  perplexity_slider, n_neighbors_slider, min_dist_slider,
3865
  sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
3866
+ ] = make_parameters_section(ncut_parameter_dropdown=False)
3867
  num_eig_slider.value = 1000
3868
  num_eig_slider.visible = False
3869
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
 
3991
  pil_images = overlaied_images
3992
  return pil_images, (y, x)
3993
 
3994
+ def _farthest_point_sampling(
3995
  features,
3996
  start_feature,
3997
  num_sample=300,
 
4041
  num_childs = min(4, masked_eigvecs.shape[0])
4042
  assert num_childs > 0
4043
 
4044
+ child_idx = _farthest_point_sampling(masked_eigvecs, _picked_eigvec, num_sample=num_childs+1)
4045
  child_idx = np.sort(child_idx)[:-1]
4046
 
4047
  # convert child_idx to flat_idx
 
4124
  with gr.Row():
4125
  with gr.Column(scale=5, min_width=200):
4126
  gr.Markdown("### Step 1: Load Images")
4127
+ input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=100)
4128
  submit_button.visible = False
4129
  num_images_slider.value = 30
4130
 
 
4141
  embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
4142
  perplexity_slider, n_neighbors_slider, min_dist_slider,
4143
  sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
4144
+ ] = make_parameters_section(ncut_parameter_dropdown=False, tsne_parameter_dropdown=False)
4145
  num_eig_slider.value = 1024
4146
  num_eig_slider.visible = False
4147
  submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary')
fps_cluster.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def build_tree(all_dots):
7
+ num_sample = all_dots.shape[0]
8
+ center = all_dots.mean(axis=0)
9
+ distances_to_center = np.linalg.norm(all_dots - center, axis=1)
10
+ start_idx = np.argmin(distances_to_center)
11
+ indices = [start_idx]
12
+ distances = [114514,]
13
+ A = all_dots[:, None] - all_dots[None, :]
14
+ A = (A ** 2).sum(-1)
15
+ A = np.sqrt(A)
16
+ A = torch.tensor(A)
17
+ for i in range(num_sample - 1):
18
+ _A = A[indices]
19
+ min_dist = _A.min(dim=0).values
20
+ next_idx = torch.argmax(min_dist).item()
21
+ distance = min_dist[next_idx].item()
22
+ indices.append(next_idx)
23
+ distances.append(distance)
24
+ indices = np.array(indices)
25
+ distances = np.array(distances)
26
+
27
+ levels = np.log2(distances[1] / distances)
28
+ levels = np.floor(levels).astype(int) + 1
29
+ levels[0] = 0
30
+
31
+ n_levels = levels.max() + 1
32
+ pi_indices = [indices[0],]
33
+ for i_level in range(1, n_levels):
34
+ current_level_indices = levels == i_level
35
+ prev_level_indices = levels < i_level
36
+ current_level_indices = indices[current_level_indices]
37
+ prev_level_indices = indices[prev_level_indices]
38
+ _A = A[prev_level_indices][:, current_level_indices]
39
+ _pi = _A.min(dim=0).indices
40
+ pi = prev_level_indices[_pi]
41
+ if isinstance(pi, np.int64) or isinstance(pi, int):
42
+ pi = [pi,]
43
+ if isinstance(pi, np.ndarray):
44
+ pi = pi.tolist()
45
+ pi_indices.extend(pi)
46
+ pi_indices = np.array(pi_indices)
47
+
48
+ edges = np.stack([indices, pi_indices], axis=1)
49
+ return edges
50
+
51
+
52
+ def find_connected_component(edges, start_node):
53
+ # Dictionary to store adjacency list
54
+ adjacency_list = {}
55
+ for edge in edges:
56
+ # Unpack edge
57
+ a, b = edge
58
+ # Add the connection for both nodes
59
+ if a in adjacency_list:
60
+ adjacency_list[a].append(b)
61
+ else:
62
+ adjacency_list[a] = [b]
63
+ if b in adjacency_list:
64
+ adjacency_list[b].append(a)
65
+ else:
66
+ adjacency_list[b] = [a]
67
+
68
+ # Use BFS to find all nodes in the connected component
69
+ connected_component = set()
70
+ queue = [start_node]
71
+
72
+ while queue:
73
+ node = queue.pop(0)
74
+ if node not in connected_component:
75
+ connected_component.add(node)
76
+ queue.extend(adjacency_list.get(node, [])) # Add neighbors to the queue
77
+
78
+ return np.array(list(connected_component))