Spaces:
Running
on
Zero
Running
on
Zero
add directed ncut (test)
Browse files- app.py +382 -39
- directed_ncut.py +287 -0
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -183,6 +183,84 @@ def compute_ncut(
|
|
| 183 |
return rgb, logging_str, eigvecs
|
| 184 |
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
def dont_use_too_much_green(image_rgb):
|
| 187 |
# make sure the foval 40% of the image is red leading
|
| 188 |
x1, x2 = int(image_rgb.shape[1] * 0.3), int(image_rgb.shape[1] * 0.7)
|
|
@@ -592,6 +670,8 @@ def ncut_run(
|
|
| 592 |
**kwargs,
|
| 593 |
):
|
| 594 |
advanced = kwargs.get("advanced", False)
|
|
|
|
|
|
|
| 595 |
progress = gr.Progress()
|
| 596 |
progress(0.2, desc="Feature Extraction")
|
| 597 |
|
|
@@ -640,6 +720,11 @@ def ncut_run(
|
|
| 640 |
features = extract_features(
|
| 641 |
images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
|
| 642 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
| 644 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
| 645 |
del model
|
|
@@ -768,25 +853,59 @@ def ncut_run(
|
|
| 768 |
|
| 769 |
|
| 770 |
# ailgnedcut
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 790 |
logging_str += _logging_str
|
| 791 |
|
| 792 |
if "AlignedThreeModelAttnNodes" == model_name:
|
|
@@ -858,26 +977,26 @@ def ncut_run(
|
|
| 858 |
|
| 859 |
def _ncut_run(*args, **kwargs):
|
| 860 |
n_ret = kwargs.pop("n_ret", 1)
|
| 861 |
-
try:
|
| 862 |
-
|
| 863 |
-
|
| 864 |
|
| 865 |
-
|
| 866 |
|
| 867 |
-
|
| 868 |
-
|
| 869 |
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
except Exception as e:
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
| 877 |
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
|
| 882 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 883 |
@spaces.GPU(duration=30)
|
|
@@ -1085,12 +1204,16 @@ def run_fn(
|
|
| 1085 |
recursion_l1_gamma=0.5,
|
| 1086 |
recursion_l2_gamma=0.5,
|
| 1087 |
recursion_l3_gamma=0.5,
|
|
|
|
|
|
|
|
|
|
| 1088 |
n_ret=1,
|
| 1089 |
plot_clusters=False,
|
| 1090 |
alignedcut_eig_norm_plot=False,
|
| 1091 |
advanced=False,
|
|
|
|
| 1092 |
):
|
| 1093 |
-
|
| 1094 |
progress=gr.Progress()
|
| 1095 |
progress(0, desc="Starting")
|
| 1096 |
|
|
@@ -1222,6 +1345,10 @@ def run_fn(
|
|
| 1222 |
"plot_clusters": plot_clusters,
|
| 1223 |
"alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
|
| 1224 |
"advanced": advanced,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1225 |
}
|
| 1226 |
# print(kwargs)
|
| 1227 |
|
|
@@ -1379,7 +1506,7 @@ def fit_trans(rgb1, rgb2, num_layer=3, width=512, batch_size=256, lr=3e-4, fitti
|
|
| 1379 |
# Train the model
|
| 1380 |
trainer.fit(mlp, dataloader)
|
| 1381 |
|
| 1382 |
-
|
| 1383 |
results = trainer.predict(mlp, data_loader)
|
| 1384 |
A_transformed = torch.cat(results, dim=0)
|
| 1385 |
|
|
@@ -2734,10 +2861,226 @@ with demo:
|
|
| 2734 |
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
| 2735 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
| 2736 |
|
|
|
|
|
|
|
| 2737 |
|
| 2738 |
-
|
| 2739 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2740 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2741 |
|
| 2742 |
with gr.Tab('📄About'):
|
| 2743 |
with gr.Column():
|
|
|
|
| 183 |
return rgb, logging_str, eigvecs
|
| 184 |
|
| 185 |
|
| 186 |
+
def compute_ncut_directed(
|
| 187 |
+
features_1,
|
| 188 |
+
features_2,
|
| 189 |
+
num_eig=100,
|
| 190 |
+
num_sample_ncut=10000,
|
| 191 |
+
affinity_focal_gamma=0.3,
|
| 192 |
+
knn_ncut=10,
|
| 193 |
+
knn_tsne=10,
|
| 194 |
+
embedding_method="UMAP",
|
| 195 |
+
embedding_metric='euclidean',
|
| 196 |
+
num_sample_tsne=300,
|
| 197 |
+
perplexity=150,
|
| 198 |
+
n_neighbors=150,
|
| 199 |
+
min_dist=0.1,
|
| 200 |
+
sampling_method="QuickFPS",
|
| 201 |
+
metric="cosine",
|
| 202 |
+
indirect_connection=False,
|
| 203 |
+
make_orthogonal=False,
|
| 204 |
+
make_symmetric=False,
|
| 205 |
+
progess_start=0.4,
|
| 206 |
+
):
|
| 207 |
+
print("Using directed_ncut")
|
| 208 |
+
print("features_1.shape", features_1.shape)
|
| 209 |
+
print("features_2.shape", features_2.shape)
|
| 210 |
+
from directed_ncut import nystrom_ncut
|
| 211 |
+
progress = gr.Progress()
|
| 212 |
+
logging_str = ""
|
| 213 |
+
|
| 214 |
+
num_nodes = np.prod(features_1.shape[:-2])
|
| 215 |
+
if num_nodes / 2 < num_eig:
|
| 216 |
+
# raise gr.Error("Number of eigenvectors should be less than half the number of nodes.")
|
| 217 |
+
gr.Warning("Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.")
|
| 218 |
+
num_eig = num_nodes // 2 - 1
|
| 219 |
+
logging_str += f"Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.\n"
|
| 220 |
+
|
| 221 |
+
start = time.time()
|
| 222 |
+
progress(progess_start+0.0, desc="NCut")
|
| 223 |
+
n_features = features_1.shape[-2]
|
| 224 |
+
_features_1 = rearrange(features_1, "b h w d c -> (b h w) (d c)")
|
| 225 |
+
_features_2 = rearrange(features_2, "b h w d c -> (b h w) (d c)")
|
| 226 |
+
eigvecs, eigvals, _ = nystrom_ncut(
|
| 227 |
+
_features_1,
|
| 228 |
+
features_B=_features_2,
|
| 229 |
+
num_eig=num_eig,
|
| 230 |
+
num_sample=num_sample_ncut,
|
| 231 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 232 |
+
affinity_focal_gamma=affinity_focal_gamma,
|
| 233 |
+
knn=knn_ncut,
|
| 234 |
+
sample_method=sampling_method,
|
| 235 |
+
distance=metric,
|
| 236 |
+
normalize_features=False,
|
| 237 |
+
indirect_connection=indirect_connection,
|
| 238 |
+
make_orthogonal=make_orthogonal,
|
| 239 |
+
make_symmetric=make_symmetric,
|
| 240 |
+
n_features=n_features,
|
| 241 |
+
)
|
| 242 |
+
# print(f"NCUT time: {time.time() - start:.2f}s")
|
| 243 |
+
logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
|
| 244 |
+
|
| 245 |
+
start = time.time()
|
| 246 |
+
progress(progess_start+0.01, desc="spectral-tSNE")
|
| 247 |
+
_, rgb = eigenvector_to_rgb(
|
| 248 |
+
eigvecs,
|
| 249 |
+
method=embedding_method,
|
| 250 |
+
metric=embedding_metric,
|
| 251 |
+
num_sample=num_sample_tsne,
|
| 252 |
+
perplexity=perplexity,
|
| 253 |
+
n_neighbors=n_neighbors,
|
| 254 |
+
min_distance=min_dist,
|
| 255 |
+
knn=knn_tsne,
|
| 256 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 257 |
+
)
|
| 258 |
+
logging_str += f"{embedding_method} time: {time.time() - start:.2f}s\n"
|
| 259 |
+
|
| 260 |
+
rgb = rgb.reshape(features_1.shape[:3] + (3,))
|
| 261 |
+
return rgb, logging_str, eigvecs
|
| 262 |
+
|
| 263 |
+
|
| 264 |
def dont_use_too_much_green(image_rgb):
|
| 265 |
# make sure the foval 40% of the image is red leading
|
| 266 |
x1, x2 = int(image_rgb.shape[1] * 0.3), int(image_rgb.shape[1] * 0.7)
|
|
|
|
| 670 |
**kwargs,
|
| 671 |
):
|
| 672 |
advanced = kwargs.get("advanced", False)
|
| 673 |
+
directed = kwargs.get("directed", False)
|
| 674 |
+
|
| 675 |
progress = gr.Progress()
|
| 676 |
progress(0.2, desc="Feature Extraction")
|
| 677 |
|
|
|
|
| 720 |
features = extract_features(
|
| 721 |
images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
|
| 722 |
)
|
| 723 |
+
if directed:
|
| 724 |
+
node_type2 = kwargs.get("node_type2", None)
|
| 725 |
+
features_B = extract_features(
|
| 726 |
+
images, model, node_type=node_type2, layer=layer-1, batch_size=BATCH_SIZE
|
| 727 |
+
)
|
| 728 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
| 729 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
| 730 |
del model
|
|
|
|
| 853 |
|
| 854 |
|
| 855 |
# ailgnedcut
|
| 856 |
+
if not directed:
|
| 857 |
+
rgb, _logging_str, eigvecs = compute_ncut(
|
| 858 |
+
features,
|
| 859 |
+
num_eig=num_eig,
|
| 860 |
+
num_sample_ncut=num_sample_ncut,
|
| 861 |
+
affinity_focal_gamma=affinity_focal_gamma,
|
| 862 |
+
knn_ncut=knn_ncut,
|
| 863 |
+
knn_tsne=knn_tsne,
|
| 864 |
+
num_sample_tsne=num_sample_tsne,
|
| 865 |
+
embedding_method=embedding_method,
|
| 866 |
+
embedding_metric=embedding_metric,
|
| 867 |
+
perplexity=perplexity,
|
| 868 |
+
n_neighbors=n_neighbors,
|
| 869 |
+
min_dist=min_dist,
|
| 870 |
+
sampling_method=sampling_method,
|
| 871 |
+
indirect_connection=indirect_connection,
|
| 872 |
+
make_orthogonal=make_orthogonal,
|
| 873 |
+
metric=ncut_metric,
|
| 874 |
+
)
|
| 875 |
+
if directed:
|
| 876 |
+
head_index_text = kwargs.get("head_index_text", None)
|
| 877 |
+
n_heads = features.shape[-2] # (batch, h, w, n_heads, d)
|
| 878 |
+
if head_index_text == 'all':
|
| 879 |
+
head_idx = torch.arange(n_heads)
|
| 880 |
+
else:
|
| 881 |
+
_idxs = head_index_text.split(",")
|
| 882 |
+
head_idx = torch.tensor([int(idx) for idx in _idxs])
|
| 883 |
+
features_A = features[:, :, :, head_idx, :]
|
| 884 |
+
features_B = features_B[:, :, :, head_idx, :]
|
| 885 |
+
|
| 886 |
+
rgb, _logging_str, eigvecs = compute_ncut_directed(
|
| 887 |
+
features_A,
|
| 888 |
+
features_B,
|
| 889 |
+
num_eig=num_eig,
|
| 890 |
+
num_sample_ncut=num_sample_ncut,
|
| 891 |
+
affinity_focal_gamma=affinity_focal_gamma,
|
| 892 |
+
knn_ncut=knn_ncut,
|
| 893 |
+
knn_tsne=knn_tsne,
|
| 894 |
+
num_sample_tsne=num_sample_tsne,
|
| 895 |
+
embedding_method=embedding_method,
|
| 896 |
+
embedding_metric=embedding_metric,
|
| 897 |
+
perplexity=perplexity,
|
| 898 |
+
n_neighbors=n_neighbors,
|
| 899 |
+
min_dist=min_dist,
|
| 900 |
+
sampling_method=sampling_method,
|
| 901 |
+
indirect_connection=False,
|
| 902 |
+
make_orthogonal=make_orthogonal,
|
| 903 |
+
metric=ncut_metric,
|
| 904 |
+
make_symmetric=kwargs.get("make_symmetric", None),
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
|
| 909 |
logging_str += _logging_str
|
| 910 |
|
| 911 |
if "AlignedThreeModelAttnNodes" == model_name:
|
|
|
|
| 977 |
|
| 978 |
def _ncut_run(*args, **kwargs):
|
| 979 |
n_ret = kwargs.pop("n_ret", 1)
|
| 980 |
+
# try:
|
| 981 |
+
# if torch.cuda.is_available():
|
| 982 |
+
# torch.cuda.empty_cache()
|
| 983 |
|
| 984 |
+
# ret = ncut_run(*args, **kwargs)
|
| 985 |
|
| 986 |
+
# if torch.cuda.is_available():
|
| 987 |
+
# torch.cuda.empty_cache()
|
| 988 |
|
| 989 |
+
# ret = list(ret)[:n_ret] + [ret[-1]]
|
| 990 |
+
# return ret
|
| 991 |
+
# except Exception as e:
|
| 992 |
+
# gr.Error(str(e))
|
| 993 |
+
# if torch.cuda.is_available():
|
| 994 |
+
# torch.cuda.empty_cache()
|
| 995 |
+
# return *(None for _ in range(n_ret)), "Error: " + str(e)
|
| 996 |
|
| 997 |
+
ret = ncut_run(*args, **kwargs)
|
| 998 |
+
ret = list(ret)[:n_ret] + [ret[-1]]
|
| 999 |
+
return ret
|
| 1000 |
|
| 1001 |
if USE_HUGGINGFACE_ZEROGPU:
|
| 1002 |
@spaces.GPU(duration=30)
|
|
|
|
| 1204 |
recursion_l1_gamma=0.5,
|
| 1205 |
recursion_l2_gamma=0.5,
|
| 1206 |
recursion_l3_gamma=0.5,
|
| 1207 |
+
node_type2="k",
|
| 1208 |
+
head_index_text='all',
|
| 1209 |
+
make_symmetric=False,
|
| 1210 |
n_ret=1,
|
| 1211 |
plot_clusters=False,
|
| 1212 |
alignedcut_eig_norm_plot=False,
|
| 1213 |
advanced=False,
|
| 1214 |
+
directed=False,
|
| 1215 |
):
|
| 1216 |
+
print(node_type2, head_index_text, make_symmetric)
|
| 1217 |
progress=gr.Progress()
|
| 1218 |
progress(0, desc="Starting")
|
| 1219 |
|
|
|
|
| 1345 |
"plot_clusters": plot_clusters,
|
| 1346 |
"alignedcut_eig_norm_plot": alignedcut_eig_norm_plot,
|
| 1347 |
"advanced": advanced,
|
| 1348 |
+
"directed": directed,
|
| 1349 |
+
"node_type2": node_type2,
|
| 1350 |
+
"head_index_text": head_index_text,
|
| 1351 |
+
"make_symmetric": make_symmetric,
|
| 1352 |
}
|
| 1353 |
# print(kwargs)
|
| 1354 |
|
|
|
|
| 1506 |
# Train the model
|
| 1507 |
trainer.fit(mlp, dataloader)
|
| 1508 |
|
| 1509 |
+
mlp.progress(0.99, desc="Applying MLP")
|
| 1510 |
results = trainer.predict(mlp, data_loader)
|
| 1511 |
A_transformed = torch.cat(results, dim=0)
|
| 1512 |
|
|
|
|
| 2861 |
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
| 2862 |
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
| 2863 |
|
| 2864 |
+
|
| 2865 |
+
with gr.Tab('Directed (experimental)', visible=True) as tab_directed_ncut:
|
| 2866 |
|
| 2867 |
+
target_images = gr.State([])
|
| 2868 |
+
input_images = gr.State([])
|
| 2869 |
+
def add_mlp_fitting_buttons(output_gallery, mlp_gallery, target_images=target_images, input_images=input_images):
|
| 2870 |
+
with gr.Row():
|
| 2871 |
+
# mark_as_target_button = gr.Button("mark target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary')
|
| 2872 |
+
# mark_as_input_button = gr.Button("mark input", elem_id=f"mark_as_input_button_{output_gallery.elem_id}", variant='secondary')
|
| 2873 |
+
mark_as_target_button = gr.Button("🎯 Mark Target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary')
|
| 2874 |
+
fit_to_target_button = gr.Button("🔴 [MLP] Fit", elem_id=f"fit_to_target_button_{output_gallery.elem_id}", variant='primary')
|
| 2875 |
+
def mark_fn(images, text="target"):
|
| 2876 |
+
if images is None:
|
| 2877 |
+
raise gr.Error("No images selected")
|
| 2878 |
+
if len(images) == 0:
|
| 2879 |
+
raise gr.Error("No images selected")
|
| 2880 |
+
num_images = len(images)
|
| 2881 |
+
gr.Info(f"Marked {num_images} images as {text}")
|
| 2882 |
+
images = [(Image.open(tup[0]), []) for tup in images]
|
| 2883 |
+
return images
|
| 2884 |
+
mark_as_target_button.click(partial(mark_fn, text="target"), inputs=[output_gallery], outputs=[target_images])
|
| 2885 |
+
# mark_as_input_button.click(partial(mark_fn, text="input"), inputs=[output_gallery], outputs=[input_images])
|
| 2886 |
+
|
| 2887 |
+
with gr.Accordion("➡️ MLP Parameters", open=False):
|
| 2888 |
+
num_layers_slider = gr.Slider(2, 10, step=1, label="Number of Layers", value=3, elem_id=f"num_layers_slider_{output_gallery.elem_id}")
|
| 2889 |
+
width_slider = gr.Slider(128, 4096, step=128, label="Width", value=512, elem_id=f"width_slider_{output_gallery.elem_id}")
|
| 2890 |
+
batch_size_slider = gr.Slider(32, 4096, step=32, label="Batch Size", value=128, elem_id=f"batch_size_slider_{output_gallery.elem_id}")
|
| 2891 |
+
lr_slider = gr.Slider(1e-6, 1, step=1e-6, label="Learning Rate", value=3e-4, elem_id=f"lr_slider_{output_gallery.elem_id}")
|
| 2892 |
+
fitting_steps_slider = gr.Slider(1000, 100000, step=1000, label="Fitting Steps", value=30000, elem_id=f"fitting_steps_slider_{output_gallery.elem_id}")
|
| 2893 |
+
fps_sample_slider = gr.Slider(128, 50000, step=128, label="FPS Sample", value=10240, elem_id=f"fps_sample_slider_{output_gallery.elem_id}")
|
| 2894 |
+
segmentation_loss_lambda_slider = gr.Slider(0, 100, step=0.01, label="Segmentation Preserving Loss Lambda", value=1, elem_id=f"segmentation_loss_lambda_slider_{output_gallery.elem_id}")
|
| 2895 |
+
|
| 2896 |
+
fit_to_target_button.click(
|
| 2897 |
+
run_mlp_fit,
|
| 2898 |
+
inputs=[output_gallery, target_images, num_layers_slider, width_slider, batch_size_slider, lr_slider, fitting_steps_slider, fps_sample_slider, segmentation_loss_lambda_slider],
|
| 2899 |
+
outputs=[mlp_gallery],
|
| 2900 |
+
)
|
| 2901 |
+
|
| 2902 |
+
def make_parameters_section_2model(model_ratio=True):
|
| 2903 |
+
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
| 2904 |
+
from ncut_pytorch.backbone import list_models, get_demo_model_names
|
| 2905 |
+
model_names = list_models()
|
| 2906 |
+
model_names = sorted(model_names)
|
| 2907 |
+
# only CLIP DINO MAE is implemented for q k v
|
| 2908 |
+
ok_models = ["CLIP(ViT", "DiNO(", "MAE("]
|
| 2909 |
+
model_names = [m for m in model_names if any(ok in m for ok in ok_models)]
|
| 2910 |
+
|
| 2911 |
+
def get_filtered_model_names(name):
|
| 2912 |
+
return [m for m in model_names if name.lower() in m.lower()]
|
| 2913 |
+
def get_default_model_name(name):
|
| 2914 |
+
lst = get_filtered_model_names(name)
|
| 2915 |
+
if len(lst) > 1:
|
| 2916 |
+
return lst[1]
|
| 2917 |
+
return lst[0]
|
| 2918 |
+
|
| 2919 |
+
|
| 2920 |
+
model_radio = gr.Radio(["CLIP", "DiNO", "MAE"], label="Backbone", value="DiNO", elem_id="model_radio", show_label=True, visible=model_ratio)
|
| 2921 |
+
model_dropdown = gr.Dropdown(get_filtered_model_names("DiNO"), label="", value="DiNO(dino_vitb8_448)", elem_id="model_name", show_label=False)
|
| 2922 |
+
model_radio.change(fn=lambda x: gr.update(choices=get_filtered_model_names(x), value=get_default_model_name(x)), inputs=model_radio, outputs=[model_dropdown])
|
| 2923 |
+
layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
|
| 2924 |
+
positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'")
|
| 2925 |
+
positive_prompt.visible = False
|
| 2926 |
+
negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'")
|
| 2927 |
+
negative_prompt.visible = False
|
| 2928 |
+
node_type_dropdown = gr.Dropdown(['q', 'k', 'v'],
|
| 2929 |
+
label="Left-side Node Type", value="q", elem_id="node_type", info="In directed case, left-side SVD eigenvector is taken")
|
| 2930 |
+
node_type_dropdown2 = gr.Dropdown(['q', 'k', 'v'],
|
| 2931 |
+
label="Right-side Node Type", value="k", elem_id="node_type2")
|
| 2932 |
+
head_index_text = gr.Textbox(value='all', label="Head Index", elem_id="head_index", type="text", info="which attention heads to use, comma separated, e.g. 0,1,2")
|
| 2933 |
+
make_symmetric = gr.Checkbox(label="Make Symmetric", value=False, elem_id="make_symmetric", info="make the graph symmetric by A = (A + A.T) / 2")
|
| 2934 |
+
|
| 2935 |
+
num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for smaller clusters')
|
| 2936 |
+
|
| 2937 |
+
def change_layer_slider(model_name):
|
| 2938 |
+
# SD2, UNET
|
| 2939 |
+
if "stable" in model_name.lower() and "diffusion" in model_name.lower():
|
| 2940 |
+
from ncut_pytorch.backbone import SD_KEY_DICT
|
| 2941 |
+
default_layer = 'up_2_resnets_1_block' if 'diffusion-3' not in model_name else 'block_23'
|
| 2942 |
+
return (gr.Slider(1, 49, step=1, label="Diffusion: Timestep (Noise)", value=5, elem_id="layer", visible=True, info="Noise level, 50 is max noise"),
|
| 2943 |
+
gr.Dropdown(SD_KEY_DICT[model_name], label="Diffusion: Layer and Node", value=default_layer, elem_id="node_type", info="U-Net (v1, v2) or DiT (v3)"))
|
| 2944 |
+
|
| 2945 |
+
if model_name == "LISSL(xinlai/LISSL-7B-v1)":
|
| 2946 |
+
layer_names = ["dec_0_input", "dec_0_attn", "dec_0_block", "dec_1_input", "dec_1_attn", "dec_1_block"]
|
| 2947 |
+
default_layer = "dec_1_block"
|
| 2948 |
+
return (gr.Slider(1, 6, step=1, label="LISA decoder: Layer index", value=6, elem_id="layer", visible=False, info=""),
|
| 2949 |
+
gr.Dropdown(layer_names, label="LISA decoder: Layer and Node", value=default_layer, elem_id="node_type"))
|
| 2950 |
+
|
| 2951 |
+
layer_dict = LAYER_DICT
|
| 2952 |
+
if model_name in layer_dict:
|
| 2953 |
+
value = layer_dict[model_name]
|
| 2954 |
+
return gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True, info="")
|
| 2955 |
+
else:
|
| 2956 |
+
value = 12
|
| 2957 |
+
return gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True, info="")
|
| 2958 |
+
model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=layer_slider)
|
| 2959 |
+
|
| 2960 |
+
def change_prompt_text(model_name):
|
| 2961 |
+
if model_name in promptable_diffusion_models:
|
| 2962 |
+
return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=True),
|
| 2963 |
+
gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=True))
|
| 2964 |
+
return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False),
|
| 2965 |
+
gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False))
|
| 2966 |
+
model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt])
|
| 2967 |
+
|
| 2968 |
+
with gr.Accordion("Advanced Parameters: NCUT", open=False):
|
| 2969 |
+
gr.Markdown("<a href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Docs: How to Get Better Segmentation</a>")
|
| 2970 |
+
affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
|
| 2971 |
+
num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
|
| 2972 |
+
# sampling_method_dropdown = gr.Dropdown(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method", info="Nyström approximation")
|
| 2973 |
+
sampling_method_dropdown = gr.Radio(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method")
|
| 2974 |
+
# ncut_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
|
| 2975 |
+
ncut_metric_dropdown = gr.Radio(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric")
|
| 2976 |
+
ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
|
| 2977 |
+
ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=False, elem_id="ncut_indirect_connection", info="TODO: Indirect connection is not implemented for directed NCUT", interactive=False)
|
| 2978 |
+
ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization")
|
| 2979 |
+
with gr.Accordion("Advanced Parameters: Visualization", open=False):
|
| 2980 |
+
# embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
|
| 2981 |
+
embedding_method_dropdown = gr.Radio(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method")
|
| 2982 |
+
# embedding_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="t-SNE/UMAP metric", value="euclidean", elem_id="embedding_metric")
|
| 2983 |
+
embedding_metric_dropdown = gr.Radio(["euclidean", "cosine"], label="t-SNE/UMAP: metric", value="euclidean", elem_id="embedding_metric")
|
| 2984 |
+
num_sample_tsne_slider = gr.Slider(100, 10000, step=100, label="t-SNE/UMAP: num_sample", value=300, elem_id="num_sample_tsne", info="Nyström approximation")
|
| 2985 |
+
knn_tsne_slider = gr.Slider(1, 100, step=1, label="t-SNE/UMAP: KNN", value=10, elem_id="knn_tsne", info="Nyström approximation")
|
| 2986 |
+
perplexity_slider = gr.Slider(10, 1000, step=10, label="t-SNE: perplexity", value=150, elem_id="perplexity")
|
| 2987 |
+
n_neighbors_slider = gr.Slider(10, 1000, step=10, label="UMAP: n_neighbors", value=150, elem_id="n_neighbors")
|
| 2988 |
+
min_dist_slider = gr.Slider(0.1, 1, step=0.1, label="UMAP: min_dist", value=0.1, elem_id="min_dist")
|
| 2989 |
+
return [model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider,
|
| 2990 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 2991 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 2992 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 2993 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt]
|
| 2994 |
+
|
| 2995 |
+
def add_one_model(i_model=1):
|
| 2996 |
+
with gr.Column(scale=5, min_width=200) as col:
|
| 2997 |
+
gr.Markdown(f'### Output Images')
|
| 2998 |
+
output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 2999 |
+
submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary')
|
| 3000 |
+
add_rotate_flip_buttons(output_gallery)
|
| 3001 |
+
add_download_button(output_gallery, f"ncut_embed")
|
| 3002 |
+
mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_fullscreen_button=True, interactive=False)
|
| 3003 |
+
add_mlp_fitting_buttons(output_gallery, mlp_gallery)
|
| 3004 |
+
add_download_button(mlp_gallery, f"mlp_color_align")
|
| 3005 |
+
norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 3006 |
+
add_download_button(norm_gallery, f"eig_norm")
|
| 3007 |
+
cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
|
| 3008 |
+
add_download_button(cluster_gallery, f"clusters")
|
| 3009 |
+
[
|
| 3010 |
+
model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider,
|
| 3011 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 3012 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 3013 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider,
|
| 3014 |
+
sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt
|
| 3015 |
+
] = make_parameters_section_2model()
|
| 3016 |
+
# logging text box
|
| 3017 |
+
logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
|
| 3018 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
| 3019 |
+
no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False)
|
| 3020 |
+
|
| 3021 |
+
false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False)
|
| 3022 |
+
|
| 3023 |
+
submit_button.click(
|
| 3024 |
+
partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True, advanced=True, directed=True),
|
| 3025 |
+
inputs=[
|
| 3026 |
+
input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
|
| 3027 |
+
positive_prompt, negative_prompt,
|
| 3028 |
+
false_placeholder, no_prompt, no_prompt, no_prompt,
|
| 3029 |
+
affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
|
| 3030 |
+
embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider,
|
| 3031 |
+
perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown,
|
| 3032 |
+
*[false_placeholder for _ in range(9)],
|
| 3033 |
+
node_type_dropdown2, head_index_text, make_symmetric
|
| 3034 |
+
],
|
| 3035 |
+
outputs=[output_gallery, cluster_gallery, norm_gallery, logging_text]
|
| 3036 |
+
)
|
| 3037 |
+
|
| 3038 |
+
output_gallery.change(lambda x: gr.update(value=x), inputs=[output_gallery], outputs=[mlp_gallery])
|
| 3039 |
+
|
| 3040 |
+
return output_gallery
|
| 3041 |
+
|
| 3042 |
+
galleries = []
|
| 3043 |
|
| 3044 |
+
with gr.Row():
|
| 3045 |
+
with gr.Column(scale=5, min_width=200):
|
| 3046 |
+
input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True)
|
| 3047 |
+
submit_button.visible = False
|
| 3048 |
+
|
| 3049 |
+
|
| 3050 |
+
for i in range(3):
|
| 3051 |
+
g = add_one_model()
|
| 3052 |
+
galleries.append(g)
|
| 3053 |
+
|
| 3054 |
+
# Create rows and buttons in a loop
|
| 3055 |
+
rows = []
|
| 3056 |
+
buttons = []
|
| 3057 |
+
|
| 3058 |
+
for i in range(4):
|
| 3059 |
+
row = gr.Row(visible=False)
|
| 3060 |
+
rows.append(row)
|
| 3061 |
+
|
| 3062 |
+
with row:
|
| 3063 |
+
for j in range(4):
|
| 3064 |
+
with gr.Column(scale=5, min_width=200):
|
| 3065 |
+
g = add_one_model()
|
| 3066 |
+
galleries.append(g)
|
| 3067 |
+
|
| 3068 |
+
button = gr.Button("➕ Add Compare", elem_id=f"add_button_{i}", visible=False if i > 0 else True, scale=3)
|
| 3069 |
+
buttons.append(button)
|
| 3070 |
+
|
| 3071 |
+
if i > 0:
|
| 3072 |
+
# Reveal the current row and next button
|
| 3073 |
+
buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=row)
|
| 3074 |
+
buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=button)
|
| 3075 |
+
|
| 3076 |
+
# Hide the current button
|
| 3077 |
+
buttons[i - 1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[i - 1])
|
| 3078 |
+
|
| 3079 |
+
# Last button only reveals the last row and hides itself
|
| 3080 |
+
buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1])
|
| 3081 |
+
buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1])
|
| 3082 |
+
|
| 3083 |
+
|
| 3084 |
|
| 3085 |
with gr.Tab('📄About'):
|
| 3086 |
with gr.Column():
|
directed_ncut.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
def affinity_from_features(
|
| 6 |
+
features,
|
| 7 |
+
features_B=None,
|
| 8 |
+
affinity_focal_gamma=1.0,
|
| 9 |
+
distance="cosine",
|
| 10 |
+
normalize_features=False,
|
| 11 |
+
fill_diagonal=False,
|
| 12 |
+
n_features=1,
|
| 13 |
+
):
|
| 14 |
+
"""Compute affinity matrix from input features.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
features (torch.Tensor): input features, shape (n_samples, n_features)
|
| 18 |
+
feature_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
|
| 19 |
+
affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the edge weights
|
| 20 |
+
on weak connections, default 1.0
|
| 21 |
+
distance (str): distance metric, 'cosine' (default) or 'euclidean'.
|
| 22 |
+
apply_normalize (bool): normalize input features before computing affinity matrix,
|
| 23 |
+
default True
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
(torch.Tensor): affinity matrix, shape (n_samples, n_samples)
|
| 27 |
+
"""
|
| 28 |
+
# compute affinity matrix from input features
|
| 29 |
+
features = features.clone()
|
| 30 |
+
if features_B is not None:
|
| 31 |
+
features_B = features_B.clone()
|
| 32 |
+
|
| 33 |
+
# if feature_B is not provided, compute affinity matrix on features x features
|
| 34 |
+
# if feature_B is provided, compute affinity matrix on features x feature_B
|
| 35 |
+
if features_B is not None:
|
| 36 |
+
assert not fill_diagonal, "fill_diagonal should be False when feature_B is None"
|
| 37 |
+
features_B = features if features_B is None else features_B
|
| 38 |
+
|
| 39 |
+
if normalize_features:
|
| 40 |
+
features = F.normalize(features, dim=-1)
|
| 41 |
+
features_B = F.normalize(features_B, dim=-1)
|
| 42 |
+
|
| 43 |
+
if distance == "cosine":
|
| 44 |
+
# if not check_if_normalized(features):
|
| 45 |
+
|
| 46 |
+
# TODO: make sure features are normalized within each head
|
| 47 |
+
|
| 48 |
+
features = F.normalize(features, dim=-1)
|
| 49 |
+
# if not check_if_normalized(features_B):
|
| 50 |
+
features_B = F.normalize(features_B, dim=-1)
|
| 51 |
+
A = 1 - (features @ features_B.T) / n_features
|
| 52 |
+
elif distance == "euclidean":
|
| 53 |
+
A = torch.cdist(features, features_B, p=2) / n_features
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError("distance should be 'cosine' or 'euclidean'")
|
| 56 |
+
|
| 57 |
+
if fill_diagonal:
|
| 58 |
+
A[torch.arange(A.shape[0]), torch.arange(A.shape[0])] = 0
|
| 59 |
+
|
| 60 |
+
# torch.exp make affinity matrix positive definite,
|
| 61 |
+
# lower affinity_focal_gamma reduce the weak edge weights
|
| 62 |
+
A = torch.exp(-((A / affinity_focal_gamma)))
|
| 63 |
+
return A
|
| 64 |
+
|
| 65 |
+
from ncut_pytorch.ncut_pytorch import run_subgraph_sampling, propagate_knn, gram_schmidt
|
| 66 |
+
import logging
|
| 67 |
+
|
| 68 |
+
import torch
|
| 69 |
+
|
| 70 |
+
def ncut(
|
| 71 |
+
A,
|
| 72 |
+
num_eig=20,
|
| 73 |
+
eig_solver="svd_lowrank",
|
| 74 |
+
make_symmetric=True,
|
| 75 |
+
):
|
| 76 |
+
"""PyTorch implementation of Normalized cut without Nystrom-like approximation.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
A (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
|
| 80 |
+
num_eig (int): number of eigenvectors to return
|
| 81 |
+
eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh']
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
(torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
|
| 85 |
+
(torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
|
| 86 |
+
"""
|
| 87 |
+
if make_symmetric:
|
| 88 |
+
# make sure A is symmetric
|
| 89 |
+
A = (A + A.T) / 2
|
| 90 |
+
|
| 91 |
+
# symmetrical normalization; A = D^(-1/2) A D^(-1/2)
|
| 92 |
+
D_r = A.sum(dim=0).detach().clone()
|
| 93 |
+
D_c = A.sum(dim=1).detach().clone()
|
| 94 |
+
A /= torch.sqrt(D_r)[:, None]
|
| 95 |
+
A /= torch.sqrt(D_c)[None, :]
|
| 96 |
+
|
| 97 |
+
# compute eigenvectors
|
| 98 |
+
if eig_solver == "svd_lowrank": # default
|
| 99 |
+
# only top q eigenvectors, fastest
|
| 100 |
+
eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig)
|
| 101 |
+
elif eig_solver == "lobpcg":
|
| 102 |
+
# only top k eigenvectors, fast
|
| 103 |
+
eigen_value, eigen_vector = torch.lobpcg(A, k=num_eig)
|
| 104 |
+
elif eig_solver == "svd":
|
| 105 |
+
# all eigenvectors, slow
|
| 106 |
+
eigen_vector, eigen_value, _ = torch.svd(A)
|
| 107 |
+
elif eig_solver == "eigh":
|
| 108 |
+
# all eigenvectors, slow
|
| 109 |
+
eigen_value, eigen_vector = torch.linalg.eigh(A)
|
| 110 |
+
else:
|
| 111 |
+
raise ValueError(
|
| 112 |
+
"eigen_solver should be 'lobpcg', 'svd_lowrank', 'svd' or 'eigh'"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# sort eigenvectors by eigenvalues, take top (descending order)
|
| 116 |
+
eigen_value = eigen_value.real
|
| 117 |
+
eigen_vector = eigen_vector.real
|
| 118 |
+
|
| 119 |
+
sort_order = torch.argsort(eigen_value, descending=True)[:num_eig]
|
| 120 |
+
eigen_value = eigen_value[sort_order]
|
| 121 |
+
eigen_vector = eigen_vector[:, sort_order]
|
| 122 |
+
|
| 123 |
+
if eigen_value.min() < 0:
|
| 124 |
+
logging.warning(
|
| 125 |
+
"negative eigenvalues detected, please make sure the affinity matrix is positive definite"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
return eigen_vector, eigen_value
|
| 129 |
+
|
| 130 |
+
def nystrom_ncut(
|
| 131 |
+
features,
|
| 132 |
+
features_B=None,
|
| 133 |
+
num_eig=100,
|
| 134 |
+
num_sample=10000,
|
| 135 |
+
knn=10,
|
| 136 |
+
sample_method="farthest",
|
| 137 |
+
distance="cosine",
|
| 138 |
+
affinity_focal_gamma=1.0,
|
| 139 |
+
indirect_connection=False,
|
| 140 |
+
indirect_pca_dim=100,
|
| 141 |
+
device=None,
|
| 142 |
+
eig_solver="svd_lowrank",
|
| 143 |
+
normalize_features=False,
|
| 144 |
+
matmul_chunk_size=8096,
|
| 145 |
+
make_orthogonal=False,
|
| 146 |
+
verbose=False,
|
| 147 |
+
no_propagation=False,
|
| 148 |
+
make_symmetric=False,
|
| 149 |
+
n_features=1,
|
| 150 |
+
):
|
| 151 |
+
"""PyTorch implementation of Faster Nystrom Normalized cut.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
features (torch.Tensor): feature matrix, shape (n_samples, n_features)
|
| 155 |
+
features_2 (torch.Tensor): feature matrix 2, for asymmetric affinity matrix, shape (n_samples2, n_features)
|
| 156 |
+
num_eig (int): default 20, number of top eigenvectors to return
|
| 157 |
+
num_sample (int): default 30000, number of samples for Nystrom-like approximation
|
| 158 |
+
knn (int): default 3, number of KNN for propagating eigenvectors from subgraph to full graph,
|
| 159 |
+
smaller knn will result in more sharp eigenvectors,
|
| 160 |
+
sample_method (str): sample method, 'farthest' (default) or 'random'
|
| 161 |
+
'farthest' is recommended for better approximation
|
| 162 |
+
distance (str): distance metric, 'cosine' (default) or 'euclidean'
|
| 163 |
+
affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the weak edge weights,
|
| 164 |
+
resulting in more sharp eigenvectors, default 1.0
|
| 165 |
+
indirect_connection (bool): include indirect connection in the subgraph, default True
|
| 166 |
+
indirect_pca_dim (int): default 100, PCA dimension to reduce the node dimension, only applied to
|
| 167 |
+
the not sampled nodes, not applied to the sampled nodes
|
| 168 |
+
device (str): device to use for computation, if None, will not change device
|
| 169 |
+
a good practice is to pass features by CPU since it's usually large,
|
| 170 |
+
and move subgraph affinity to GPU to speed up eigenvector computation
|
| 171 |
+
eig_solver (str): eigen decompose solver, 'svd_lowrank' (default), 'lobpcg', 'svd', 'eigh'
|
| 172 |
+
'svd_lowrank' is recommended for large scale graph, it's the fastest
|
| 173 |
+
they correspond to torch.svd_lowrank, torch.lobpcg, torch.svd, torch.linalg.eigh
|
| 174 |
+
normalize_features (bool): normalize input features before computing affinity matrix,
|
| 175 |
+
default True
|
| 176 |
+
matmul_chunk_size (int): chunk size for matrix multiplication
|
| 177 |
+
large matrix multiplication is chunked to reduce memory usage,
|
| 178 |
+
smaller chunk size will reduce memory usage but slower computation, default 8096
|
| 179 |
+
make_orthogonal (bool): make eigenvectors orthogonal after propagation, default True
|
| 180 |
+
verbose (bool): show progress bar when propagating eigenvectors from subgraph to full graph
|
| 181 |
+
no_propagation (bool): if True, skip the eigenvector propagation step, only return the subgraph eigenvectors
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
(torch.Tensor): eigenvectors, shape (n_samples, num_eig)
|
| 185 |
+
(torch.Tensor): eigenvalues, sorted in descending order, shape (num_eig,)
|
| 186 |
+
(torch.Tensor): sampled_indices used by Nystrom-like approximation subgraph, shape (num_sample,)
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
# check if features dimension greater than num_eig
|
| 190 |
+
if eig_solver in ["svd_lowrank", "lobpcg"]:
|
| 191 |
+
assert features.shape[0] > (
|
| 192 |
+
num_eig * 2
|
| 193 |
+
), "number of nodes should be greater than 2*num_eig"
|
| 194 |
+
if eig_solver in ["svd", "eigh"]:
|
| 195 |
+
assert (
|
| 196 |
+
features.shape[0] > num_eig
|
| 197 |
+
), "number of nodes should be greater than num_eig"
|
| 198 |
+
|
| 199 |
+
features = features.clone()
|
| 200 |
+
if normalize_features:
|
| 201 |
+
# features need to be normalized for affinity matrix computation (cosine distance)
|
| 202 |
+
features = torch.nn.functional.normalize(features, dim=-1)
|
| 203 |
+
|
| 204 |
+
sampled_indices = run_subgraph_sampling(
|
| 205 |
+
features,
|
| 206 |
+
num_sample=num_sample,
|
| 207 |
+
sample_method=sample_method,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
sampled_indices_B = run_subgraph_sampling(
|
| 211 |
+
features_B,
|
| 212 |
+
num_sample=num_sample,
|
| 213 |
+
sample_method=sample_method,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
sampled_features = features[sampled_indices]
|
| 217 |
+
sampled_features_B = features_B[sampled_indices_B]
|
| 218 |
+
# move subgraph gpu to speed up
|
| 219 |
+
original_device = sampled_features.device
|
| 220 |
+
device = original_device if device is None else device
|
| 221 |
+
sampled_features = sampled_features.to(device)
|
| 222 |
+
sampled_features_B = sampled_features_B.to(device)
|
| 223 |
+
|
| 224 |
+
# compute affinity matrix on subgraph
|
| 225 |
+
A = affinity_from_features(
|
| 226 |
+
sampled_features, features_B=sampled_features_B,
|
| 227 |
+
affinity_focal_gamma=affinity_focal_gamma, distance=distance,
|
| 228 |
+
n_features=n_features,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
not_sampled = torch.tensor(
|
| 232 |
+
list(set(range(features.shape[0])) - set(sampled_indices))
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if len(not_sampled) == 0:
|
| 236 |
+
# if sampled all nodes, no need for nyström approximation
|
| 237 |
+
eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver)
|
| 238 |
+
return eigen_vector, eigen_value, sampled_indices
|
| 239 |
+
|
| 240 |
+
# 1) PCA to reduce the node dimension for the not sampled nodes
|
| 241 |
+
# 2) compute indirect connection on the PC nodes
|
| 242 |
+
if len(not_sampled) > 0 and indirect_connection:
|
| 243 |
+
raise NotImplementedError("indirect_connection is not implemented yet")
|
| 244 |
+
indirect_pca_dim = min(indirect_pca_dim, min(*features.shape))
|
| 245 |
+
U, S, V = torch.pca_lowrank(features[not_sampled].T, q=indirect_pca_dim)
|
| 246 |
+
feature_B = (features[not_sampled].T @ V).T # project to PCA space
|
| 247 |
+
feature_B = feature_B.to(device)
|
| 248 |
+
B = affinity_from_features(
|
| 249 |
+
sampled_features,
|
| 250 |
+
feature_B,
|
| 251 |
+
affinity_focal_gamma=affinity_focal_gamma,
|
| 252 |
+
distance=distance,
|
| 253 |
+
fill_diagonal=False,
|
| 254 |
+
)
|
| 255 |
+
# P is 1-hop random walk matrix
|
| 256 |
+
B_row = B / B.sum(axis=1, keepdim=True)
|
| 257 |
+
B_col = B / B.sum(axis=0, keepdim=True)
|
| 258 |
+
P = B_row @ B_col.T
|
| 259 |
+
P = (P + P.T) / 2
|
| 260 |
+
# fill diagonal with 0
|
| 261 |
+
P[torch.arange(P.shape[0]), torch.arange(P.shape[0])] = 0
|
| 262 |
+
A = A + P
|
| 263 |
+
|
| 264 |
+
# compute normalized cut on the subgraph
|
| 265 |
+
eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver, make_symmetric=make_symmetric)
|
| 266 |
+
eigen_vector = eigen_vector.to(dtype=features.dtype, device=original_device)
|
| 267 |
+
eigen_value = eigen_value.to(dtype=features.dtype, device=original_device)
|
| 268 |
+
|
| 269 |
+
if no_propagation:
|
| 270 |
+
return eigen_vector, eigen_value, sampled_indices
|
| 271 |
+
|
| 272 |
+
# propagate eigenvectors from subgraph to full graph
|
| 273 |
+
eigen_vector = propagate_knn(
|
| 274 |
+
eigen_vector,
|
| 275 |
+
features,
|
| 276 |
+
sampled_features,
|
| 277 |
+
knn,
|
| 278 |
+
chunk_size=matmul_chunk_size,
|
| 279 |
+
device=device,
|
| 280 |
+
use_tqdm=verbose,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# post-hoc orthogonalization
|
| 284 |
+
if make_orthogonal:
|
| 285 |
+
eigen_vector = gram_schmidt(eigen_vector)
|
| 286 |
+
|
| 287 |
+
return eigen_vector, eigen_value, sampled_indices
|
requirements.txt
CHANGED
|
@@ -20,4 +20,4 @@ lisa @ git+https://github.com/huzeyann/LISA.git@7211e99
|
|
| 20 |
timm==0.9.2
|
| 21 |
open-clip-torch==2.20.0
|
| 22 |
pytorch_lightning==1.9.4
|
| 23 |
-
ncut-pytorch>=1.
|
|
|
|
| 20 |
timm==0.9.2
|
| 21 |
open-clip-torch==2.20.0
|
| 22 |
pytorch_lightning==1.9.4
|
| 23 |
+
ncut-pytorch>=1.4.1
|