Spaces:
Running
on
Zero
Running
on
Zero
add inspect playground
Browse files- app.py +425 -19
- fps_cluster.py +78 -0
app.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
# %%
|
3 |
import copy
|
4 |
from datetime import datetime
|
|
|
5 |
import math
|
6 |
import pickle
|
7 |
from functools import partial
|
@@ -168,8 +169,6 @@ def compute_ncut(
|
|
168 |
logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
|
169 |
|
170 |
if only_eigvecs:
|
171 |
-
eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
|
172 |
-
eigvecs = eigvecs.detach().numpy()
|
173 |
return None, logging_str, eigvecs
|
174 |
|
175 |
start = time.time()
|
@@ -285,10 +284,15 @@ def to_pil_images(images, target_size=512, resize=True, force_size=False):
|
|
285 |
res = int(size * multiplier)
|
286 |
if force_size:
|
287 |
res = target_size
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
|
|
|
|
|
|
|
|
|
|
292 |
if resize:
|
293 |
pil_images = [
|
294 |
image.resize((res, res), Image.Resampling.NEAREST)
|
@@ -865,6 +869,7 @@ def ncut_run(
|
|
865 |
# ailgnedcut
|
866 |
if not directed:
|
867 |
only_eigvecs = kwargs.get("only_eigvecs", False)
|
|
|
868 |
|
869 |
rgb, _logging_str, eigvecs = compute_ncut(
|
870 |
features,
|
@@ -886,9 +891,20 @@ def ncut_run(
|
|
886 |
only_eigvecs=only_eigvecs,
|
887 |
)
|
888 |
|
|
|
889 |
if only_eigvecs:
|
|
|
|
|
|
|
890 |
return eigvecs, logging_str
|
891 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
892 |
if directed:
|
893 |
head_index_text = kwargs.get("head_index_text", None)
|
894 |
n_heads = features.shape[-2] # (batch, h, w, n_heads, d)
|
@@ -1232,6 +1248,7 @@ def run_fn(
|
|
1232 |
advanced=False,
|
1233 |
directed=False,
|
1234 |
only_eigvecs=False,
|
|
|
1235 |
):
|
1236 |
# print(node_type2, head_index_text, make_symmetric)
|
1237 |
progress=gr.Progress()
|
@@ -1373,6 +1390,7 @@ def run_fn(
|
|
1373 |
"head_index_text": head_index_text,
|
1374 |
"make_symmetric": make_symmetric,
|
1375 |
"only_eigvecs": only_eigvecs,
|
|
|
1376 |
}
|
1377 |
# print(kwargs)
|
1378 |
|
@@ -1599,7 +1617,7 @@ def load_dataset_images(is_advanced, dataset_name, num_images=10,
|
|
1599 |
is_advanced = "Basic"
|
1600 |
|
1601 |
if is_advanced == "Basic":
|
1602 |
-
gr.Info(f"Loaded images from EgoExo")
|
1603 |
return default_images
|
1604 |
try:
|
1605 |
progress(0.5, desc="Downloading Dataset")
|
@@ -1644,7 +1662,7 @@ def load_dataset_images(is_advanced, dataset_name, num_images=10,
|
|
1644 |
image_idx.extend(idx.tolist())
|
1645 |
if not is_filter:
|
1646 |
if is_random:
|
1647 |
-
if num_images
|
1648 |
image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
|
1649 |
else:
|
1650 |
gr.Warning(f"Dataset has less than {num_images} images.")
|
@@ -1653,7 +1671,7 @@ def load_dataset_images(is_advanced, dataset_name, num_images=10,
|
|
1653 |
image_idx = list(range(num_images))
|
1654 |
key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
|
1655 |
images = [dataset[i][key] for i in image_idx]
|
1656 |
-
gr.Info(f"Loaded {len(images)} images from {dataset_name}")
|
1657 |
del dataset
|
1658 |
|
1659 |
if dataset_name in CENTER_CROP_DATASETS:
|
@@ -2060,7 +2078,7 @@ def make_output_images_section(markdown=True, button=True):
|
|
2060 |
add_rotate_flip_buttons(output_gallery)
|
2061 |
return output_gallery
|
2062 |
|
2063 |
-
def make_parameters_section(is_lisa=False, model_ratio=True,
|
2064 |
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
2065 |
from ncut_pytorch.backbone import list_models, get_demo_model_names
|
2066 |
model_names = list_models()
|
@@ -2089,7 +2107,7 @@ def make_parameters_section(is_lisa=False, model_ratio=True, parameter_dropdown=
|
|
2089 |
positive_prompt.visible = False
|
2090 |
negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'")
|
2091 |
negative_prompt.visible = False
|
2092 |
-
node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type"
|
2093 |
num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for smaller clusters')
|
2094 |
|
2095 |
def change_layer_slider(model_name):
|
@@ -2125,7 +2143,7 @@ def make_parameters_section(is_lisa=False, model_ratio=True, parameter_dropdown=
|
|
2125 |
gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False))
|
2126 |
model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt])
|
2127 |
|
2128 |
-
with gr.Accordion("Advanced Parameters: NCUT", open=False, visible=
|
2129 |
gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
|
2130 |
affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
|
2131 |
num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
|
@@ -2136,7 +2154,7 @@ def make_parameters_section(is_lisa=False, model_ratio=True, parameter_dropdown=
|
|
2136 |
ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
|
2137 |
ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=True, elem_id="ncut_indirect_connection", info="Add indirect connection to the sub-sampled graph")
|
2138 |
ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
|
2139 |
-
with gr.Accordion("Advanced Parameters: Visualization", open=False, visible=
|
2140 |
# embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
|
2141 |
embedding_method_dropdown = gr.Radio(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
|
2142 |
# embedding_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="t-SNE/UMAP metric", value="euclidean", elem_id="embedding_metric")
|
@@ -2168,8 +2186,396 @@ demo = gr.Blocks(
|
|
2168 |
)
|
2169 |
with demo:
|
2170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
2172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2173 |
with gr.Tab('AlignedCut'):
|
2174 |
|
2175 |
with gr.Row():
|
@@ -3448,7 +3854,7 @@ with demo:
|
|
3448 |
with gr.Row():
|
3449 |
with gr.Column(scale=5, min_width=200):
|
3450 |
gr.Markdown("### Step 1: Load Images and Run NCUT")
|
3451 |
-
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=
|
3452 |
# submit_button.visible = False
|
3453 |
num_images_slider.value = 30
|
3454 |
[
|
@@ -3457,7 +3863,7 @@ with demo:
|
|
3457 |
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
3458 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
3459 |
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
3460 |
-
] = make_parameters_section(
|
3461 |
num_eig_slider.value = 1000
|
3462 |
num_eig_slider.visible = False
|
3463 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
@@ -3585,7 +3991,7 @@ with demo:
|
|
3585 |
pil_images = overlaied_images
|
3586 |
return pil_images, (y, x)
|
3587 |
|
3588 |
-
def
|
3589 |
features,
|
3590 |
start_feature,
|
3591 |
num_sample=300,
|
@@ -3635,7 +4041,7 @@ with demo:
|
|
3635 |
num_childs = min(4, masked_eigvecs.shape[0])
|
3636 |
assert num_childs > 0
|
3637 |
|
3638 |
-
child_idx =
|
3639 |
child_idx = np.sort(child_idx)[:-1]
|
3640 |
|
3641 |
# convert child_idx to flat_idx
|
@@ -3718,7 +4124,7 @@ with demo:
|
|
3718 |
with gr.Row():
|
3719 |
with gr.Column(scale=5, min_width=200):
|
3720 |
gr.Markdown("### Step 1: Load Images")
|
3721 |
-
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=
|
3722 |
submit_button.visible = False
|
3723 |
num_images_slider.value = 30
|
3724 |
|
@@ -3735,7 +4141,7 @@ with demo:
|
|
3735 |
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
3736 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
3737 |
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
3738 |
-
] = make_parameters_section(
|
3739 |
num_eig_slider.value = 1024
|
3740 |
num_eig_slider.visible = False
|
3741 |
submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary')
|
|
|
2 |
# %%
|
3 |
import copy
|
4 |
from datetime import datetime
|
5 |
+
import io
|
6 |
import math
|
7 |
import pickle
|
8 |
from functools import partial
|
|
|
169 |
logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
|
170 |
|
171 |
if only_eigvecs:
|
|
|
|
|
172 |
return None, logging_str, eigvecs
|
173 |
|
174 |
start = time.time()
|
|
|
284 |
res = int(size * multiplier)
|
285 |
if force_size:
|
286 |
res = target_size
|
287 |
+
|
288 |
+
pil_images = []
|
289 |
+
for image in images:
|
290 |
+
if isinstance(image, torch.Tensor):
|
291 |
+
image = image.cpu().numpy()
|
292 |
+
if image.dtype == np.float32 or image.dtype == np.float64:
|
293 |
+
image = (image * 255).astype(np.uint8)
|
294 |
+
pil_images.append(Image.fromarray(image))
|
295 |
+
|
296 |
if resize:
|
297 |
pil_images = [
|
298 |
image.resize((res, res), Image.Resampling.NEAREST)
|
|
|
869 |
# ailgnedcut
|
870 |
if not directed:
|
871 |
only_eigvecs = kwargs.get("only_eigvecs", False)
|
872 |
+
return_eigvec_and_rgb = kwargs.get("return_eigvec_and_rgb", False)
|
873 |
|
874 |
rgb, _logging_str, eigvecs = compute_ncut(
|
875 |
features,
|
|
|
891 |
only_eigvecs=only_eigvecs,
|
892 |
)
|
893 |
|
894 |
+
|
895 |
if only_eigvecs:
|
896 |
+
eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
|
897 |
+
eigvecs = eigvecs.detach().numpy()
|
898 |
+
logging_str += _logging_str
|
899 |
return eigvecs, logging_str
|
900 |
|
901 |
+
if return_eigvec_and_rgb:
|
902 |
+
eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,))
|
903 |
+
eigvecs = eigvecs.detach().numpy()
|
904 |
+
rgb = rgb.cpu().numpy()
|
905 |
+
logging_str += _logging_str
|
906 |
+
return eigvecs, rgb, logging_str
|
907 |
+
|
908 |
if directed:
|
909 |
head_index_text = kwargs.get("head_index_text", None)
|
910 |
n_heads = features.shape[-2] # (batch, h, w, n_heads, d)
|
|
|
1248 |
advanced=False,
|
1249 |
directed=False,
|
1250 |
only_eigvecs=False,
|
1251 |
+
return_eigvec_and_rgb=False,
|
1252 |
):
|
1253 |
# print(node_type2, head_index_text, make_symmetric)
|
1254 |
progress=gr.Progress()
|
|
|
1390 |
"head_index_text": head_index_text,
|
1391 |
"make_symmetric": make_symmetric,
|
1392 |
"only_eigvecs": only_eigvecs,
|
1393 |
+
"return_eigvec_and_rgb": return_eigvec_and_rgb,
|
1394 |
}
|
1395 |
# print(kwargs)
|
1396 |
|
|
|
1617 |
is_advanced = "Basic"
|
1618 |
|
1619 |
if is_advanced == "Basic":
|
1620 |
+
gr.Info(f"Loaded images from EgoExo", duration=5)
|
1621 |
return default_images
|
1622 |
try:
|
1623 |
progress(0.5, desc="Downloading Dataset")
|
|
|
1662 |
image_idx.extend(idx.tolist())
|
1663 |
if not is_filter:
|
1664 |
if is_random:
|
1665 |
+
if num_images <= len(dataset):
|
1666 |
image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
|
1667 |
else:
|
1668 |
gr.Warning(f"Dataset has less than {num_images} images.")
|
|
|
1671 |
image_idx = list(range(num_images))
|
1672 |
key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
|
1673 |
images = [dataset[i][key] for i in image_idx]
|
1674 |
+
gr.Info(f"Loaded {len(images)} images from {dataset_name}", duration=5)
|
1675 |
del dataset
|
1676 |
|
1677 |
if dataset_name in CENTER_CROP_DATASETS:
|
|
|
2078 |
add_rotate_flip_buttons(output_gallery)
|
2079 |
return output_gallery
|
2080 |
|
2081 |
+
def make_parameters_section(is_lisa=False, model_ratio=True, ncut_parameter_dropdown=True, tsne_parameter_dropdown=True):
|
2082 |
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
2083 |
from ncut_pytorch.backbone import list_models, get_demo_model_names
|
2084 |
model_names = list_models()
|
|
|
2107 |
positive_prompt.visible = False
|
2108 |
negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'")
|
2109 |
negative_prompt.visible = False
|
2110 |
+
node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type")
|
2111 |
num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for smaller clusters')
|
2112 |
|
2113 |
def change_layer_slider(model_name):
|
|
|
2143 |
gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False))
|
2144 |
model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt])
|
2145 |
|
2146 |
+
with gr.Accordion("Advanced Parameters: NCUT", open=False, visible=ncut_parameter_dropdown):
|
2147 |
gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
|
2148 |
affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
|
2149 |
num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
|
|
|
2154 |
ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
|
2155 |
ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=True, elem_id="ncut_indirect_connection", info="Add indirect connection to the sub-sampled graph")
|
2156 |
ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
|
2157 |
+
with gr.Accordion("Advanced Parameters: Visualization", open=False, visible=tsne_parameter_dropdown):
|
2158 |
# embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
|
2159 |
embedding_method_dropdown = gr.Radio(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
|
2160 |
# embedding_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="t-SNE/UMAP metric", value="euclidean", elem_id="embedding_metric")
|
|
|
2186 |
)
|
2187 |
with demo:
|
2188 |
|
2189 |
+
with gr.Tab('Hierarchical (dev)'):
|
2190 |
+
eigvecs = gr.State(np.array([]))
|
2191 |
+
tsne3d_rgb = gr.State(np.array([]))
|
2192 |
+
with gr.Row():
|
2193 |
+
with gr.Column(scale=5, min_width=200):
|
2194 |
+
# gr.Markdown("### Step 1: Load Images")
|
2195 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=100, markdown=False)
|
2196 |
+
submit_button.value = "🔴 RUN NCUT"
|
2197 |
+
num_images_slider.value = 100
|
2198 |
+
|
2199 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
2200 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
2201 |
+
|
2202 |
+
with gr.Column(scale=5, min_width=200):
|
2203 |
+
# gr.Markdown("### Step 2a: Run Backbone and NCUT")
|
2204 |
+
# with gr.Accordion(label="Backbone Parameters", visible=True, open=False):
|
2205 |
+
output_gallery = gr.Gallery(format='png', value=[], label="NCUT spectral-tSNE", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
2206 |
+
def add_rotate_flip_buttons_with_state(output_gallery, tsne3d_rgb):
|
2207 |
+
with gr.Row():
|
2208 |
+
rotate_button = gr.Button("🔄 Rotate", elem_id="rotate_button", variant='secondary')
|
2209 |
+
rotate_button.click(sequence_rotate_rgb_gallery, inputs=[output_gallery], outputs=[output_gallery])
|
2210 |
+
def rotate_state(arr):
|
2211 |
+
rotation_matrix = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32)
|
2212 |
+
return arr @ rotation_matrix
|
2213 |
+
rotate_button.click(rotate_state, inputs=[tsne3d_rgb], outputs=[tsne3d_rgb])
|
2214 |
+
flip_button = gr.Button("🔃 Flip", elem_id="flip_button", variant='secondary')
|
2215 |
+
flip_button.click(flip_rgb_gallery, inputs=[output_gallery], outputs=[output_gallery])
|
2216 |
+
def flip_state(arr):
|
2217 |
+
return 1 - arr
|
2218 |
+
flip_button.click(flip_state, inputs=[tsne3d_rgb], outputs=[tsne3d_rgb])
|
2219 |
+
return rotate_button, flip_button
|
2220 |
+
add_rotate_flip_buttons_with_state(output_gallery, tsne3d_rgb)
|
2221 |
+
|
2222 |
+
[
|
2223 |
+
model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
|
2224 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
2225 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
2226 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
2227 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
2228 |
+
] = make_parameters_section(ncut_parameter_dropdown=True, tsne_parameter_dropdown=True)
|
2229 |
+
# submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary')
|
2230 |
+
logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
2231 |
+
|
2232 |
+
def __run_fn(*args, **kwargs):
|
2233 |
+
eigvecs, rgb, logging_str = run_fn(*args, **kwargs)
|
2234 |
+
rgb_gallery = to_pil_images(rgb)
|
2235 |
+
return eigvecs, rgb, rgb_gallery, logging_str
|
2236 |
+
|
2237 |
+
submit_button.click(
|
2238 |
+
partial(__run_fn, n_ret=2, return_eigvec_and_rgb=True),
|
2239 |
+
inputs=[
|
2240 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
2241 |
+
positive_prompt, negative_prompt,
|
2242 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
2243 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
2244 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
2245 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown
|
2246 |
+
],
|
2247 |
+
outputs=[eigvecs, tsne3d_rgb, output_gallery, logging_text],
|
2248 |
+
)
|
2249 |
+
|
2250 |
+
with gr.Column(scale=5, min_width=200):
|
2251 |
+
gr.Markdown('---')
|
2252 |
+
gr.Markdown('<h3 style="text-align: center;">Help</h3>')
|
2253 |
+
gr.Markdown('---')
|
2254 |
+
with gr.Accordion("Instructions", open=True):
|
2255 |
+
gr.Markdown("""
|
2256 |
+
1. Load Dataset (left).
|
2257 |
+
2. Choose parameters (middle).
|
2258 |
+
3. 🔴 RUN NCUT.
|
2259 |
+
4. 🔴 RUN FPS+Cluster.
|
2260 |
+
5. Interact and Inspect (scroll down).
|
2261 |
+
""")
|
2262 |
+
with gr.Accordion("Methods: NCUT spectral-TSNE", open=False):
|
2263 |
+
gr.Markdown("### <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_ncut_works/' target='_blank'>Documentation: How NCUT works</a>")
|
2264 |
+
gr.Markdown("""
|
2265 |
+
1. Run Backbone, feature extraction for each image.
|
2266 |
+
2. Vectorize latent-pixels, concatenate all the images.
|
2267 |
+
3. Run NCUT, on one big graph of all the images.
|
2268 |
+
4. Run spectral-tSNE on the NCUT eigenvectors.
|
2269 |
+
5. Plot the 3D spectral-tSNE as RGB.
|
2270 |
+
""")
|
2271 |
+
with gr.Accordion("Methods: Hierarchical Structure", open=False):
|
2272 |
+
gr.Markdown("""
|
2273 |
+
1. Farthest Point Sampling (FPS) on the eigenvectors.
|
2274 |
+
2. spectral-tSNE (2D) on the FPS sampled points.
|
2275 |
+
3. Hierarchical clustering on the FPS sampled points.
|
2276 |
+
""")
|
2277 |
+
gr.Markdown('---')
|
2278 |
+
|
2279 |
+
run_hierarchical_button = gr.Button("🔴 RUN FPS+Cluster", elem_id="run_hierarchical", variant='primary')
|
2280 |
+
with gr.Accordion("Hierarchical Structure Parameters:", open=True):
|
2281 |
+
num_sample_fps_slider = gr.Slider(1, 5000, step=1, label="FPS: num_sample", value=1000, elem_id="num_sample_fps")
|
2282 |
+
tsne_perplexity_slider = gr.Slider(1, 1000, step=1, label="t-SNE: perplexity", value=500, elem_id="perplexity_tsne")
|
2283 |
+
fps_hc_seed_slider = gr.Slider(0, 1000, step=1, label="Seed", value=0, elem_id="fps_hc_seed")
|
2284 |
+
tsne_plot = gr.Image(label="spectral-tSNE tree", elem_id="tsne_plot", interactive=False, format='png')
|
2285 |
+
|
2286 |
+
tsne_2d_points = gr.State(np.array([]))
|
2287 |
+
edges = gr.State(np.array([]))
|
2288 |
+
fps_eigvecs = gr.State(np.array([]))
|
2289 |
+
fps_indices = gr.State(np.array([]))
|
2290 |
+
fps_tsne_rgb = gr.State(np.array([]))
|
2291 |
+
|
2292 |
+
def plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, k, hightlight_idx=None, highlight_connections=False):
|
2293 |
+
# Plot the t-SNE points
|
2294 |
+
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
|
2295 |
+
ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne3d_rgb)
|
2296 |
+
# draw the edges
|
2297 |
+
for i_edge in range(k, len(edges)):
|
2298 |
+
edge = edges[i_edge]
|
2299 |
+
ax.plot(tsne_embed[edge, 0], tsne_embed[edge, 1], 'k-', lw=1, alpha=0.7)
|
2300 |
+
# highlight the selected node
|
2301 |
+
if hightlight_idx is not None:
|
2302 |
+
if highlight_connections:
|
2303 |
+
from fps_cluster import find_connected_component
|
2304 |
+
_edges = edges[k:, :]
|
2305 |
+
connected_nodes = find_connected_component(_edges, hightlight_idx)
|
2306 |
+
ax.scatter(tsne_embed[connected_nodes, 0], tsne_embed[connected_nodes, 1], s=50, c=fps_tsne3d_rgb[connected_nodes], marker='D', edgecolor='deeppink', linewidth=1)
|
2307 |
+
# ax.scatter(tsne_embed[hightlight_idx, 0], tsne_embed[hightlight_idx, 1], s=300, c='r', marker='x')
|
2308 |
+
ax.scatter(tsne_embed[hightlight_idx, 0], tsne_embed[hightlight_idx, 1], s=200, c='cyan', marker='o', edgecolor='black', linewidth=1)
|
2309 |
+
ax.set_xticks([])
|
2310 |
+
ax.set_yticks([])
|
2311 |
+
ax.axis('off')
|
2312 |
+
ax.set_xlim(tsne_embed[:, 0].min()*1.1, tsne_embed[:, 0].max()*1.1)
|
2313 |
+
ax.set_ylim(tsne_embed[:, 1].min()*1.1, tsne_embed[:, 1].max()*1.1)
|
2314 |
+
|
2315 |
+
# Remove the white space around the plot
|
2316 |
+
fig.tight_layout(pad=0)
|
2317 |
+
|
2318 |
+
# Save the plot to an in-memory buffer
|
2319 |
+
buf = io.BytesIO()
|
2320 |
+
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
|
2321 |
+
buf.seek(0)
|
2322 |
+
|
2323 |
+
# Load the image into a NumPy array
|
2324 |
+
image = np.array(Image.open(buf))
|
2325 |
|
2326 |
+
# Close the buffer and plot
|
2327 |
+
buf.close()
|
2328 |
+
plt.close(fig)
|
2329 |
+
|
2330 |
+
pil_image = Image.fromarray(image)
|
2331 |
+
return pil_image
|
2332 |
|
2333 |
+
def run_fps_tsne_hierarchical(eigvecs, num_sample_fps, perplexity_tsne, tsne3d_rgb, seed=0):
|
2334 |
+
if len(eigvecs) == 0:
|
2335 |
+
gr.Warning("Please run NCUT first.")
|
2336 |
+
return
|
2337 |
+
eigvecs = torch.tensor(eigvecs)
|
2338 |
+
eigvecs = eigvecs.reshape(-1, eigvecs.shape[-1])
|
2339 |
+
gr.Info("Running FPS, t-SNE, and Hierarchical Clustering...", 3)
|
2340 |
+
from ncut_pytorch.ncut_pytorch import farthest_point_sampling
|
2341 |
+
from sklearn.manifold import TSNE
|
2342 |
+
from fps_cluster import build_tree
|
2343 |
+
|
2344 |
+
torch.manual_seed(seed)
|
2345 |
+
np.random.seed(seed)
|
2346 |
+
|
2347 |
+
fps_idx = farthest_point_sampling(eigvecs, num_sample_fps)
|
2348 |
+
fps_eigvecs = eigvecs[fps_idx]
|
2349 |
+
fps_eigvecs = fps_eigvecs.numpy()
|
2350 |
+
|
2351 |
+
tsne3d_rgb = tsne3d_rgb.reshape(-1, 3)
|
2352 |
+
fps_tsne3d_rgb = tsne3d_rgb[fps_idx]
|
2353 |
+
|
2354 |
+
np.random.seed(seed)
|
2355 |
+
tsne_embed = TSNE(
|
2356 |
+
n_components=2,
|
2357 |
+
perplexity=perplexity_tsne,
|
2358 |
+
metric='cosine',
|
2359 |
+
random_state=seed,
|
2360 |
+
).fit_transform(fps_eigvecs)
|
2361 |
+
|
2362 |
+
edges = build_tree(tsne_embed)
|
2363 |
+
|
2364 |
+
# Plot the t-SNE points
|
2365 |
+
pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, 0)
|
2366 |
+
|
2367 |
+
return tsne_embed, edges, fps_eigvecs, fps_tsne3d_rgb, fps_idx, pil_image
|
2368 |
+
|
2369 |
+
run_hierarchical_button.click(
|
2370 |
+
run_fps_tsne_hierarchical,
|
2371 |
+
inputs=[eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider],
|
2372 |
+
outputs=[tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot],
|
2373 |
+
)
|
2374 |
+
gr.Markdown('---')
|
2375 |
+
gr.Markdown('<h3 style="text-align: center;">↓ interactively inspect the hierarchical structure</h3>')
|
2376 |
+
gr.Markdown('---')
|
2377 |
+
with gr.Row():
|
2378 |
+
from gradio_image_prompter import ImagePrompter
|
2379 |
+
with gr.Column(scale=5, min_width=200) as tsne_select:
|
2380 |
+
tsne_prompt_image = ImagePrompter(show_label=True, elem_id="tsne_prompt_image", interactive=False, label="spectral-tSNE tree")
|
2381 |
+
# copy plot to tsne_prompt_image on change
|
2382 |
+
# tsne_plot.change(fn=lambda x: gr.update(value={'image': x}, interactive=True),
|
2383 |
+
# inputs=[tsne_plot], outputs=[tsne_prompt_image])
|
2384 |
+
with gr.Column(scale=5, min_width=200) as image_select:
|
2385 |
+
image_plot = ImagePrompter(show_label=True, elem_id="image_plot", interactive=False, label="NCUT spectral-tSNE")
|
2386 |
+
image_slider = gr.Slider(0, 100, step=1, label="Image Index", value=0, elem_id="image_slider", interactive=True)
|
2387 |
+
def update_image_prompt(image_slider, output_gallery):
|
2388 |
+
if len(output_gallery) == 0:
|
2389 |
+
return gr.update(value=None, interactive=False)
|
2390 |
+
image_idx = int(image_slider)
|
2391 |
+
image = output_gallery[image_idx][0]
|
2392 |
+
return gr.update(value={'image': image}, interactive=True)
|
2393 |
+
image_slider.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
|
2394 |
+
output_gallery.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot])
|
2395 |
+
output_gallery.change(fn=lambda x: gr.update(maximum=len(x)-1, interactive=True), inputs=[output_gallery], outputs=[image_slider])
|
2396 |
+
with gr.Column(scale=5, min_width=200):
|
2397 |
+
gr.Markdown('<h3 style="text-align: center;">Help</h3>')
|
2398 |
+
with gr.Accordion("Instructions", open=True):
|
2399 |
+
gr.Markdown("""
|
2400 |
+
1. Click one dot on the left-side image.
|
2401 |
+
- Only the last clicked dot will be used
|
2402 |
+
- Eraser is at top-right corner
|
2403 |
+
- Use the right-side Radio to switch tree/image
|
2404 |
+
2. Choose a granularity (right-side).
|
2405 |
+
3. 🔴 RUN Inspection.
|
2406 |
+
4. Output will be shown below.
|
2407 |
+
""")
|
2408 |
+
with gr.Accordion("Outputs", open=True):
|
2409 |
+
gr.Markdown("""
|
2410 |
+
1. spectral-tSNE tree: ◆ means connected components to the selected point.
|
2411 |
+
2. Cluster Heatmap: max cosine similarity to any points in the connected components.
|
2412 |
+
""")
|
2413 |
+
with gr.Column(scale=5, min_width=200):
|
2414 |
+
prompt_radio = gr.Radio(["Tree", "Image"], label="Where to click on?", value="Tree", elem_id="prompt_radio", show_label=True)
|
2415 |
+
granularity_slider = gr.Slider(1, 1000, step=1, label="Cluster Granularity", value=100, elem_id="granularity")
|
2416 |
+
num_sample_fps_slider.change(fn=lambda x: gr.update(maximum=x, interactive=True), inputs=[num_sample_fps_slider], outputs=[granularity_slider])
|
2417 |
+
def updaste_tsne_plot_change_granularity(granularity, tsne_embed, edges, fps_tsne_rgb, tsne_prompt_image):
|
2418 |
+
# Plot the t-SNE points
|
2419 |
+
pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne_rgb, granularity)
|
2420 |
+
if tsne_prompt_image is None:
|
2421 |
+
return gr.update(value={'image': pil_image}, interactive=True)
|
2422 |
+
return gr.update(value={'image': pil_image, 'points': tsne_prompt_image['points']}, interactive=True)
|
2423 |
+
granularity_slider.change(updaste_tsne_plot_change_granularity,
|
2424 |
+
inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb, tsne_prompt_image],
|
2425 |
+
outputs=[tsne_prompt_image])
|
2426 |
+
tsne_plot.change(updaste_tsne_plot_change_granularity,
|
2427 |
+
inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb],
|
2428 |
+
outputs=[tsne_prompt_image])
|
2429 |
+
run_inspection_button = gr.Button("🔴 RUN Inspection", elem_id="run_inspection", variant='primary')
|
2430 |
+
inspect_logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="inspect_logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
2431 |
+
# output_slot_radio = gr.Radio([1, 2, 3], label="Output Row", value=1, elem_id="output_slot", show_label=True)
|
2432 |
+
|
2433 |
+
image_select.visible = False
|
2434 |
+
tsne_select.visible = True
|
2435 |
+
prompt_radio.change(fn=lambda x: gr.update(visible=x=="Tree"), inputs=prompt_radio, outputs=[tsne_select])
|
2436 |
+
prompt_radio.change(fn=lambda x: gr.update(visible=x=="Image"), inputs=prompt_radio, outputs=[image_select])
|
2437 |
+
|
2438 |
+
def make_one_output_row(i_row=1):
|
2439 |
+
with gr.Row() as inspect_output_row:
|
2440 |
+
with gr.Column(scale=5, min_width=200):
|
2441 |
+
output_tree_image = gr.Image(label=f"spectral-tSNE tree [row#{i_row}]", elem_id="output_image", interactive=False)
|
2442 |
+
text_block = gr.Textbox("", label="Logging", elem_id=f"logging_{i_row}", type="text", placeholder="Logging information", autofocus=False, autoscroll=False, lines=2, show_label=False)
|
2443 |
+
with gr.Column(scale=10, min_width=200):
|
2444 |
+
heatmap_gallery = gr.Gallery(format='png', value=[], label=f"Cluster Heatmap [row#{i_row}]", show_label=True, elem_id="heatmap", columns=[6], rows=[1], object_fit="contain", height="auto", show_share_button=True, interactive=False)
|
2445 |
+
return inspect_output_row, output_tree_image, heatmap_gallery, text_block
|
2446 |
+
|
2447 |
+
gr.Markdown('---')
|
2448 |
+
MAX_ROWS = 100
|
2449 |
+
current_output_row = gr.State(MAX_ROWS-1)
|
2450 |
+
inspect_output_rows, output_tree_images, heatmap_galleries, text_blocks = [], [], [], []
|
2451 |
+
for i_row in range(MAX_ROWS, 0, -1):
|
2452 |
+
inspect_output_row, output_tree_image, heatmap_gallery, text_block = make_one_output_row(i_row)
|
2453 |
+
inspect_output_row.visible = False
|
2454 |
+
inspect_output_rows.append(inspect_output_row)
|
2455 |
+
output_tree_images.append(output_tree_image)
|
2456 |
+
heatmap_galleries.append(heatmap_gallery)
|
2457 |
+
text_blocks.append(text_block)
|
2458 |
+
|
2459 |
+
|
2460 |
+
def relative_xy_last_positive(prompts):
|
2461 |
+
image = prompts['image']
|
2462 |
+
points = np.asarray(prompts['points'])
|
2463 |
+
if points.shape[0] == 0:
|
2464 |
+
return [], []
|
2465 |
+
is_point = points[:, 5] == 4.0
|
2466 |
+
points = points[is_point]
|
2467 |
+
is_positive = points[:, 2] == 1.0
|
2468 |
+
if is_positive.sum() == 0:
|
2469 |
+
raise Exception("No blue point is selected.")
|
2470 |
+
is_negative = points[:, 2] == 0.0
|
2471 |
+
xy = points[:, :2].tolist()
|
2472 |
+
if isinstance(image, str):
|
2473 |
+
image = Image.open(image)
|
2474 |
+
image = np.array(image)
|
2475 |
+
h, w = image.shape[:2]
|
2476 |
+
new_xy = [(x/w, y/h) for x, y in xy]
|
2477 |
+
|
2478 |
+
last_positive_idx = np.where(is_positive)[0][-1]
|
2479 |
+
x, y = new_xy[last_positive_idx]
|
2480 |
+
return x, y
|
2481 |
+
|
2482 |
+
def find_closest_fps_point_for_tsne_tree_plot(tsne_prompt, tsne2d_embed):
|
2483 |
+
x, y = relative_xy_last_positive(tsne_prompt)
|
2484 |
+
x_vmax = tsne2d_embed[:, 0].max() * 1.1
|
2485 |
+
x_vmin = tsne2d_embed[:, 0].min() * 1.1
|
2486 |
+
y_vmax = tsne2d_embed[:, 1].max() * 1.1
|
2487 |
+
y_vmin = tsne2d_embed[:, 1].min() * 1.1
|
2488 |
+
x = x * (x_vmax - x_vmin) + x_vmin
|
2489 |
+
y = 1 - y
|
2490 |
+
y = y * (y_vmax - y_vmin) + y_vmin
|
2491 |
+
dist = np.linalg.norm(tsne2d_embed - np.array([x, y]), axis=1)
|
2492 |
+
closest_idx = np.argmin(dist)
|
2493 |
+
return closest_idx
|
2494 |
+
|
2495 |
+
def find_closest_fps_point_for_image_prompt(image_prompt, i_image, eigvecs, fps_eigvecs):
|
2496 |
+
x, y = relative_xy_last_positive(image_prompt)
|
2497 |
+
_eigvec = eigvecs[i_image]
|
2498 |
+
h, w = _eigvec.shape[:2]
|
2499 |
+
x = int(x * w)
|
2500 |
+
y = int(y * h)
|
2501 |
+
eigvec = _eigvec[y, x]
|
2502 |
+
dist = np.linalg.norm(fps_eigvecs - eigvec, axis=1)
|
2503 |
+
closest_idx = np.argmin(dist)
|
2504 |
+
return closest_idx
|
2505 |
+
|
2506 |
+
def find_closest_fps_point(prompt_radio, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs):
|
2507 |
+
try:
|
2508 |
+
if prompt_radio == "Tree":
|
2509 |
+
return find_closest_fps_point_for_tsne_tree_plot(tsne_prompt, tsne2d_embed)
|
2510 |
+
if prompt_radio == "Image":
|
2511 |
+
return find_closest_fps_point_for_image_prompt(image_prompt, i_image, eigvecs, fps_eigvecs)
|
2512 |
+
except:
|
2513 |
+
raise gr.Error("""No blue point is selected. <br/>Please left-click on the image to select a blue point. <br/>After reloading the image (e.g., change granularity), please use the eraser to remove the previous point, then click on the image to select a blue point.""")
|
2514 |
+
|
2515 |
+
def run_inspection(tsne_prompt, image_prompt, prompt_radio, output_slot, tsne2d_embed, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity, eigvecs, i_image, tsne3d_rgb, input_gallery, max_rows=MAX_ROWS):
|
2516 |
+
if len(tsne2d_embed) == 0:
|
2517 |
+
raise gr.Error("Please run FPS+Cluster first.")
|
2518 |
+
closest_idx = find_closest_fps_point(prompt_radio, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs)
|
2519 |
+
closest_rgb = fps_tsne_rgb[closest_idx]
|
2520 |
+
closest_rgb = (closest_rgb * 255).astype(np.uint8)
|
2521 |
+
|
2522 |
+
from fps_cluster import find_connected_component
|
2523 |
+
connected_idxs = find_connected_component(edges[granularity:], closest_idx)
|
2524 |
+
|
2525 |
+
logging_text = f"Clicked: idx={closest_idx}, RGB: {closest_rgb.tolist()}\n"
|
2526 |
+
logging_text += f"Granularity: k={granularity}, Connected: n={len(connected_idxs)}"
|
2527 |
+
|
2528 |
+
output_tsne_plot = plot_tsne_tree(tsne2d_embed, edges, fps_tsne_rgb, granularity, closest_idx, highlight_connections=True)
|
2529 |
+
|
2530 |
+
# draw heatmap for the connected components
|
2531 |
+
connected_eigvecs = fps_eigvecs[connected_idxs]
|
2532 |
+
left = torch.tensor(eigvecs).float() # B H W 3
|
2533 |
+
right = torch.tensor(connected_eigvecs).float()
|
2534 |
+
left = F.normalize(left, p=2, dim=-1)
|
2535 |
+
right = F.normalize(right, p=2, dim=-1)
|
2536 |
+
similarity = left @ right.T
|
2537 |
+
similarity = similarity.max(dim=-1).values # B H W
|
2538 |
+
hot_map = matplotlib.cm.get_cmap('hot')
|
2539 |
+
heatmap = hot_map(similarity)[..., :3] # B H W 3
|
2540 |
+
heatmap_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True)
|
2541 |
+
# overlay input images on the heatmap
|
2542 |
+
input_images = [x[0] for x in input_gallery]
|
2543 |
+
if isinstance(input_images[0], str):
|
2544 |
+
input_images = [Image.open(x) for x in input_images]
|
2545 |
+
for i, img in enumerate(input_images):
|
2546 |
+
_img = img.resize((256, 256)).convert('RGB')
|
2547 |
+
_heatmap = heatmap_images[i].resize((256, 256)).convert('RGB')
|
2548 |
+
blend = np.array(_img) * 0.5 + np.array(_heatmap) * 0.5
|
2549 |
+
blend = Image.fromarray(blend.astype(np.uint8))
|
2550 |
+
heatmap_images[i] = blend
|
2551 |
+
|
2552 |
+
# tree_label = f"spectral-tSNE tree [row#{max_rows-output_slot}] k={granularity} idx={closest_idx} n={len(connected_idxs)}"
|
2553 |
+
tree_label = f"spectral-tSNE tree [row#{max_rows-output_slot}]"
|
2554 |
+
heatmap_label = f"Cluster Heatmap [row#{max_rows-output_slot}] k={granularity} idx={closest_idx} n={len(connected_idxs)}"
|
2555 |
+
# update the output slots
|
2556 |
+
output_rows = [gr.update() for _ in range(max_rows)]
|
2557 |
+
output_tsne_plots = [gr.update() for _ in range(max_rows)]
|
2558 |
+
output_heatmaps = [gr.update() for _ in range(max_rows)]
|
2559 |
+
output_texts = [gr.update() for _ in range(max_rows)]
|
2560 |
+
output_rows[output_slot] = gr.update(visible=True)
|
2561 |
+
output_tsne_plots[output_slot] = gr.update(value=output_tsne_plot, label=tree_label)
|
2562 |
+
output_heatmaps[output_slot] = gr.update(value=heatmap_images, label=heatmap_label)
|
2563 |
+
output_texts[output_slot] = gr.update(value=logging_text)
|
2564 |
+
gr.Info(f"Output in [row#{max_rows-output_slot}]", 3)
|
2565 |
+
logging_text += f"\nOutput: [row#{max_rows-output_slot}]"
|
2566 |
+
output_slot -= 1
|
2567 |
+
if output_slot < 0:
|
2568 |
+
output_slot = max_rows - 1
|
2569 |
+
|
2570 |
+
return *output_rows, *output_tsne_plots, *output_heatmaps, *output_texts, output_slot, logging_text
|
2571 |
+
|
2572 |
+
|
2573 |
+
run_inspection_button.click(
|
2574 |
+
run_inspection,
|
2575 |
+
inputs=[tsne_prompt_image, image_plot, prompt_radio, current_output_row, tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity_slider, eigvecs, image_slider, tsne3d_rgb, input_gallery],
|
2576 |
+
outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, inspect_logging_text],
|
2577 |
+
)
|
2578 |
+
|
2579 |
with gr.Tab('AlignedCut'):
|
2580 |
|
2581 |
with gr.Row():
|
|
|
3854 |
with gr.Row():
|
3855 |
with gr.Column(scale=5, min_width=200):
|
3856 |
gr.Markdown("### Step 1: Load Images and Run NCUT")
|
3857 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=100)
|
3858 |
# submit_button.visible = False
|
3859 |
num_images_slider.value = 30
|
3860 |
[
|
|
|
3863 |
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
3864 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
3865 |
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
3866 |
+
] = make_parameters_section(ncut_parameter_dropdown=False)
|
3867 |
num_eig_slider.value = 1000
|
3868 |
num_eig_slider.visible = False
|
3869 |
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
|
|
|
3991 |
pil_images = overlaied_images
|
3992 |
return pil_images, (y, x)
|
3993 |
|
3994 |
+
def _farthest_point_sampling(
|
3995 |
features,
|
3996 |
start_feature,
|
3997 |
num_sample=300,
|
|
|
4041 |
num_childs = min(4, masked_eigvecs.shape[0])
|
4042 |
assert num_childs > 0
|
4043 |
|
4044 |
+
child_idx = _farthest_point_sampling(masked_eigvecs, _picked_eigvec, num_sample=num_childs+1)
|
4045 |
child_idx = np.sort(child_idx)[:-1]
|
4046 |
|
4047 |
# convert child_idx to flat_idx
|
|
|
4124 |
with gr.Row():
|
4125 |
with gr.Column(scale=5, min_width=200):
|
4126 |
gr.Markdown("### Step 1: Load Images")
|
4127 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=100)
|
4128 |
submit_button.visible = False
|
4129 |
num_images_slider.value = 30
|
4130 |
|
|
|
4141 |
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
4142 |
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
4143 |
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
4144 |
+
] = make_parameters_section(ncut_parameter_dropdown=False, tsne_parameter_dropdown=False)
|
4145 |
num_eig_slider.value = 1024
|
4146 |
num_eig_slider.visible = False
|
4147 |
submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary')
|
fps_cluster.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def build_tree(all_dots):
|
7 |
+
num_sample = all_dots.shape[0]
|
8 |
+
center = all_dots.mean(axis=0)
|
9 |
+
distances_to_center = np.linalg.norm(all_dots - center, axis=1)
|
10 |
+
start_idx = np.argmin(distances_to_center)
|
11 |
+
indices = [start_idx]
|
12 |
+
distances = [114514,]
|
13 |
+
A = all_dots[:, None] - all_dots[None, :]
|
14 |
+
A = (A ** 2).sum(-1)
|
15 |
+
A = np.sqrt(A)
|
16 |
+
A = torch.tensor(A)
|
17 |
+
for i in range(num_sample - 1):
|
18 |
+
_A = A[indices]
|
19 |
+
min_dist = _A.min(dim=0).values
|
20 |
+
next_idx = torch.argmax(min_dist).item()
|
21 |
+
distance = min_dist[next_idx].item()
|
22 |
+
indices.append(next_idx)
|
23 |
+
distances.append(distance)
|
24 |
+
indices = np.array(indices)
|
25 |
+
distances = np.array(distances)
|
26 |
+
|
27 |
+
levels = np.log2(distances[1] / distances)
|
28 |
+
levels = np.floor(levels).astype(int) + 1
|
29 |
+
levels[0] = 0
|
30 |
+
|
31 |
+
n_levels = levels.max() + 1
|
32 |
+
pi_indices = [indices[0],]
|
33 |
+
for i_level in range(1, n_levels):
|
34 |
+
current_level_indices = levels == i_level
|
35 |
+
prev_level_indices = levels < i_level
|
36 |
+
current_level_indices = indices[current_level_indices]
|
37 |
+
prev_level_indices = indices[prev_level_indices]
|
38 |
+
_A = A[prev_level_indices][:, current_level_indices]
|
39 |
+
_pi = _A.min(dim=0).indices
|
40 |
+
pi = prev_level_indices[_pi]
|
41 |
+
if isinstance(pi, np.int64) or isinstance(pi, int):
|
42 |
+
pi = [pi,]
|
43 |
+
if isinstance(pi, np.ndarray):
|
44 |
+
pi = pi.tolist()
|
45 |
+
pi_indices.extend(pi)
|
46 |
+
pi_indices = np.array(pi_indices)
|
47 |
+
|
48 |
+
edges = np.stack([indices, pi_indices], axis=1)
|
49 |
+
return edges
|
50 |
+
|
51 |
+
|
52 |
+
def find_connected_component(edges, start_node):
|
53 |
+
# Dictionary to store adjacency list
|
54 |
+
adjacency_list = {}
|
55 |
+
for edge in edges:
|
56 |
+
# Unpack edge
|
57 |
+
a, b = edge
|
58 |
+
# Add the connection for both nodes
|
59 |
+
if a in adjacency_list:
|
60 |
+
adjacency_list[a].append(b)
|
61 |
+
else:
|
62 |
+
adjacency_list[a] = [b]
|
63 |
+
if b in adjacency_list:
|
64 |
+
adjacency_list[b].append(a)
|
65 |
+
else:
|
66 |
+
adjacency_list[b] = [a]
|
67 |
+
|
68 |
+
# Use BFS to find all nodes in the connected component
|
69 |
+
connected_component = set()
|
70 |
+
queue = [start_node]
|
71 |
+
|
72 |
+
while queue:
|
73 |
+
node = queue.pop(0)
|
74 |
+
if node not in connected_component:
|
75 |
+
connected_component.add(node)
|
76 |
+
queue.extend(adjacency_list.get(node, [])) # Add neighbors to the queue
|
77 |
+
|
78 |
+
return np.array(list(connected_component))
|