# Author: Huzheng Yang # %% import copy from datetime import datetime import io import math import pickle from functools import partial from io import BytesIO import json import os import uuid import zipfile import multiprocessing as mp from gradio_image_prompter import ImagePrompter from einops import rearrange from matplotlib import pyplot as plt import matplotlib from matplotlib.offsetbox import AnnotationBbox, OffsetImage USE_HUGGINGFACE_ZEROGPU = os.getenv("USE_HUGGINGFACE_ZEROGPU", "False").lower() in ["true", "1", "yes"] DOWNLOAD_ALL_MODELS_DATASETS = os.getenv("DOWNLOAD_ALL_MODELS_DATASETS", "False").lower() in ["true", "1", "yes"] if USE_HUGGINGFACE_ZEROGPU: # huggingface ZeroGPU, dynamic GPU allocation try: import spaces except: USE_HUGGINGFACE_ZEROGPU = False if USE_HUGGINGFACE_ZEROGPU: BATCH_SIZE = 1 else: # run on local machine BATCH_SIZE = 1 import gradio as gr import torch import torch.nn.functional as F from PIL import Image import numpy as np import time import threading from ncut_pytorch.backbone import extract_features, load_model from ncut_pytorch.backbone import MODEL_DICT, LAYER_DICT, RES_DICT from ncut_pytorch import NCUT from ncut_pytorch import eigenvector_to_rgb DATASETS = { 'Common': [ ('mrm8488/ImageNet1K-val', 1000), ('UCSC-VLAA/Recap-COCO-30K', None), ('nateraw/pascal-voc-2012', None), ('johnowhitaker/imagenette2-320', 10), ('Multimodal-Fatima/CUB_train', 200), ('saragag/FlBirds', 7), ('microsoft/cats_vs_dogs', None), ('Robotkid2696/food_classification', 20), ('JapanDegitalMaterial/Places_in_Japan', None), ], 'Ego': [ ('EgoThink/EgoThink', None), ], 'Face': [ ('nielsr/CelebA-faces', None), ('huggan/anime-faces', None), ], 'Pose': [ ('sayakpaul/poses-controlnet-dataset', None), ('razdab/sign_pose_M', None), ('Marqo/deepfashion-multimodal', None), ('Fiacre/small-animal-poses-controlnet-dataset', None), ('junjuice0/vtuber-tachi-e', None), ], 'Hand': [ ('trashsock/hands-images', 8), ('dduka/guitar-chords-v3', None), ], 'Satellite': [ ('arakesh/deepglobe-2448x2448', None), ('tanganke/eurosat', 10), ('wangyi111/EuroSAT-SAR', None), ('efoley/sar_tile_512', None), ], 'Medical': [ ('Mahadih534/Chest_CT-Scan_images-Dataset', None), ('TrainingDataPro/chest-x-rays', None), ('hongrui/mimic_chest_xray_v_1', None), ('sartajbhuvaji/Brain-Tumor-Classification', 4), ('Falah/Alzheimer_MRI', 4), ('Leonardo6/path-vqa', None), ('Itsunori/path-vqa_jap', None), ('ruby-jrl/isic-2024-2', None), ('VRJBro/lung_cancer_dataset', 5), ('keremberke/blood-cell-object-detection', None) ], 'Miscs': [ ('yashvoladoddi37/kanjienglish', None), ('Borismile/Anime-dataset', None), ('jainr3/diffusiondb-pixelart', None), ('jlbaker361/dcgan-eval-creative_gan_256_256', None), ('Francesco/csgo-videogame', None), ('Francesco/apex-videogame', None), ('huggan/pokemon', None), ('huggan/few-shot-universe', None), ('huggan/flowers-102-categories', None), ('huggan/inat_butterflies_top10k', None), ] } CENTER_CROP_DATASETS = ["razdab/sign_pose_M"] from datasets import load_dataset def download_all_datasets(): for cat in DATASETS.keys(): for tup in DATASETS[cat]: name = tup[0] print(f"Downloading {name}") try: load_dataset(name, trust_remote_code=True) except Exception as e: print(f"Error downloading {name}: {e}") def compute_ncut( features, num_eig=100, num_sample_ncut=10000, affinity_focal_gamma=0.3, knn_ncut=10, knn_tsne=10, embedding_method="UMAP", embedding_metric='euclidean', num_sample_tsne=300, perplexity=150, n_neighbors=150, min_dist=0.1, sampling_method="QuickFPS", metric="cosine", indirect_connection=True, make_orthogonal=False, progess_start=0.4, only_eigvecs=False, ): progress = gr.Progress() logging_str = "" num_nodes = np.prod(features.shape[:-1]) if num_nodes / 2 < num_eig: # raise gr.Error("Number of eigenvectors should be less than half the number of nodes.") gr.Warning("Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.") num_eig = num_nodes // 2 - 1 logging_str += f"Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.\n" start = time.time() progress(progess_start+0.0, desc="NCut") eigvecs, eigvals = NCUT( num_eig=num_eig, num_sample=num_sample_ncut, device="cuda" if torch.cuda.is_available() else "cpu", affinity_focal_gamma=affinity_focal_gamma, knn=knn_ncut, sample_method=sampling_method, distance=metric, normalize_features=False, indirect_connection=indirect_connection, make_orthogonal=make_orthogonal, ).fit_transform(features.reshape(-1, features.shape[-1])) # print(f"NCUT time: {time.time() - start:.2f}s") logging_str += f"NCUT time: {time.time() - start:.2f}s\n" if only_eigvecs: return None, logging_str, eigvecs start = time.time() progress(progess_start+0.01, desc="spectral-tSNE") _, rgb = eigenvector_to_rgb( eigvecs, method=embedding_method, metric=embedding_metric, num_sample=num_sample_tsne, perplexity=perplexity, n_neighbors=n_neighbors, min_distance=min_dist, knn=knn_tsne, device="cuda" if torch.cuda.is_available() else "cpu", ) logging_str += f"{embedding_method} time: {time.time() - start:.2f}s\n" rgb = rgb.reshape(features.shape[:-1] + (3,)) return rgb, logging_str, eigvecs def compute_ncut_directed( features_1, features_2, num_eig=100, num_sample_ncut=10000, affinity_focal_gamma=0.3, knn_ncut=10, knn_tsne=10, embedding_method="UMAP", embedding_metric='euclidean', num_sample_tsne=300, perplexity=150, n_neighbors=150, min_dist=0.1, sampling_method="QuickFPS", metric="cosine", indirect_connection=False, make_orthogonal=False, make_symmetric=False, progess_start=0.4, ): # print("Using directed_ncut") # print("features_1.shape", features_1.shape) # print("features_2.shape", features_2.shape) from directed_ncut import nystrom_ncut progress = gr.Progress() logging_str = "" num_nodes = np.prod(features_1.shape[:-2]) if num_nodes / 2 < num_eig: # raise gr.Error("Number of eigenvectors should be less than half the number of nodes.") gr.Warning("Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.") num_eig = num_nodes // 2 - 1 logging_str += f"Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.\n" start = time.time() progress(progess_start+0.0, desc="NCut") n_features = features_1.shape[-2] _features_1 = rearrange(features_1, "b h w d c -> (b h w) (d c)") _features_2 = rearrange(features_2, "b h w d c -> (b h w) (d c)") eigvecs, eigvals, _ = nystrom_ncut( _features_1, features_B=_features_2, num_eig=num_eig, num_sample=num_sample_ncut, device="cuda" if torch.cuda.is_available() else "cpu", affinity_focal_gamma=affinity_focal_gamma, knn=knn_ncut, sample_method=sampling_method, distance=metric, normalize_features=False, indirect_connection=indirect_connection, make_orthogonal=make_orthogonal, make_symmetric=make_symmetric, n_features=n_features, ) # print(f"NCUT time: {time.time() - start:.2f}s") logging_str += f"NCUT time: {time.time() - start:.2f}s\n" start = time.time() progress(progess_start+0.01, desc="spectral-tSNE") _, rgb = eigenvector_to_rgb( eigvecs, method=embedding_method, metric=embedding_metric, num_sample=num_sample_tsne, perplexity=perplexity, n_neighbors=n_neighbors, min_distance=min_dist, knn=knn_tsne, device="cuda" if torch.cuda.is_available() else "cpu", ) logging_str += f"{embedding_method} time: {time.time() - start:.2f}s\n" rgb = rgb.reshape(features_1.shape[:3] + (3,)) return rgb, logging_str, eigvecs def dont_use_too_much_green(image_rgb): # make sure the foval 40% of the image is red leading x1, x2 = int(image_rgb.shape[1] * 0.3), int(image_rgb.shape[1] * 0.7) y1, y2 = int(image_rgb.shape[2] * 0.3), int(image_rgb.shape[2] * 0.7) sum_values = image_rgb[:, x1:x2, y1:y2].mean((0, 1, 2)) sorted_indices = sum_values.argsort(descending=True) image_rgb = image_rgb[:, :, :, sorted_indices] return image_rgb def to_pil_images(images, target_size=512, resize=True, force_size=False): size = images[0].shape[1] multiplier = target_size // size res = int(size * multiplier) if force_size: res = target_size pil_images = [] for image in images: if isinstance(image, torch.Tensor): image = image.cpu().numpy() if image.dtype == np.float32 or image.dtype == np.float64: image = (image * 255).astype(np.uint8) pil_images.append(Image.fromarray(image)) if resize: pil_images = [ image.resize((res, res), Image.Resampling.NEAREST) for image in pil_images ] return pil_images def pil_images_to_video(images, output_path, fps=5): # from pil images to numpy images = [np.array(image) for image in images] # print("Saving video to", output_path) import cv2 fourcc = cv2.VideoWriter_fourcc(*'mp4v') height, width, _ = images[0].shape out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) for image in images: out.write(cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) out.release() return output_path # save up to 100 videos in disk class VideoCache: def __init__(self, max_videos=100): self.max_videos = max_videos self.videos = {} def add_video(self, video_path): if len(self.videos) >= self.max_videos: pop_path = self.videos.popitem()[0] try: os.remove(pop_path) except: pass self.videos[video_path] = video_path def get_video(self, video_path): return self.videos.get(video_path, None) video_cache = VideoCache() def get_random_path(length=10): import random import string name = ''.join(random.choices(string.ascii_lowercase + string.digits, k=length)) path = f'/tmp/{name}.mp4' return path default_images = ['./images/image_0.jpg', './images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg', './images/guitar_ego.jpg', './images/image_5.jpg'] default_outputs = ['./images/image-1.webp', './images/image-2.webp', './images/image-3.webp', './images/image-4.webp', './images/image-5.webp'] # default_outputs_independent = ['./images/image-6.webp', './images/image-7.webp', './images/image-8.webp', './images/image-9.webp', './images/image-10.webp'] default_outputs_independent = [] downscaled_images = ['./images/image_0_small.jpg', './images/image_1_small.jpg', './images/image_2_small.jpg', './images/image_3_small.jpg', './images/image_5_small.jpg'] downscaled_outputs = default_outputs example_items = downscaled_images[:3] + downscaled_outputs[:3] def run_alignedthreemodelattnnodes(images, model, batch_size=16): use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") if use_cuda: model = model.to(device) chunked_idxs = torch.split(torch.arange(images.shape[0]), batch_size) outputs = [] for idxs in chunked_idxs: inp = images[idxs] if use_cuda: inp = inp.to(device) out = model(inp) # normalize before save out = F.normalize(out, dim=-1) outputs.append(out.cpu().float()) outputs = torch.cat(outputs, dim=0) return outputs def _reds_colormap(image): # normed_data = image / image.max() # Normalize to [0, 1] normed_data = image colormap = matplotlib.colormaps['inferno'] # Get the Reds colormap colored_image = colormap(normed_data) # Apply colormap return (colored_image[..., :3] * 255).astype(np.uint8) # Convert to RGB # heatmap images def apply_reds_colormap(images, size): # for i_image in range(images.shape[0]): # images[i_image] -= images[i_image].min() # images[i_image] /= images[i_image].max() # normed_data = [_reds_colormap(images[i]) for i in range(images.shape[0])] # normed_data = np.stack(normed_data) normed_data = _reds_colormap(images) normed_data = torch.tensor(normed_data).float() normed_data = rearrange(normed_data, "b h w c -> b c h w") normed_data = torch.nn.functional.interpolate(normed_data, size=size, mode="nearest") normed_data = rearrange(normed_data, "b c h w -> b h w c") normed_data = normed_data.cpu().numpy().astype(np.uint8) return normed_data # Blend heatmap with the original image def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5): blended = (1 - opacity1) * image + opacity2 * heatmap return blended.astype(np.uint8) def segment_fg_bg(images): images = F.interpolate(images, (224, 224), mode="bilinear") # model = load_alignedthreemodel() model = load_model("CLIP(ViT-B-16/openai)") from ncut_pytorch.backbone import resample_position_embeddings pos_embed = model.model.visual.positional_embedding pos_embed = resample_position_embeddings(pos_embed, 14, 14) model.model.visual.positional_embedding = torch.nn.Parameter(pos_embed) batch_size = 4 chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size) device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) means = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) stds = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) fg_acts, bg_acts = [], [] for chunk_idx in chunk_idxs: with torch.no_grad(): input_images = images[chunk_idx].to(device) # transform the input images input_images = (input_images - means) / stds # output = model(input_images)[:, 5] output = model(input_images)['attn'][6] # [B, H=14, W=14, C] fg_act = output[:, 6, 6].mean(0) bg_act = output[:, 0, 0].mean(0) fg_acts.append(fg_act) bg_acts.append(bg_act) fg_act = torch.stack(fg_acts, dim=0).mean(0) bg_act = torch.stack(bg_acts, dim=0).mean(0) fg_act = F.normalize(fg_act, dim=-1) bg_act = F.normalize(bg_act, dim=-1) # ref_image = default_images[0] # image = Image.open(ref_image).convert("RGB").resize((224, 224), Image.Resampling.BILINEAR) # image = torch.tensor(np.array(image)).permute(2, 0, 1).float().to(device) # image = (image / 255.0 - means) / stds # output = model(image)['attn'][6][0] # # print(output.shape) # # bg on the center # fg_act = output[5, 5] # # bg on the bottom left # bg_act = output[0, 0] # fg_act = F.normalize(fg_act, dim=-1) # bg_act = F.normalize(bg_act, dim=-1) # print(images.mean(), images.std()) fg_act, bg_act = fg_act.to(device), bg_act.to(device) chunk_idxs = torch.split(torch.arange(images.shape[0]), batch_size) heatmap_fgs, heatmap_bgs = [], [] for chunk_idx in chunk_idxs: with torch.no_grad(): input_images = images[chunk_idx].to(device) # transform the input images input_images = (input_images - means) / stds # output = model(input_images)[:, 5] output = model(input_images)['attn'][6] output = F.normalize(output, dim=-1) heatmap_fg = output @ fg_act[:, None] # [B, H, W, 1] heatmap_bg = output @ bg_act[:, None] # [B, H, W, 1] heatmap_fgs.append(heatmap_fg.cpu()) heatmap_bgs.append(heatmap_bg.cpu()) heatmap_fg = torch.cat(heatmap_fgs, dim=0) heatmap_bg = torch.cat(heatmap_bgs, dim=0) return heatmap_fg, heatmap_bg def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=False, clusters=50, eig_idx=None, title='cluster'): if clusters == 0: return [], [] progress = gr.Progress() progress(progess_start, desc="Finding Clusters by FPS") device = 'cuda' if torch.cuda.is_available() else 'cpu' eigvecs = eigvecs.to(device) from ncut_pytorch.ncut_pytorch import farthest_point_sampling magnitude = torch.norm(eigvecs, dim=-1) # gr.Info("Finding Clusters by FPS, no magnitude filtering") top_p_idx = torch.arange(eigvecs.shape[0]) if eig_idx is not None: top_p_idx = eig_idx # gr.Info("Finding Clusters by FPS, with magnitude filtering") # p = 0.8 # top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])] ret_magnitude = magnitude.reshape(-1, h, w) num_samples = 300 if num_samples > top_p_idx.shape[0]: num_samples = top_p_idx.shape[0] fps_idx = farthest_point_sampling(eigvecs[top_p_idx], num_samples) fps_idx = top_p_idx[fps_idx] # fps round 2 on the heatmap left = eigvecs[fps_idx, :].clone() right = eigvecs.clone() left = F.normalize(left, dim=-1) right = F.normalize(right, dim=-1) heatmap = left @ right.T heatmap = F.normalize(heatmap, dim=-1) # [300, N_pixel] PCA-> [300, 8] num_samples = clusters + 20 # 100/120 if num_samples > fps_idx.shape[0]: num_samples = fps_idx.shape[0] r2_fps_idx = farthest_point_sampling(heatmap, num_samples) fps_idx = fps_idx[r2_fps_idx] # downsample to 256x256 images = F.interpolate(images, (256, 256), mode="bilinear") images = images.cpu().numpy() images = images.transpose(0, 2, 3, 1) images = images * 255 images = images.astype(np.uint8) # sort the fps_idx by the mean of the heatmap fps_heatmaps = {} sort_values = [] top3_image_idx = {} top10_image_idx = {} for _, idx in enumerate(fps_idx): heatmap = F.cosine_similarity(eigvecs, eigvecs[idx][None], dim=-1) # def top_percentile(tensor, p=0.8, max_size=10000): # tensor = tensor.clone().flatten() # if tensor.shape[0] > max_size: # tensor = tensor[torch.randperm(tensor.shape[0])[:max_size]] # return tensor.quantile(p) # top_p = top_percentile(heatmap, p=0.5) top_p = 0.9 heatmap = heatmap.reshape(-1, h, w) mask = (heatmap > top_p).float() # take top 3 masks only mask_sort_values = mask.mean((1, 2)) _sort_value2 = (heatmap > 0.1).float().mean((1, 2)) * 0.1 mask_sort_values += _sort_value2 mask_sort_idx = torch.argsort(mask_sort_values, descending=True) mask = mask[mask_sort_idx[:3]] sort_values.append(mask.mean().item()) # fps_heatmaps[idx.item()] = heatmap.cpu() fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:6]].cpu() top3_image_idx[idx.item()] = mask_sort_idx[:3] top10_image_idx[idx.item()] = mask_sort_idx[:6] # do the sorting _sort_idx = torch.tensor(sort_values).argsort(descending=True) fps_idx = fps_idx[_sort_idx] # reverse the fps_idx # fps_idx = fps_idx.flip(0) # discard the big clusters # gr.Info("Discarding the biggest 10 clusters") # fps_idx = fps_idx[10:] # gr.Info("Not discarding the biggest 10 clusters") # gr.Info("Discarding the smallest 30 out of 80 sampled clusters") if not advanced: # shuffle the fps_idx fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])] def plot_cluster_images(fps_idx_chunk, chunk_idx): fig, axs = plt.subplots(3, 5, figsize=(15, 9)) if not advanced else plt.subplots(6, 5, figsize=(15, 18)) for ax in axs.flatten(): ax.axis("off") for j, idx in enumerate(fps_idx_chunk): heatmap = fps_heatmaps[idx.item()] size = (images.shape[1], images.shape[2]) heatmap = apply_reds_colormap(heatmap, size) image_idxs = top3_image_idx[idx.item()] if not advanced else top10_image_idx[idx.item()] for i, image_idx in enumerate(image_idxs): _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i]) axs[i, j].imshow(_heatmap) if i == 0: axs[i, j].set_title(f"{title} {chunk_idx * 5 + j + 1}", fontsize=24) plt.tight_layout(h_pad=0.5, w_pad=0.3) filename = f"{datetime.now():%Y%m%d%H%M%S%f}_{uuid.uuid4().hex}" tmp_path = f"/tmp/{filename}.png" plt.savefig(tmp_path, bbox_inches='tight', dpi=72) img = Image.open(tmp_path).convert("RGB") os.remove(tmp_path) plt.close() return img fig_images = [] num_plots = clusters // 5 plot_step_float = (1.0 - progess_start) / num_plots fps_idx_chunks = [fps_idx[i*5:(i+1)*5] for i in range(num_plots)] # with mp.Pool(processes=mp.cpu_count()) as pool: # results = [pool.apply_async(plot_cluster_images, args=(chunk, i)) for i, chunk in enumerate(fps_idx_chunks)] # for i, result in enumerate(results): # progress(progess_start + i * plot_step_float, desc=f"Plotted {title}") # fig_images.append(result.get()) for i, chunk in enumerate(fps_idx_chunks): progress(progess_start + i * plot_step_float, desc=f"Plotted {title}") fig_images.append(plot_cluster_images(chunk, i)) return fig_images, ret_magnitude def make_cluster_plot_advanced(eigvecs, images, h=64, w=64): heatmap_fg, heatmap_bg = segment_fg_bg(images.clone()) heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w') heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w') heatmap_fg = F.interpolate(heatmap_fg, (h, w), mode="bilinear") heatmap_bg = F.interpolate(heatmap_bg, (h, w), mode="bilinear") heatmap_fg = heatmap_fg.flatten() heatmap_bg = heatmap_bg.flatten() fg_minus_bg = heatmap_fg - heatmap_bg fg_mask = fg_minus_bg > fg_minus_bg.quantile(0.8) bg_mask = fg_minus_bg < fg_minus_bg.quantile(0.2) # fg_mask = heatmap_fg > heatmap_fg.quantile(0.8) # bg_mask = heatmap_bg > heatmap_bg.quantile(0.8) other_mask = ~(fg_mask | bg_mask) fg_idx = torch.arange(heatmap_fg.shape[0])[fg_mask] bg_idx = torch.arange(heatmap_bg.shape[0])[bg_mask] other_idx = torch.arange(heatmap_fg.shape[0])[other_mask] fg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=fg_idx, title="fg") bg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=20, eig_idx=bg_idx, title="bg") other_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=0, eig_idx=other_idx, title="other") cluster_images = fg_images + bg_images + other_images magitude = torch.norm(eigvecs, dim=-1) magitude = magitude.reshape(-1, h, w) # magitude = fg_minus_bg.reshape(-1, h, w) #TODO return cluster_images, magitude def ncut_run( model, images, model_name="DiNO(dino_vitb8_448)", layer=10, num_eig=100, node_type="block", affinity_focal_gamma=0.5, num_sample_ncut=10000, knn_ncut=10, embedding_method="tsne_3d", embedding_metric='euclidean', num_sample_tsne=1000, knn_tsne=10, perplexity=500, n_neighbors=500, min_dist=0.1, sampling_method="QuickFPS", ncut_metric="cosine", indirect_connection=True, make_orthogonal=False, old_school_ncut=False, recursion=False, recursion_l2_n_eigs=50, recursion_l3_n_eigs=20, recursion_metric="euclidean", recursion_l1_gamma=0.5, recursion_l2_gamma=0.5, recursion_l3_gamma=0.5, video_output=False, is_lisa=False, lisa_prompt1="", lisa_prompt2="", lisa_prompt3="", plot_clusters=False, alignedcut_eig_norm_plot=False, **kwargs, ): advanced = kwargs.get("advanced", False) directed = kwargs.get("directed", False) progress = gr.Progress() progress(0.2, desc="Feature Extraction") logging_str = "" if "AlignedThreeModelAttnNodes" == model_name: # dirty patch for the alignedcut paper resolution = (224, 224) else: resolution = RES_DICT[model_name] logging_str += f"Resolution: {resolution}\n" if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne: # raise gr.Error("Perplexity must be less than the number of samples for t-SNE.") gr.Warning("Perplexity/n_neighbors must be less than the number of samples.\n" f"Setting Perplexity to {num_sample_tsne-1}.") logging_str += f"Perplexity/n_neighbors must be less than the number of samples.\n" f"Setting Perplexity to {num_sample_tsne-1}.\n" perplexity = num_sample_tsne - 1 n_neighbors = num_sample_tsne - 1 if torch.cuda.is_available(): torch.cuda.empty_cache() node_type = node_type.split(":")[0].strip() start = time.time() if "AlignedThreeModelAttnNodes" == model_name: # dirty patch for the alignedcut paper features = run_alignedthreemodelattnnodes(images, model, batch_size=BATCH_SIZE) elif is_lisa == True: # dirty patch for the LISA model features = [] with torch.no_grad(): model = model.cuda() images = images.cuda() lisa_prompts = [lisa_prompt1, lisa_prompt2, lisa_prompt3] for prompt in lisa_prompts: import bleach prompt = bleach.clean(prompt) prompt = prompt.strip() # print(prompt) # # copy the sting to a new string # copy_s = copy.copy(prompt) feature = model(images, input_str=prompt)[node_type][0] feature = F.normalize(feature, dim=-1) features.append(feature.cpu().float()) features = torch.stack(features) else: features = extract_features( images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE ) if directed: node_type2 = kwargs.get("node_type2", None) features_B = extract_features( images, model, node_type=node_type2, layer=layer-1, batch_size=BATCH_SIZE ) # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s") logging_str += f"Backbone time: {time.time() - start:.2f}s\n" del model progress(0.4, desc="NCut") if recursion: rgbs = [] all_eigvecs = [] recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma] inp = features progress_start = 0.4 for i, n_eigs in enumerate([num_eig, recursion_l2_n_eigs, recursion_l3_n_eigs]): logging_str += f"Recursion #{i+1}\n" progress_start += + 0.1 * i rgb, _logging_str, eigvecs = compute_ncut( inp, num_eig=n_eigs, num_sample_ncut=num_sample_ncut, affinity_focal_gamma=recursion_gammas[i], knn_ncut=knn_ncut, knn_tsne=knn_tsne, num_sample_tsne=num_sample_tsne, embedding_method=embedding_method, embedding_metric=embedding_metric, perplexity=perplexity, n_neighbors=n_neighbors, min_dist=min_dist, sampling_method=sampling_method, metric=ncut_metric if i == 0 else recursion_metric, indirect_connection=indirect_connection, make_orthogonal=make_orthogonal, progess_start=progress_start, ) logging_str += _logging_str all_eigvecs.append(eigvecs.cpu().clone()) if "AlignedThreeModelAttnNodes" == model_name: # dirty patch for the alignedcut paper start = time.time() progress(progress_start + 0.09, desc=f"Plotting Recursion {i+1}") pil_images = [] for i_image in range(rgb.shape[0]): _im = plot_one_image_36_grid(images[i_image], rgb[i_image]) pil_images.append(_im) rgbs.append(pil_images) logging_str += f"plot time: {time.time() - start:.2f}s\n" else: rgb = dont_use_too_much_green(rgb) rgbs.append(to_pil_images(rgb)) inp = eigvecs.reshape(*features.shape[:-1], -1) if recursion_metric == "cosine": inp = F.normalize(inp, dim=-1) if not advanced: return rgbs[0], rgbs[1], rgbs[2], logging_str if "AlignedThreeModelAttnNodes" == model_name: return rgbs[0], rgbs[1], rgbs[2], logging_str if advanced: cluster_plots, norm_plots = [], [] for i in range(3): eigvecs = all_eigvecs[i] # add norm plot, cluster plot start = time.time() progress_start = 0.6 progress(progress_start, desc=f"Plotting Clusters Recursion #{i+1}") h, w = features.shape[1], features.shape[2] if torch.cuda.is_available(): images = images.cuda() _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower()) cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w) logging_str += f"Recursion #{i+1} plot time: {time.time() - start:.2f}s\n" norm_images = [] vmin, vmax = eig_magnitude.min(), eig_magnitude.max() eig_magnitude = (eig_magnitude - vmin) / (vmax - vmin) eig_magnitude = eig_magnitude.cpu().numpy() colormap = matplotlib.colormaps['Reds'] for i_image in range(eig_magnitude.shape[0]): norm_image = colormap(eig_magnitude[i_image]) norm_images.append(torch.tensor(norm_image[..., :3])) norm_images = to_pil_images(norm_images) logging_str += f"Recursion #{i+1} Eigenvector Magnitude: [{vmin:.2f}, {vmax:.2f}]\n" gr.Info(f"Recursion #{i+1} Eigenvector Magnitude:
Min: {vmin:.2f}, Max: {vmax:.2f}", duration=10) cluster_plots.append(cluster_images) norm_plots.append(norm_images) return *rgbs, *norm_plots, *cluster_plots, logging_str if old_school_ncut: # individual images logging_str += "Running NCut for each image independently\n" rgb = [] progress_start = 0.4 step_float = 0.6 / features.shape[0] for i_image in range(features.shape[0]): logging_str += f"Image #{i_image+1}\n" feature = features[i_image] _rgb, _logging_str, _ = compute_ncut( feature[None], num_eig=num_eig, num_sample_ncut=30000, affinity_focal_gamma=affinity_focal_gamma, knn_ncut=1, knn_tsne=10, num_sample_tsne=300, embedding_method=embedding_method, embedding_metric=embedding_metric, perplexity=perplexity, n_neighbors=n_neighbors, min_dist=min_dist, sampling_method=sampling_method, metric=ncut_metric, indirect_connection=indirect_connection, make_orthogonal=make_orthogonal, progess_start=progress_start+step_float*i_image, ) logging_str += _logging_str rgb.append(_rgb[0]) return to_pil_images(rgb), logging_str # ailgnedcut if not directed: only_eigvecs = kwargs.get("only_eigvecs", False) return_eigvec_and_rgb = kwargs.get("return_eigvec_and_rgb", False) normalize_eigvec_return = kwargs.get("normalize_eigvec_return", False) rgb, _logging_str, eigvecs = compute_ncut( features, num_eig=num_eig, num_sample_ncut=num_sample_ncut, affinity_focal_gamma=affinity_focal_gamma, knn_ncut=knn_ncut, knn_tsne=knn_tsne, num_sample_tsne=num_sample_tsne, embedding_method=embedding_method, embedding_metric=embedding_metric, perplexity=perplexity, n_neighbors=n_neighbors, min_dist=min_dist, sampling_method=sampling_method, indirect_connection=indirect_connection, make_orthogonal=make_orthogonal, metric=ncut_metric, only_eigvecs=only_eigvecs, ) if only_eigvecs: if normalize_eigvec_return: eigvecs = F.normalize(eigvecs, dim=-1) eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,)) eigvecs = eigvecs.detach().numpy() logging_str += _logging_str return eigvecs, logging_str if return_eigvec_and_rgb: if normalize_eigvec_return: eigvecs = F.normalize(eigvecs, dim=-1) eigvecs = eigvecs.to("cpu").reshape(features.shape[:-1] + (num_eig,)) eigvecs = eigvecs.detach().numpy() rgb = rgb.cpu().numpy() logging_str += _logging_str return eigvecs, rgb, logging_str if directed: head_index_text = kwargs.get("head_index_text", None) n_heads = features.shape[-2] # (batch, h, w, n_heads, d) if head_index_text == 'all': head_idx = torch.arange(n_heads) else: _idxs = head_index_text.split(",") head_idx = torch.tensor([int(idx) for idx in _idxs]) features_A = features[:, :, :, head_idx, :] features_B = features_B[:, :, :, head_idx, :] rgb, _logging_str, eigvecs = compute_ncut_directed( features_A, features_B, num_eig=num_eig, num_sample_ncut=num_sample_ncut, affinity_focal_gamma=affinity_focal_gamma, knn_ncut=knn_ncut, knn_tsne=knn_tsne, num_sample_tsne=num_sample_tsne, embedding_method=embedding_method, embedding_metric=embedding_metric, perplexity=perplexity, n_neighbors=n_neighbors, min_dist=min_dist, sampling_method=sampling_method, indirect_connection=False, make_orthogonal=make_orthogonal, metric=ncut_metric, make_symmetric=kwargs.get("make_symmetric", None), ) logging_str += _logging_str if "AlignedThreeModelAttnNodes" == model_name: # dirty patch for the alignedcut paper start = time.time() progress(0.6, desc="Plotting") pil_images = [] for i_image in range(rgb.shape[0]): _im = plot_one_image_36_grid(images[i_image], rgb[i_image]) pil_images.append(_im) logging_str += f"plot time: {time.time() - start:.2f}s\n" return pil_images, logging_str if is_lisa == True: # dirty patch for the LISA model galleries = [] for i_prompt in range(len(lisa_prompts)): _rgb = rgb[i_prompt] galleries.append(to_pil_images(_rgb)) return *galleries, logging_str rgb = dont_use_too_much_green(rgb) if video_output: progress(0.8, desc="Saving Video") video_path = get_random_path() video_cache.add_video(video_path) pil_images_to_video(to_pil_images(rgb), video_path, fps=5) return video_path, logging_str cluster_images = None if plot_clusters and kwargs.get("n_ret", 1) > 1: start = time.time() progress_start = 0.6 progress(progress_start, desc="Plotting Clusters") h, w = features.shape[1], features.shape[2] if torch.cuda.is_available(): images = images.cuda() _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower()) advanced = kwargs.get("advanced", False) if advanced: cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w) else: cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=False) logging_str += f"plot time: {time.time() - start:.2f}s\n" norm_images = None if alignedcut_eig_norm_plot and kwargs.get("n_ret", 1) > 1: norm_images = [] # eig_magnitude = torch.clamp(eig_magnitude, 0, 1) vmin, vmax = eig_magnitude.min(), eig_magnitude.max() eig_magnitude = (eig_magnitude - vmin) / (vmax - vmin) eig_magnitude = eig_magnitude.cpu().numpy() colormap = matplotlib.colormaps['Reds'] for i_image in range(eig_magnitude.shape[0]): norm_image = colormap(eig_magnitude[i_image]) # norm_image = (norm_image[..., :3] * 255).astype(np.uint8) # norm_images.append(Image.fromarray(norm_image)) norm_images.append(torch.tensor(norm_image[..., :3])) norm_images = to_pil_images(norm_images) logging_str += "Eigenvector Magnitude\n" logging_str += f"Min: {vmin:.2f}, Max: {vmax:.2f}\n" gr.Info(f"Eigenvector Magnitude:
Min: {vmin:.2f}, Max: {vmax:.2f}", duration=10) return to_pil_images(rgb), cluster_images, norm_images, logging_str def _ncut_run(*args, **kwargs): n_ret = kwargs.get("n_ret", 1) try: gr.Info("NCUT Run Started", 2) if torch.cuda.is_available(): torch.cuda.empty_cache() ret = ncut_run(*args, **kwargs) if torch.cuda.is_available(): torch.cuda.empty_cache() ret = list(ret)[:n_ret] + [ret[-1]] gr.Info("NCUT Run Finished", 2) return ret except Exception as e: gr.Error(str(e)) if torch.cuda.is_available(): torch.cuda.empty_cache() return *(None for _ in range(n_ret)), "Error: " + str(e) # ret = ncut_run(*args, **kwargs) # ret = list(ret)[:n_ret] + [ret[-1]] # return ret if USE_HUGGINGFACE_ZEROGPU: @spaces.GPU(duration=30) def quick_run(*args, **kwargs): return _ncut_run(*args, **kwargs) @spaces.GPU(duration=45) def long_run(*args, **kwargs): return _ncut_run(*args, **kwargs) @spaces.GPU(duration=60) def longer_run(*args, **kwargs): return _ncut_run(*args, **kwargs) @spaces.GPU(duration=120) def super_duper_long_run(*args, **kwargs): return _ncut_run(*args, **kwargs) def cpu_run(*args, **kwargs): return _ncut_run(*args, **kwargs) if not USE_HUGGINGFACE_ZEROGPU: def quick_run(*args, **kwargs): return _ncut_run(*args, **kwargs) def long_run(*args, **kwargs): return _ncut_run(*args, **kwargs) def longer_run(*args, **kwargs): return _ncut_run(*args, **kwargs) def super_duper_long_run(*args, **kwargs): return _ncut_run(*args, **kwargs) def cpu_run(*args, **kwargs): return _ncut_run(*args, **kwargs) def extract_video_frames(video_path, max_frames=100): from decord import VideoReader vr = VideoReader(video_path) num_frames = len(vr) if num_frames > max_frames: gr.Warning(f"Video has {num_frames} frames. Only using {max_frames} frames. Evenly spaced.") frame_idx = np.linspace(0, num_frames - 1, max_frames, dtype=int).tolist() else: frame_idx = list(range(num_frames)) frames = vr.get_batch(frame_idx).asnumpy() # return as list of PIL images return [(Image.fromarray(frames[i]), "") for i in range(frames.shape[0])] def transform_image(image, resolution=(1024, 1024), stablediffusion=False): image = image.convert('RGB').resize(resolution, Image.LANCZOS) # Convert to torch tensor image = torch.tensor(np.array(image).transpose(2, 0, 1)).float() image = image / 255 # Normalize if not stablediffusion: mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] image = (image - torch.tensor(mean).view(3, 1, 1)) / torch.tensor(std).view(3, 1, 1) if stablediffusion: image = image * 2 - 1 return image def reverse_transform_image(image, stablediffusion=False): if stablediffusion: image = (image + 1) / 2 else: mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(image.device) std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(image.device) image = image * std + mean image = torch.clamp(image, 0, 1) return image def plot_one_image_36_grid(original_image, tsne_rgb_images): mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] original_image = original_image * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1) original_image = torch.clamp(original_image, 0, 1) fig = plt.figure(figsize=(20, 4)) grid = plt.GridSpec(3, 14, hspace=0.1, wspace=0.1) ax1 = fig.add_subplot(grid[0:2, 0:2]) img = original_image.cpu().float().numpy().transpose(1, 2, 0) def convert_and_pad_image(np_array, pad_size=20): """ Converts a NumPy array of shape (height, width, 3) to a PNG image and pads the right and bottom sides with a transparent background. Args: np_array (numpy.ndarray): Input NumPy array of shape (height, width, 3) pad_size (int, optional): Number of pixels to pad on the right and bottom sides. Default is 20. Returns: PIL.Image: Padded PNG image with transparent background """ # Convert NumPy array to PIL Image img = Image.fromarray(np_array) # Get the original size width, height = img.size # Create a new image with padding and transparent background new_width = width + pad_size new_height = height + pad_size padded_img = Image.new('RGBA', (new_width, new_height), color=(255, 255, 255, 0)) # Paste the original image onto the padded image padded_img.paste(img, (0, 0)) return padded_img img = convert_and_pad_image((img*255).astype(np.uint8)) ax1.imshow(img) ax1.axis('off') model_names = ['CLIP', 'DINO', 'MAE'] for i_model, model_name in enumerate(model_names): for i_layer in range(12): ax = fig.add_subplot(grid[i_model, i_layer+2]) ax.imshow(tsne_rgb_images[i_layer+12*i_model].cpu().float().numpy()) ax.axis('off') if i_model == 0: ax.set_title(f'Layer{i_layer}', fontsize=16) if i_layer == 0: ax.text(-0.1, 0.5, model_name, va="center", ha="center", fontsize=16, transform=ax.transAxes, rotation=90,) plt.tight_layout() filename = uuid.uuid4() filename = f"/tmp/{filename}.png" plt.savefig(filename, bbox_inches='tight', pad_inches=0, dpi=100) img = Image.open(filename) img = img.convert("RGB") img = copy.deepcopy(img) os.remove(filename) plt.close() return img def load_alignedthreemodel(): import sys if "alignedthreeattn" not in sys.path: for _ in range(3): os.system("git clone https://huggingface.co/huzey/alignedthreeattn >> /dev/null 2>&1") os.system("git -C alignedthreeattn pull >> /dev/null 2>&1") # add to path sys.path.append("alignedthreeattn") from alignedthreeattn.alignedthreeattn_model import ThreeAttnNodes align_weights = torch.load("alignedthreeattn/align_weights.pth") model = ThreeAttnNodes(align_weights) return model try: # pre-load the alignedthree model in case it fails to load load_alignedthreemodel() except Exception as e: pass promptable_diffusion_models = ["Diffusion(stabilityai/stable-diffusion-2)", "Diffusion(CompVis/stable-diffusion-v1-4)"] promptable_segmentation_models = ["LISA(xinlai/LISA-7B-v1)"] def run_fn( images, model_name="DiNO(dino_vitb8_448)", layer=10, num_eig=100, node_type="block", positive_prompt="", negative_prompt="", is_lisa=False, lisa_prompt1="", lisa_prompt2="", lisa_prompt3="", affinity_focal_gamma=0.5, num_sample_ncut=10000, knn_ncut=10, ncut_indirect_connection=True, ncut_make_orthogonal=False, embedding_method="tsne_3d", embedding_metric='euclidean', num_sample_tsne=300, knn_tsne=10, perplexity=150, n_neighbors=150, min_dist=0.1, sampling_method="QuickFPS", ncut_metric="cosine", old_school_ncut=False, max_frames=100, recursion=False, recursion_l2_n_eigs=50, recursion_l3_n_eigs=20, recursion_metric="euclidean", recursion_l1_gamma=0.5, recursion_l2_gamma=0.5, recursion_l3_gamma=0.5, node_type2="k", head_index_text='all', make_symmetric=False, n_ret=1, plot_clusters=False, alignedcut_eig_norm_plot=False, advanced=False, directed=False, only_eigvecs=False, return_eigvec_and_rgb=False, normalize_eigvec_return=False, ): # print(node_type2, head_index_text, make_symmetric) progress=gr.Progress() progress(0, desc="Starting") if images is None: gr.Warning("No images selected.") return *(None for _ in range(n_ret)), "No images selected." progress(0.05, desc="Processing Images") video_output = False if isinstance(images, str): images = extract_video_frames(images, max_frames=max_frames) video_output = True if sampling_method == "QuickFPS": sampling_method = "farthest" # resize the images before acquiring GPU if "AlignedThreeModelAttnNodes" == model_name: # dirty patch for the alignedcut paper resolution = (224, 224) else: resolution = RES_DICT[model_name] images = [tup[0] for tup in images] stablediffusion = True if "Diffusion" in model_name else False images = [transform_image(image, resolution=resolution, stablediffusion=stablediffusion) for image in images] images = torch.stack(images) progress(0.1, desc="Downloading Model") if is_lisa: import subprocess import sys import importlib gr.Warning("LISA model is not compatible with the current version of transformers. Please contact the LISA and Llava author for update.") gr.Warning("This is a dirty patch for the LISA model. switch to the old version of transformers.") gr.Warning("Not garanteed to work.") # LISA and Llava is not compatible with the current version of transformers # please contact the author for update # this is a dirty patch for the LISA model # pre-import the SD3 pipeline from diffusers import StableDiffusion3Pipeline # unloading the current transformers for module in list(sys.modules.keys()): if "transformers" in module: del sys.modules[module] def install_transformers_version(version, target_dir): """Install a specific version of transformers to a target directory.""" if not os.path.exists(target_dir): os.makedirs(target_dir) # Use subprocess to run the pip command # subprocess.check_call([sys.executable, '-m', 'pip', 'install', f'transformers=={version}', '-t', target_dir]) os.system(f"{sys.executable} -m pip install transformers=={version} -t {target_dir} >> /dev/null 2>&1") target_dir = '/tmp/lisa_transformers_v433' if not os.path.exists(target_dir): install_transformers_version('4.33.0', target_dir) # Add the new version path to sys.path sys.path.insert(0, target_dir) transformers = importlib.import_module("transformers") if not is_lisa: import subprocess import sys import importlib # remove the LISA model from the sys.path if "/tmp/lisa_transformers_v433" in sys.path: sys.path.remove("/tmp/lisa_transformers_v433") transformers = importlib.import_module("transformers") if "AlignedThreeModelAttnNodes" == model_name: # dirty patch for the alignedcut paper model = load_alignedthreemodel() else: model = load_model(model_name) if directed: # save qkv for directed, need more memory model.enable_save_qkv() if "stable" in model_name.lower() and "diffusion" in model_name.lower(): model.timestep = layer layer = 1 if model_name in promptable_diffusion_models: model.positive_prompt = positive_prompt model.negative_prompt = negative_prompt kwargs = { "model_name": model_name, "layer": layer, "num_eig": num_eig, "node_type": node_type, "affinity_focal_gamma": affinity_focal_gamma, "num_sample_ncut": num_sample_ncut, "knn_ncut": knn_ncut, "embedding_method": embedding_method, "embedding_metric": embedding_metric, "num_sample_tsne": num_sample_tsne, "knn_tsne": knn_tsne, "perplexity": perplexity, "n_neighbors": n_neighbors, "min_dist": min_dist, "sampling_method": sampling_method, "ncut_metric": ncut_metric, "indirect_connection": ncut_indirect_connection, "make_orthogonal": ncut_make_orthogonal, "old_school_ncut": old_school_ncut, "recursion": recursion, "recursion_l2_n_eigs": recursion_l2_n_eigs, "recursion_l3_n_eigs": recursion_l3_n_eigs, "recursion_metric": recursion_metric, "recursion_l1_gamma": recursion_l1_gamma, "recursion_l2_gamma": recursion_l2_gamma, "recursion_l3_gamma": recursion_l3_gamma, "video_output": video_output, "lisa_prompt1": lisa_prompt1, "lisa_prompt2": lisa_prompt2, "lisa_prompt3": lisa_prompt3, "is_lisa": is_lisa, "n_ret": n_ret, "plot_clusters": plot_clusters, "alignedcut_eig_norm_plot": alignedcut_eig_norm_plot, "advanced": advanced, "directed": directed, "node_type2": node_type2, "head_index_text": head_index_text, "make_symmetric": make_symmetric, "only_eigvecs": only_eigvecs, "return_eigvec_and_rgb": return_eigvec_and_rgb, "normalize_eigvec_return": normalize_eigvec_return, } # print(kwargs) try: # try to aquiare GPU, can fail if the user is out of GPU quota if old_school_ncut: return super_duper_long_run(model, images, **kwargs) if is_lisa: return super_duper_long_run(model, images, **kwargs) num_images = len(images) if num_images >= 100: return super_duper_long_run(model, images, **kwargs) if 'diffusion' in model_name.lower(): return super_duper_long_run(model, images, **kwargs) if recursion: return longer_run(model, images, **kwargs) if num_images >= 50: return longer_run(model, images, **kwargs) if old_school_ncut: return longer_run(model, images, **kwargs) if num_images >= 10: return long_run(model, images, **kwargs) if embedding_method == "UMAP": if perplexity >= 250 or num_sample_tsne >= 500: return longer_run(model, images, **kwargs) return long_run(model, images, **kwargs) if embedding_method == "t-SNE": if perplexity >= 250 or num_sample_tsne >= 500: return long_run(model, images, **kwargs) return quick_run(model, images, **kwargs) return quick_run(model, images, **kwargs) except gr.Error as e: # I assume this is a GPU quota error info1 = 'Running out of HuggingFace GPU Quota?
Please try Demo hosted at UPenn
' info2 = 'Or try use the Python package that powers this app: ncut-pytorch' info = info1 + info2 message = "HuggingFace:
" + e.message + "

---------
" + "`ncut-pytorch` Developer:
" + info raise gr.Error(message, duration=0) import torch from torch import nn from torch.utils.data import Dataset, DataLoader import pytorch_lightning as pl # Custom Dataset class TwoTensorDataset(Dataset): def __init__(self, A, B): self.A = A self.B = B def __len__(self): return len(self.A) def __getitem__(self, idx): return self.A[idx], self.B[idx] # MLP model class MLP(pl.LightningModule): def __init__(self, num_layer=3, width=512, lr=3e-4, fitting_steps=10000, seg_loss_lambda=1.0): super().__init__() layers = [nn.Linear(3, width), nn.GELU()] for _ in range(num_layer - 1): layers.append(nn.Linear(width, width)) layers.append(nn.GELU()) layers.append(nn.Linear(width, 3)) self.layers = nn.Sequential(*layers) self.mse_loss = nn.MSELoss() self.lr = lr self.fitting_steps = fitting_steps self.seg_loss_lambda = seg_loss_lambda self.progress = gr.Progress() def forward(self, x): return self.layers(x) def training_step(self, batch, batch_idx): x, y = batch y_hat = self.forward(x) loss = self.mse_loss(y_hat, y) # loss = torch.nn.functional.mse_loss(torch.log(y_hat), torch.log(y)) self.log("train_loss", loss) # add segmentation constraint bsz = x.shape[0] sample_size = 1000 if bsz > sample_size: idx = torch.randperm(bsz)[:sample_size] x = x[idx] y_hat = y_hat[idx] old_dist = torch.pdist(x, p=2) new_dist = torch.pdist(y_hat, p=2) # seg_loss = torch.log((old_dist - new_dist)).pow(2).mean() seg_loss = self.mse_loss(old_dist, new_dist) self.log("seg_loss", seg_loss) loss += seg_loss * self.seg_loss_lambda step = self.global_step if step % 100 == 0: self.progress(step / self.fitting_steps, desc="Fitting MLP") return loss def predict_step(self, batch, batch_idx, dataloader_idx=None): x = batch[0] y_hat = self.forward(x) return y_hat def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) return optimizer def fit_trans(rgb1, rgb2, num_layer=3, width=512, batch_size=256, lr=3e-4, fitting_steps=10000, fps_sample=4096, seg_loss_lambda=1.0): A = rgb1.clone() B = rgb2.clone() # FPS sample on the data from ncut_pytorch.ncut_pytorch import farthest_point_sampling A_idx = farthest_point_sampling(A, fps_sample) B_idx = farthest_point_sampling(B, fps_sample) A_B_idx = np.concatenate([A_idx, B_idx]) A = A[A_B_idx] B = B[A_B_idx] from torch.utils.data import DataLoader, TensorDataset # Dataset and DataLoader dataset = TwoTensorDataset(A, B) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Initialize model and trainer mlp = MLP(num_layer=num_layer, width=width, lr=lr, fitting_steps=fitting_steps, seg_loss_lambda=seg_loss_lambda) trainer = pl.Trainer( max_epochs=100000, gpus=1, max_steps=fitting_steps, enable_checkpointing=False, enable_progress_bar=False, gradient_clip_val=1.0 ) # Create a DataLoader for tensor A batch_size = 256 # Define your batch size data_loader = DataLoader(TensorDataset(rgb1), batch_size=batch_size, shuffle=False) # Train the model trainer.fit(mlp, dataloader) mlp.progress(0.99, desc="Applying MLP") results = trainer.predict(mlp, data_loader) A_transformed = torch.cat(results, dim=0) return A_transformed if USE_HUGGINGFACE_ZEROGPU: @spaces.GPU(duration=60) def _run_mlp_fit(*args, **kwargs): return fit_trans(*args, **kwargs) else: def _run_mlp_fit(*args, **kwargs): return fit_trans(*args, **kwargs) def run_mlp_fit(input_gallery, target_gallery, num_layer=3, width=512, batch_size=256, lr=3e-4, fitting_steps=10000, fps_sample=4096, seg_loss_lambda=1.0): # print("Fitting MLP") # print("Target Gallery Length:", len(target_gallery)) # print("Input Gallery Length:", len(input_gallery)) if target_gallery is None or len(target_gallery) == 0: raise gr.Error("No target images selected. Please use the Mark button to select the target images.") if input_gallery is None or len(input_gallery) == 0: raise gr.Error("No input images selected.") def gallery_to_rgb(gallery): images = [tup[0] for tup in gallery] rgb = [] for image in images: if isinstance(image, str): image = Image.open(image) image = image.convert('RGB') image = torch.tensor(np.array(image)).float() image = image / 255 rgb.append(image) rgb = torch.stack(rgb) shape = rgb.shape rgb = rgb.reshape(-1, 3) return rgb, shape target_rgb, target_shape = gallery_to_rgb(target_gallery) input_rgb, input_shape = gallery_to_rgb(input_gallery) input_transformed = _run_mlp_fit(input_rgb, target_rgb, num_layer=num_layer, width=width, batch_size=batch_size, lr=lr, fitting_steps=fitting_steps, fps_sample=fps_sample, seg_loss_lambda=seg_loss_lambda) input_transformed = input_transformed.reshape(*input_shape) pil_images = to_pil_images(input_transformed, resize=False) return pil_images def make_input_video_section(): # gr.Markdown('### Input Video') input_gallery = gr.Video(value=None, label="Select video", elem_id="video-input", height="auto", show_share_button=False, interactive=True) gr.Markdown('_image backbone model is used to extract features from each frame, NCUT is computed on all frames_') max_frames_number = gr.Number(100, label="Max frames", elem_id="max_frames") # max_frames_number = gr.Slider(1, 200, step=1, label="Max frames", value=100, elem_id="max_frames") submit_button = gr.Button("🔴 RUN", elem_id="submit_button", variant='primary') clear_images_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop') return input_gallery, submit_button, clear_images_button, max_frames_number def load_dataset_images(is_advanced, dataset_name, num_images=10, is_filter=False, filter_by_class_text="0,1,2", is_random=False, seed=1): progress = gr.Progress() progress(0, desc="Loading Images") if dataset_name == "EgoExo": is_advanced = "Basic" if is_advanced == "Basic": gr.Info(f"Loaded images from EgoExo", duration=5) return default_images try: progress(0.5, desc="Downloading Dataset") if 'EgoThink' in dataset_name: dataset = load_dataset(dataset_name, 'Activity', trust_remote_code=True) else: dataset = load_dataset(dataset_name, trust_remote_code=True) key = list(dataset.keys())[0] dataset = dataset[key] except Exception as e: raise gr.Error(f"Error loading dataset {dataset_name}: {e}") if num_images > len(dataset): num_images = len(dataset) if len(filter_by_class_text) == 0: is_filter = False if is_filter: progress(0.8, desc="Filtering Images") classes = [int(i) for i in filter_by_class_text.split(",")] labels = np.array(dataset['label']) unique_labels = np.unique(labels) valid_classes = [i for i in classes if i in unique_labels] invalid_classes = [i for i in classes if i not in unique_labels] if len(invalid_classes) > 0: gr.Warning(f"Classes {invalid_classes} not found in the dataset.") if len(valid_classes) == 0: raise gr.Error(f"Classes {classes} not found in the dataset.") # shuffle each class chunk_size = num_images // len(valid_classes) image_idx = [] for i in valid_classes: idx = np.where(labels == i)[0] if is_random: if chunk_size < len(idx): idx = np.random.RandomState(seed).choice(idx, chunk_size, replace=False) else: gr.Warning(f"Class {i} has less than {chunk_size} images.") idx = idx[:chunk_size] else: idx = idx[:chunk_size] image_idx.extend(idx.tolist()) if not is_filter: if is_random: if num_images <= len(dataset): image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist() else: gr.Warning(f"Dataset has less than {num_images} images.") image_idx = list(range(num_images)) else: image_idx = list(range(num_images)) key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0] images = [dataset[i][key] for i in image_idx] gr.Info(f"Loaded {len(images)} images from {dataset_name}", duration=5) del dataset if dataset_name in CENTER_CROP_DATASETS: def center_crop_image(img): # image: PIL image w, h = img.size min_hw = min(h, w) # center crop left = (w - min_hw) // 2 top = (h - min_hw) // 2 right = left + min_hw bottom = top + min_hw img = img.crop((left, top, right, bottom)) return img images = [center_crop_image(image) for image in images] return images def load_and_append(existing_images, *args, **kwargs): new_images = load_dataset_images(*args, **kwargs) if new_images is None: return existing_images if len(new_images) == 0: return existing_images if existing_images is None: existing_images = [] existing_images += new_images gr.Info(f"Total images: {len(existing_images)}") return existing_images def make_input_images_section(rows=1, cols=3, height="450px", advanced=False, is_random=False, allow_download=False, markdown=True, n_example_images=100): if markdown: gr.Markdown('### Input Images') input_gallery = gr.Gallery(value=None, label="Input images", show_label=True, elem_id="input_images", columns=[cols], rows=[rows], object_fit="contain", height=height, type="pil", show_share_button=False, format="webp") submit_button = gr.Button("🔴 RUN", elem_id="submit_button", variant='primary') with gr.Row(): clear_images_button = gr.Button("🗑️ Clear", elem_id='clear_button', variant='stop') clear_images_button.click(fn=lambda: gr.update(value=None), outputs=[input_gallery]) upload_button = gr.UploadButton(elem_id="upload_button", label="⬆️ Upload", variant='secondary', file_types=["image"], file_count="multiple") def convert_to_pil_and_append(images, new_images): if images is None: images = [] if new_images is None: return images if isinstance(new_images, Image.Image): images.append(new_images) if isinstance(new_images, list): images += [Image.open(new_image) for new_image in new_images] if isinstance(new_images, str): images.append(Image.open(new_images)) gr.Info(f"Total images: {len(images)}") return images upload_button.upload(convert_to_pil_and_append, inputs=[input_gallery, upload_button], outputs=[input_gallery]) if allow_download: create_file_button, download_button = add_download_button(input_gallery, "input_images") gr.Markdown('### Load Datasets') advanced_radio = gr.Radio(["Basic", "Advanced"], label="Datasets Menu", value="Advanced" if advanced else "Basic", elem_id="advanced-radio", show_label=True) with gr.Column() as basic_block: # gr.Markdown('### Example Image Sets') def make_example(name, images, dataset_name): with gr.Row(): button = gr.Button("Load\n"+name, elem_id=f"example-{name}", elem_classes="small-button", variant='secondary', size="sm", scale=1, min_width=60) gallery = gr.Gallery(value=images, label=name, show_label=True, columns=[3], rows=[1], interactive=False, height=80, scale=8, object_fit="cover", min_width=140, allow_preview=False) button.click(fn=lambda: gr.update(value=load_dataset_images(True, dataset_name, n_example_images, is_random=True, seed=42)), outputs=[input_gallery]) return gallery, button example_items = [ ("EgoExo", ['./images/egoexo1.jpg', './images/egoexo3.jpg', './images/egoexo2.jpg'], "EgoExo"), ("Ego", ['./images/egothink1.jpg', './images/egothink2.jpg', './images/egothink3.jpg'], "EgoThink/EgoThink"), ("Face", ['./images/face1.jpg', './images/face2.jpg', './images/face3.jpg'], "nielsr/CelebA-faces"), ("Pose", ['./images/pose1.jpg', './images/pose2.jpg', './images/pose3.jpg'], "sayakpaul/poses-controlnet-dataset"), # ("CatDog", ['./images/catdog1.jpg', './images/catdog2.jpg', './images/catdog3.jpg'], "microsoft/cats_vs_dogs"), # ("Bird", ['./images/bird1.jpg', './images/bird2.jpg', './images/bird3.jpg'], "Multimodal-Fatima/CUB_train"), # ("ChestXray", ['./images/chestxray1.jpg', './images/chestxray2.jpg', './images/chestxray3.jpg'], "hongrui/mimic_chest_xray_v_1"), ("MRI", ['./images/brain1.jpg', './images/brain2.jpg', './images/brain3.jpg'], "sartajbhuvaji/Brain-Tumor-Classification"), ("Kanji", ['./images/kanji1.jpg', './images/kanji2.jpg', './images/kanji3.jpg'], "yashvoladoddi37/kanjienglish"), ] for name, images, dataset_name in example_items: make_example(name, images, dataset_name) with gr.Column() as advanced_block: load_images_button = gr.Button("🔴 Load Images", elem_id="load-images-button", variant='primary') # dataset_names = DATASET_NAMES # dataset_classes = DATASET_CLASSES dataset_categories = list(DATASETS.keys()) defualt_cat = dataset_categories[0] def get_choices(cat): return [tup[0] for tup in DATASETS[cat]] defualt_choices = get_choices(defualt_cat) with gr.Row(): dataset_radio = gr.Radio(dataset_categories, label="Dataset Category", value=defualt_cat, elem_id="dataset-radio", show_label=True, min_width=600) # dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset", min_width=300) dataset_dropdown = gr.Dropdown(defualt_choices, label="Dataset name", value=defualt_choices[0], elem_id="dataset", min_width=400) dataset_radio.change(fn=lambda x: gr.update(choices=get_choices(x), value=get_choices(x)[0]), inputs=dataset_radio, outputs=dataset_dropdown) # num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images") num_images_slider = gr.Slider(1, 1000, step=1, label="Number of images", value=10, elem_id="num_images", min_width=200) if not is_random: filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox") filter_by_class_text = gr.Textbox(label="Class to select", value="97,0", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=True) # is_random_checkbox = gr.Checkbox(label="Random shuffle", value=False, elem_id="random_seed_checkbox") # random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=False) is_random_checkbox = gr.Checkbox(label="Random shuffle", value=True, elem_id="random_seed_checkbox") random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=True) if is_random: filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox") filter_by_class_text = gr.Textbox(label="Class to select", value="97,0", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=False) is_random_checkbox = gr.Checkbox(label="Random shuffle", value=True, elem_id="random_seed_checkbox") random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=42, elem_id="random_seed", visible=True) # add functionality, save and load images to profile with gr.Accordion("Saved Image Profiles", open=False) as profile_accordion: with gr.Row(): profile_text = gr.Textbox(label="Profile name", placeholder="Type here: Profile name to save/load/delete", elem_id="profile-name", scale=6, show_label=False) list_profiles_button = gr.Button("📋 List", elem_id="list-profile-button", variant='secondary', scale=3) with gr.Row(): save_profile_button = gr.Button("💾 Save", elem_id="save-profile-button", variant='secondary') load_profile_button = gr.Button("📂 Load", elem_id="load-profile-button", variant='secondary') delete_profile_button = gr.Button("🗑️ Delete", elem_id="delete-profile-button", variant='secondary') class OnDiskProfiles: def __init__(self, profile_dir="demo_profiles"): if not os.path.exists(profile_dir): os.makedirs(profile_dir) self.profile_dir = profile_dir def list_profiles(self): profiles = os.listdir(self.profile_dir) # remove hidden files profiles = [p for p in profiles if not p.startswith(".")] if len(profiles) == 0: return "No profiles found." profile_text = "
".join(profiles) n_files = len(profiles) profile_text = f"Number of profiles: {n_files}
---------
" + profile_text return profile_text def save_profile(self, profile_name, images): profile_path = os.path.join(self.profile_dir, profile_name) if os.path.exists(profile_path): raise gr.Error(f"Profile {profile_name} already exists.") with open(profile_path, "wb") as f: pickle.dump(images, f) gr.Info(f"Profile {profile_name} saved.") return profile_path def load_profile(self, profile_name, existing_images): profile_path = os.path.join(self.profile_dir, profile_name) if not os.path.exists(profile_path): raise gr.Error(f"Profile {profile_name} not found.") with open(profile_path, "rb") as f: images = pickle.load(f) gr.Info(f"Profile {profile_name} loaded.") if existing_images is None: existing_images = [] return existing_images + images def delete_profile(self, profile_name): profile_path = os.path.join(self.profile_dir, profile_name) os.remove(profile_path) gr.Info(f"Profile {profile_name} deleted.") return profile_path home_dir = os.path.expanduser("~") defualt_dir = os.path.join(home_dir, ".cache") cache_dir = os.environ.get("DEMO_PROFILE_CACHE_DIR", defualt_dir) cache_dir = os.path.join(cache_dir, "demo_profiles") on_disk_profiles = OnDiskProfiles(cache_dir) save_profile_button.click(fn=lambda name, images: on_disk_profiles.save_profile(name, images), inputs=[profile_text, input_gallery]) load_profile_button.click(fn=lambda name, existing_images: gr.update(value=on_disk_profiles.load_profile(name, existing_images)), inputs=[profile_text, input_gallery], outputs=[input_gallery]) delete_profile_button.click(fn=lambda name: on_disk_profiles.delete_profile(name), inputs=profile_text) list_profiles_button.click(fn=lambda: gr.Info(on_disk_profiles.list_profiles(), duration=0)) if advanced: advanced_block.visible = True basic_block.visible = False else: advanced_block.visible = False basic_block.visible = True # change visibility advanced_radio.change(fn=lambda x: gr.update(visible=x=="Advanced"), inputs=advanced_radio, outputs=[advanced_block]) advanced_radio.change(fn=lambda x: gr.update(visible=x=="Basic"), inputs=advanced_radio, outputs=[basic_block]) def find_num_classes(dataset_name): num_classes = None for cat, datasets in DATASETS.items(): datasets = [tup[0] for tup in datasets] if dataset_name in datasets: num_classes = DATASETS[cat][datasets.index(dataset_name)][1] break return num_classes def change_filter_options(dataset_name): num_classes = find_num_classes(dataset_name) if num_classes is None: return (gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox", visible=False), gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info="e.g. `0,1,2`. This dataset has no class label", visible=False)) return (gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox", visible=True), gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=True)) dataset_dropdown.change(fn=change_filter_options, inputs=dataset_dropdown, outputs=[filter_by_class_checkbox, filter_by_class_text]) def change_filter_by_class(is_filter, dataset_name): num_classes = find_num_classes(dataset_name) return gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=is_filter) filter_by_class_checkbox.change(fn=change_filter_by_class, inputs=[filter_by_class_checkbox, dataset_dropdown], outputs=filter_by_class_text) def change_random_seed(is_random): return gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=is_random) is_random_checkbox.change(fn=change_random_seed, inputs=is_random_checkbox, outputs=random_seed_slider) load_images_button.click(load_and_append, inputs=[input_gallery, advanced_radio, dataset_dropdown, num_images_slider, filter_by_class_checkbox, filter_by_class_text, is_random_checkbox, random_seed_slider], outputs=[input_gallery]) return input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button # def random_rotate_rgb_gallery(images): # if images is None or len(images) == 0: # gr.Warning("No images selected.") # return [] # # read webp images # images = [Image.open(image[0]).convert("RGB") for image in images] # images = [np.array(image).astype(np.float32) for image in images] # images = np.stack(images) # images = torch.tensor(images) / 255 # position = np.random.choice([1, 2, 4, 5, 6]) # images = rotate_rgb_cube(images, position) # images = to_pil_images(images, resize=False) # return images def protect_original_image_in_plot(original_image, rotated_images): plot_h, plot_w = 332, 1542 image_h, image_w = original_image.shape[1], original_image.shape[2] if not (plot_h == image_h and plot_w == image_w): return rotated_images protection_w = 190 rotated_images[:, :, :protection_w] = original_image[:, :, :protection_w] return rotated_images def sequence_rotate_rgb_gallery(images): if images is None or len(images) == 0: gr.Warning("No images selected.") return [] # read webp images images = [Image.open(image[0]).convert("RGB") for image in images] images = [np.array(image).astype(np.float32) for image in images] images = np.stack(images) images = torch.tensor(images) / 255 original_images = images.clone() rotation_matrix = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).float() images = images @ rotation_matrix images = protect_original_image_in_plot(original_images, images) images = to_pil_images(images, resize=False) return images def flip_rgb_gallery(images, axis=0): if images is None or len(images) == 0: gr.Warning("No images selected.") return [] # read webp images images = [Image.open(image[0]).convert("RGB") for image in images] images = [np.array(image).astype(np.float32) for image in images] images = np.stack(images) images = torch.tensor(images) / 255 original_images = images.clone() images = 1 - images images = protect_original_image_in_plot(original_images, images) images = to_pil_images(images, resize=False) return images def add_rotate_flip_buttons(output_gallery): with gr.Row(): rotate_button = gr.Button("🔄 Rotate", elem_id="rotate_button", variant='secondary') rotate_button.click(sequence_rotate_rgb_gallery, inputs=[output_gallery], outputs=[output_gallery]) flip_button = gr.Button("🔃 Flip", elem_id="flip_button", variant='secondary') flip_button.click(flip_rgb_gallery, inputs=[output_gallery], outputs=[output_gallery]) return rotate_button, flip_button def add_download_button(gallery, filename_prefix="output"): def make_3x5_plot(images): plot_list = [] # Split the list of images into chunks of 15 chunks = [images[i:i + 15] for i in range(0, len(images), 15)] for chunk in chunks: fig, axs = plt.subplots(3, 4, figsize=(12, 9)) for ax in axs.flatten(): ax.axis("off") for ax, img in zip(axs.flatten(), chunk): img = img.convert("RGB") ax.imshow(img) plt.tight_layout(h_pad=0.5, w_pad=0.3) # Generate a unique filename filename = uuid.uuid4() tmp_path = f"/tmp/{filename}.png" # Save the plot to the temporary file plt.savefig(tmp_path, bbox_inches='tight', dpi=144) # Open the saved image img = Image.open(tmp_path) img = img.convert("RGB") img = copy.deepcopy(img) # Remove the temporary file os.remove(tmp_path) plot_list.append(img) plt.close() return plot_list def delete_file_after_delay(file_path, delay): def delete_file(): if os.path.exists(file_path): os.remove(file_path) timer = threading.Timer(delay, delete_file) timer.start() def create_zip_file(images, filename_prefix=filename_prefix): if images is None or len(images) == 0: gr.Warning("No images selected.") return None gr.Info("Creating zip file for download...") images = [image[0] for image in images] if isinstance(images[0], str): images = [Image.open(image) for image in images] timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") zip_filename = f"/tmp/gallery_download/{filename_prefix}_{timestamp}.zip" os.makedirs(os.path.dirname(zip_filename), exist_ok=True) plots = make_3x5_plot(images) with zipfile.ZipFile(zip_filename, 'w') as zipf: # Create a temporary directory to store images and plots temp_dir = f"/tmp/gallery_download/images/{uuid.uuid4()}" os.makedirs(temp_dir) try: # Save images to the temporary directory for i, img in enumerate(images): img = img.convert("RGB") img_path = os.path.join(temp_dir, f"single_{i:04d}.jpg") img.save(img_path) zipf.write(img_path, f"single_{i:04d}.jpg") # Save plots to the temporary directory for i, plot in enumerate(plots): plot = plot.convert("RGB") plot_path = os.path.join(temp_dir, f"grid_{i:04d}.jpg") plot.save(plot_path) zipf.write(plot_path, f"grid_{i:04d}.jpg") finally: # Clean up the temporary directory for file in os.listdir(temp_dir): os.remove(os.path.join(temp_dir, file)) os.rmdir(temp_dir) # Schedule the deletion of the zip file after 24 hours (86400 seconds) delete_file_after_delay(zip_filename, 86400) gr.Info(f"File is ready for download: {os.path.basename(zip_filename)}") return gr.update(value=zip_filename, interactive=True) with gr.Row(): create_file_button = gr.Button("📦 Pack", elem_id="create_file_button", variant='secondary') download_button = gr.DownloadButton(label="📥 Download", value=None, variant='secondary', elem_id="download_button", interactive=False) create_file_button.click(create_zip_file, inputs=[gallery], outputs=[download_button]) def warn_on_click(filename): if filename is None: gr.Warning("No file to download, please `📦 Pack` first.") interactive = filename is not None return gr.update(interactive=interactive) download_button.click(warn_on_click, inputs=[download_button], outputs=[download_button]) return create_file_button, download_button def make_output_images_section(markdown=True, button=True): if markdown: gr.Markdown('### Output Images') output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False) if button: add_rotate_flip_buttons(output_gallery) return output_gallery def make_parameters_section(is_lisa=False, model_ratio=True, ncut_parameter_dropdown=True, tsne_parameter_dropdown=True): gr.Markdown("### Parameters Help") from ncut_pytorch.backbone import list_models, get_demo_model_names model_names = list_models() model_names = sorted(model_names) def get_filtered_model_names(name): return [m for m in model_names if name.lower() in m.lower()] def get_default_model_name(name): lst = get_filtered_model_names(name) if len(lst) > 1: return lst[1] return lst[0] if is_lisa: model_dropdown = gr.Dropdown(["LISA(xinlai/LISA-7B-v1)"], label="Backbone", value="LISA(xinlai/LISA-7B-v1)", elem_id="model_name") layer_slider = gr.Slider(1, 6, step=1, label="LISA decoder: Layer index", value=6, elem_id="layer", visible=False) layer_names = ["dec_0_input", "dec_0_attn", "dec_0_block", "dec_1_input", "dec_1_attn", "dec_1_block"] positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False) negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False) node_type_dropdown = gr.Dropdown(layer_names, label="LISA (SAM) decoder: Layer and Node", value="dec_1_block", elem_id="node_type") else: model_radio = gr.Radio(["CLIP", "DiNO", "Diffusion", "ImageNet", "MAE", "SAM", "Rand"], label="Backbone", value="DiNO", elem_id="model_radio", show_label=True, visible=model_ratio) model_dropdown = gr.Dropdown(get_filtered_model_names("DiNO"), label="", value="DiNO(dino_vitb8_448)", elem_id="model_name", show_label=False) model_radio.change(fn=lambda x: gr.update(choices=get_filtered_model_names(x), value=get_default_model_name(x)), inputs=model_radio, outputs=[model_dropdown]) layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer") positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'") positive_prompt.visible = False negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'") negative_prompt.visible = False node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type") num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for smaller clusters') def change_layer_slider(model_name): # SD2, UNET if "stable" in model_name.lower() and "diffusion" in model_name.lower(): from ncut_pytorch.backbone import SD_KEY_DICT default_layer = 'up_2_resnets_1_block' if 'diffusion-3' not in model_name else 'block_23' return (gr.Slider(1, 49, step=1, label="Diffusion: Timestep (Noise)", value=5, elem_id="layer", visible=True, info="Noise level, 50 is max noise"), gr.Dropdown(SD_KEY_DICT[model_name], label="Diffusion: Layer and Node", value=default_layer, elem_id="node_type", info="U-Net (v1, v2) or DiT (v3)")) if model_name == "LISSL(xinlai/LISSL-7B-v1)": layer_names = ["dec_0_input", "dec_0_attn", "dec_0_block", "dec_1_input", "dec_1_attn", "dec_1_block"] default_layer = "dec_1_block" return (gr.Slider(1, 6, step=1, label="LISA decoder: Layer index", value=6, elem_id="layer", visible=False, info=""), gr.Dropdown(layer_names, label="LISA decoder: Layer and Node", value=default_layer, elem_id="node_type")) layer_dict = LAYER_DICT if model_name in layer_dict: value = layer_dict[model_name] return (gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True, info=""), gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")) else: value = 12 return (gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True, info=""), gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")) model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=[layer_slider, node_type_dropdown]) def change_prompt_text(model_name): if model_name in promptable_diffusion_models: return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=True), gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=True)) return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False), gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False)) model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt]) with gr.Accordion("Advanced Parameters: NCUT", open=False, visible=ncut_parameter_dropdown): gr.Markdown("Docs: How to Get Better Segmentation") affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation") num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation") # sampling_method_dropdown = gr.Dropdown(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method", info="Nyström approximation") sampling_method_dropdown = gr.Radio(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method") # ncut_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric") ncut_metric_dropdown = gr.Radio(["euclidean", "cosine", "rbf"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric") ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation") ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=True, elem_id="ncut_indirect_connection", info="Add indirect connection to the sub-sampled graph") ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization") with gr.Accordion("Advanced Parameters: Visualization", open=False, visible=tsne_parameter_dropdown): # embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method") embedding_method_dropdown = gr.Radio(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method") # embedding_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="t-SNE/UMAP metric", value="euclidean", elem_id="embedding_metric") embedding_metric_dropdown = gr.Radio(["euclidean", "cosine"], label="t-SNE/UMAP: metric", value="cosine", elem_id="embedding_metric") num_sample_tsne_slider = gr.Slider(100, 10000, step=100, label="t-SNE/UMAP: num_sample", value=300, elem_id="num_sample_tsne", info="Nyström approximation") knn_tsne_slider = gr.Slider(1, 100, step=1, label="t-SNE/UMAP: KNN", value=10, elem_id="knn_tsne", info="Nyström approximation") perplexity_slider = gr.Slider(10, 1000, step=10, label="t-SNE: perplexity", value=150, elem_id="perplexity") n_neighbors_slider = gr.Slider(10, 1000, step=10, label="UMAP: n_neighbors", value=150, elem_id="n_neighbors") min_dist_slider = gr.Slider(0.1, 1, step=0.1, label="UMAP: min_dist", value=0.1, elem_id="min_dist") return [model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt] custom_css = """ #unlock_button { all: unset !important; } .form:has(#unlock_button) { all: unset !important; } """ demo = gr.Blocks( theme=gr.themes.Base(spacing_size='md', text_size='lg', primary_hue='blue', neutral_hue='slate', secondary_hue='pink'), # fill_width=False, # title="ncut-pytorch", css=custom_css, ) with demo: with gr.Tab('PlayGround'): eigvecs = gr.State(np.array([])) tsne3d_rgb = gr.State(np.array([])) with gr.Row(): with gr.Column(scale=5, min_width=200): # gr.Markdown("### Step 1: Load Images") input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=100, markdown=False) submit_button.value = "🔴 RUN NCUT" num_images_slider.value = 100 false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False) no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) with gr.Column(scale=5, min_width=200): # gr.Markdown("### Step 2a: Run Backbone and NCUT") # with gr.Accordion(label="Backbone Parameters", visible=True, open=False): output_gallery = gr.Gallery(format='png', value=[], label="NCUT spectral-tSNE", show_label=True, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False) def add_rotate_flip_buttons_with_state(output_gallery, tsne3d_rgb): with gr.Row(): rotate_button = gr.Button("🔄 Rotate", elem_id="rotate_button", variant='secondary') rotate_button.click(sequence_rotate_rgb_gallery, inputs=[output_gallery], outputs=[output_gallery]) def rotate_state(arr): rotation_matrix = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32) return arr @ rotation_matrix rotate_button.click(rotate_state, inputs=[tsne3d_rgb], outputs=[tsne3d_rgb]) flip_button = gr.Button("🔃 Flip", elem_id="flip_button", variant='secondary') flip_button.click(flip_rgb_gallery, inputs=[output_gallery], outputs=[output_gallery]) def flip_state(arr): return 1 - arr flip_button.click(flip_state, inputs=[tsne3d_rgb], outputs=[tsne3d_rgb]) return rotate_button, flip_button add_rotate_flip_buttons_with_state(output_gallery, tsne3d_rgb) [ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section(ncut_parameter_dropdown=True, tsne_parameter_dropdown=True) # submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary') logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False) def __run_fn(*args, **kwargs): eigvecs, rgb, logging_str = run_fn(*args, **kwargs) rgb_gallery = to_pil_images(rgb) # # normalize the eigvecs # eigvecs = torch.tensor(eigvecs) # if torch.cuda.is_available(): # eigvecs = eigvecs.cuda() # eigvecs = F.normalize(eigvecs, p=2, dim=-1) # eigvecs = eigvecs.cpu().numpy() return eigvecs, rgb, rgb_gallery, logging_str submit_button.click( partial(__run_fn, n_ret=2, return_eigvec_and_rgb=True, normalize_eigvec_return=True), inputs=[ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, false_placeholder, no_prompt, no_prompt, no_prompt, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown ], outputs=[eigvecs, tsne3d_rgb, output_gallery, logging_text], ) with gr.Column(scale=5, min_width=200): gr.Markdown('---') gr.Markdown('

Help

') gr.Markdown('---') with gr.Accordion("Instructions", open=True): gr.Markdown(""" 1. Load Dataset (left). 2. Choose parameters (middle). 3. 🔴 RUN NCUT. 4. 🔴 RUN tree. 5. Interact and Inspect (scroll down). """) with gr.Accordion("Methods: NCUT Embedding", open=False): gr.Markdown("### Documentation: How NCUT works") gr.Markdown(""" 1. Run Backbone, feature extraction for each image. 2. Vectorize latent-pixels, concatenate all the images. 3. Run NCUT, on one big graph of all the images. 4. Run spectral-tSNE on the NCUT eigenvectors. 5. Plot the 3D spectral-tSNE as RGB. """) with gr.Accordion("Methods: spectral-tSNE tree", open=False): gr.Markdown(""" 1. Farthest Point Sampling (FPS) on the eigenvectors. 2. spectral-tSNE (2D) on the FPS sampled points. 3. Hierarchical clustering (tree) on the FPS sampled points. """) gr.Markdown('---') run_hierarchical_button = gr.Button("🔴 RUN tree", elem_id="run_hierarchical", variant='primary') with gr.Accordion("Hierarchical Structure Parameters:", open=True): num_sample_fps_slider = gr.Slider(1, 5000, step=1, label="FPS: num_sample", value=1000, elem_id="num_sample_fps") tsne_perplexity_slider = gr.Slider(1, 1000, step=1, label="t-SNE: perplexity", value=500, elem_id="perplexity_tsne") fps_hc_seed_slider = gr.Slider(0, 1000, step=1, label="Seed", value=0, elem_id="fps_hc_seed") tree_method_radio = gr.Radio(["eigvecs", "tsne"], label="Tree Method (input type)", value="tsne", elem_id="tree_method", info="What's the input to build tree? `eigvecs` is loss-less, `tsne` is lossy; tsne make tree looks 'better', eigvecs is more accurate.") tsne_plot = gr.Image(label="spectral-tSNE tree", elem_id="tsne_plot", interactive=False, format='png') tsne_2d_points = gr.State(np.array([])) edges = gr.State(np.array([])) fps_eigvecs = gr.State(np.array([])) fps_indices = gr.State(np.array([])) fps_tsne_rgb = gr.State(np.array([])) def plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, k, hightlight_idx=None, highlight_connections=False): # Plot the t-SNE points fig, ax = plt.subplots(1, 1, figsize=(6, 6)) ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne3d_rgb) # compute the length of the edges lengthes = np.linalg.norm(tsne_embed[edges[:, 0]] - tsne_embed[edges[:, 1]], axis=1) max_length = lengthes[k:].max() diag_length = np.linalg.norm(tsne_embed.max(axis=0) - tsne_embed.min(axis=0)) # draw the edges for i_edge in range(k, len(edges)): edge = edges[i_edge] # _do = np.clip(lengthes[i_edge] / (diag_length*0.3), 0, 1) if lengthes[i_edge] > diag_length*0.1: _do = 1.0 else: _do = 0.0 alpha = 0.7 * (1 - _do) + 0.0 ax.plot(tsne_embed[edge, 0], tsne_embed[edge, 1], 'k-', lw=1, alpha=alpha) # highlight the selected node if hightlight_idx is not None: if highlight_connections: from fps_cluster import find_connected_component _edges = edges[k:, :] connected_nodes = find_connected_component(_edges, hightlight_idx) ax.scatter(tsne_embed[connected_nodes, 0], tsne_embed[connected_nodes, 1], s=50, c=fps_tsne3d_rgb[connected_nodes], marker='D', edgecolor='deeppink', linewidth=1) # ax.scatter(tsne_embed[hightlight_idx, 0], tsne_embed[hightlight_idx, 1], s=300, c='r', marker='x') ax.scatter(tsne_embed[hightlight_idx, 0], tsne_embed[hightlight_idx, 1], s=200, c='cyan', marker='o', edgecolor='black', linewidth=1) ax.set_xticks([]) ax.set_yticks([]) ax.axis('off') ax.set_xlim(tsne_embed[:, 0].min()*1.1, tsne_embed[:, 0].max()*1.1) ax.set_ylim(tsne_embed[:, 1].min()*1.1, tsne_embed[:, 1].max()*1.1) # Remove the white space around the plot fig.tight_layout(pad=0) # Save the plot to an in-memory buffer buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) # Load the image into a NumPy array image = np.array(Image.open(buf)) # Close the buffer and plot buf.close() plt.close(fig) pil_image = Image.fromarray(image) return pil_image def get_top1_heatmap_for_each_dot(images, eigvecs, fps_eigvecs, max_display_dots, fps_tsne_rgb, tsne_embed): n_dots = fps_eigvecs.shape[0] if n_dots > max_display_dots: dots_idx = np.random.choice(n_dots, max_display_dots, replace=False) import fpsample dots_idx = fpsample.bucket_fps_kdline_sampling(tsne_embed, max_display_dots, 5).astype(np.int64) else: dots_idx = np.arange(n_dots) fps_eigvecs = fps_eigvecs[dots_idx] fps_tsne_rgb = fps_tsne_rgb[dots_idx] heatmaps = eigvecs @ fps_eigvecs.T # [B, H, W, C] @ [N, C] -> [B, H, W, N] value = heatmaps.mean(1).mean(1) # [B, N] top1_image_idxs = value.argmax(axis=0) # [N] def pad_image_with_border(image, border_color, border_width): new_image = np.ones((image.shape[0] + 2 * border_width, image.shape[1] + 2 * border_width, image.shape[2]), dtype=image.dtype) new_image[:, :] = border_color new_image[border_width:-border_width, border_width:-border_width] = image return new_image top1_image_blended = [] cm = matplotlib.colormaps['hot'] for i_fps in range(len(top1_image_idxs)): image_idx = top1_image_idxs[i_fps] image = images[image_idx] heatmap = heatmaps[image_idx, :, :, i_fps] heatmap = cm(heatmap) heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8) image = image.convert("RGB").resize((256, 256)) heatmap = Image.fromarray(heatmap).resize((256, 256)).convert("RGB") blended = 0.5 * np.array(image) + 0.5 * np.array(heatmap) blended = np.clip(blended, 0, 255).astype(np.uint8) border_color = fps_tsne_rgb[i_fps, :3] * 255 border_width = 20 padded_image = pad_image_with_border(blended, border_color, border_width) top1_image_blended.append(padded_image) return top1_image_blended, dots_idx def plot_tsne_with_image_heatmaps(images, eigvecs, fps_eigvecs, tsne_embed, fps_tsne_rgb, max_display_dots=100): top1_image_blended, dots_idx = get_top1_heatmap_for_each_dot(images, eigvecs, fps_eigvecs, max_display_dots, fps_tsne_rgb, tsne_embed) # Plot the t-SNE points fig, ax = plt.subplots(1, 1, figsize=(15, 15)) ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne_rgb) ax.set_xticks([]) ax.set_yticks([]) ax.axis('off') ax.set_xlim(tsne_embed[:, 0].min()*1.1, tsne_embed[:, 0].max()*1.1) ax.set_ylim(tsne_embed[:, 1].min()*1.1, tsne_embed[:, 1].max()*1.1) # Add the top1_image_blended to the scatter plot for i, (x, y) in enumerate(tsne_embed[dots_idx]): img = top1_image_blended[i] img = np.array(img) imgbox = OffsetImage(img, zoom=0.15) ab = AnnotationBbox(imgbox, (x, y), frameon=False) ax.add_artist(ab) ax.scatter(tsne_embed[:, 0], tsne_embed[:, 1], s=20, c=fps_tsne_rgb) # Remove the white space around the plot fig.tight_layout(pad=0) # Save the plot to an in-memory buffer buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) # Load the image into a NumPy array image = np.array(Image.open(buf)) # Close the buffer and plot buf.close() plt.close(fig) pil_image = Image.fromarray(image) return pil_image def run_fps_tsne_hierarchical(image_gallery, eigvecs, num_sample_fps, perplexity_tsne, tsne3d_rgb, seed=0, tree_method='eigvecs', max_display_dots=300): if len(eigvecs) == 0: gr.Warning("Please run NCUT first.") return images = [image[0] for image in image_gallery] if isinstance(images[0], str): images = [Image.open(image) for image in images] eigvecs = torch.tensor(eigvecs) _eigvecs = eigvecs.reshape(-1, eigvecs.shape[-1]) gr.Info("Running FPS, t-SNE, and Hierarchical Clustering...", 3) from ncut_pytorch.ncut_pytorch import farthest_point_sampling from sklearn.manifold import TSNE from fps_cluster import build_tree torch.manual_seed(seed) np.random.seed(seed) fps_idx = farthest_point_sampling(_eigvecs, num_sample_fps) fps_eigvecs = _eigvecs[fps_idx] fps_eigvecs = fps_eigvecs.numpy() tsne3d_rgb = tsne3d_rgb.reshape(-1, 3) fps_tsne3d_rgb = tsne3d_rgb[fps_idx] np.random.seed(seed) tsne_embed = TSNE( n_components=2, perplexity=perplexity_tsne, metric='cosine', random_state=seed, ).fit_transform(fps_eigvecs) # normalize = [-1, 1] tsne_embed[:, 0] = (tsne_embed[:, 0] - tsne_embed[:, 0].min()) / (tsne_embed[:, 0].max() - tsne_embed[:, 0].min()) * 2 - 1 tsne_embed[:, 1] = (tsne_embed[:, 1] - tsne_embed[:, 1].min()) / (tsne_embed[:, 1].max() - tsne_embed[:, 1].min()) * 2 - 1 if tree_method == 'eigvecs': edges = build_tree(fps_eigvecs, dist='cosine') if tree_method == 'tsne': edges = build_tree(tsne_embed, dist='euclidean') # Plot the t-SNE points pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne3d_rgb, 0) # Plot the t-SNE points with image heatmaps big_pil_image = plot_tsne_with_image_heatmaps(images, eigvecs, fps_eigvecs, tsne_embed, fps_tsne3d_rgb, max_display_dots) return tsne_embed, edges, fps_eigvecs, fps_tsne3d_rgb, fps_idx, pil_image, gr.update(value={'image': big_pil_image, 'points': []}, interactive=True) gr.Markdown('---') gr.Markdown('

↓ interactively inspect the hierarchical structure

') gr.Markdown('---') # big_tsne_plot = gr.Image(label="spectral-tSNE tree [+ Cluster Heatmap]", elem_id="big_tsne_plot", interactive=False, format='png') tsne_image_plot = ImagePrompter(show_label=True, elem_id="tsne_image_plot", interactive=False, label="spectral-tSNE tree [+ Cluster Heatmap]") run_hierarchical_button.click( run_fps_tsne_hierarchical, inputs=[input_gallery, eigvecs, num_sample_fps_slider, tsne_perplexity_slider, tsne3d_rgb, fps_hc_seed_slider, tree_method_radio], outputs=[tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, tsne_plot, tsne_image_plot], ) with gr.Row(): with gr.Column(scale=5, min_width=200) as tsne_select: gr.Markdown('---') gr.Markdown('

Please click on the image blow ↓

') gr.Markdown('---') tsne_prompt_image = ImagePrompter(show_label=True, elem_id="tsne_prompt_image", interactive=False, label="spectral-tSNE tree") # copy plot to tsne_prompt_image on change # tsne_plot.change(fn=lambda x: gr.update(value={'image': x}, interactive=True), # inputs=[tsne_plot], outputs=[tsne_prompt_image]) with gr.Column(scale=5, min_width=200) as image_select: gr.Markdown('---') gr.Markdown('

Please click on the image blow ↓

') gr.Markdown('---') image_plot = ImagePrompter(show_label=True, elem_id="image_plot", interactive=False, label="NCUT spectral-tSNE") image_slider = gr.Slider(0, 100, step=1, label="Image Index", value=0, elem_id="image_slider", interactive=True) def update_image_prompt(image_slider, output_gallery): if output_gallery is None: return gr.update(value=None, interactive=False) if len(output_gallery) == 0: return gr.update(value=None, interactive=False) image_idx = int(image_slider) image = output_gallery[image_idx][0] return gr.update(value={'image': image}, interactive=True) image_slider.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot]) output_gallery.change(fn=update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot]) output_gallery.change(fn=lambda x: gr.update(maximum=len(x)-1, interactive=True), inputs=[output_gallery], outputs=[image_slider]) with gr.Column(scale=5, min_width=200) as tsne_image_select: gr.Markdown('---') gr.Markdown('

Please click on the image above ↑

') gr.Markdown('---') tsne_non_prompt_image = gr.Image(label="spectral-tSNE tree", elem_id="tsne_non_prompt_image", interactive=False, format='png') with gr.Column(scale=5, min_width=200): gr.Markdown('

Help

') with gr.Accordion("Instructions", open=True): gr.Markdown(""" 1. Click one dot on the image. - Only the last clicked dot will be used - Eraser is at top-right corner - Use the right-side Radio to switch tree/image 2. Choose a granularity (right-side). 3. 🔴 RUN Inspection. 4. Output will be shown below. """) gr.Markdown("Known Issue: Resize the browser window will break the clicking, please refresh the page.") with gr.Accordion("Outputs", open=True): gr.Markdown(""" 1. spectral-tSNE tree: ◆ marker is the N points, connected components to the clicked . 2. Cluster Heatmap: max of N cosine similarity to N points in the connected components. """) with gr.Column(scale=5, min_width=200): prompt_radio = gr.Radio(["Tree [+Image]", "Image"], label="Where to click on?", value="Tree [+Image]", elem_id="prompt_radio", show_label=True) granularity_slider = gr.Slider(1, 1000, step=1, label="Cluster Granularity (k)", value=100, elem_id="granularity") num_sample_fps_slider.change(fn=lambda x: gr.update(maximum=x, interactive=True), inputs=[num_sample_fps_slider], outputs=[granularity_slider]) def updaste_tsne_plot_change_granularity(granularity, tsne_embed, edges, fps_tsne_rgb, tsne_prompt_image): # Plot the t-SNE points pil_image = plot_tsne_tree(tsne_embed, edges, fps_tsne_rgb, granularity) return gr.update(value=pil_image, label=f"spectral-tSNE tree [k={granularity}]") granularity_slider.change(updaste_tsne_plot_change_granularity, inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb, tsne_prompt_image], outputs=[tsne_non_prompt_image]) tsne_plot.change(updaste_tsne_plot_change_granularity, inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb], outputs=[tsne_non_prompt_image]) prompt_radio.change(update_image_prompt, inputs=[image_slider, output_gallery], outputs=[image_plot]) # prompt_radio.change(updaste_tsne_plot_change_granularity, # inputs=[granularity_slider, tsne_2d_points, edges, fps_tsne_rgb, tsne_prompt_image], # outputs=[tsne_non_prompt_image]) run_inspection_button = gr.Button("🔴 RUN Inspection", elem_id="run_inspection", variant='primary') inspect_logging_text = gr.Textbox("Logging information", lines=3, label="Logging", elem_id="inspect_logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False) # output_slot_radio = gr.Radio([1, 2, 3], label="Output Row", value=1, elem_id="output_slot", show_label=True) delete_all_output_button = gr.Button("❌ Delete All Output", elem_id="delete_all_output", variant='secondary') tsne_image_select.visible = True tsne_select.visible = False image_select.visible = False prompt_radio.change(fn=lambda x: gr.update(visible=x=="Image"), inputs=prompt_radio, outputs=[image_select]) prompt_radio.change(fn=lambda x: gr.update(visible=x=="Tree [+Image]"), inputs=prompt_radio, outputs=[tsne_image_select]) MAX_ROWS = 20 current_output_row = gr.State(0) output_row_occupy = gr.State([False] * MAX_ROWS) def make_one_output_row(output_row_occupy, i_row=1): with gr.Row() as inspect_output_row: with gr.Column(scale=5, min_width=200): output_tree_image = gr.Image(label=f"spectral-tSNE tree [row#{i_row}]", elem_id="output_image", interactive=False) text_block = gr.Textbox("", label="Logging", elem_id=f"logging_{i_row}", type="text", placeholder="Logging information", autofocus=False, autoscroll=False, lines=2, show_label=False) delete_button = gr.Button("❌ Delete", elem_id=f"delete_button_{i_row}", variant='secondary') with gr.Column(scale=10, min_width=200): heatmap_gallery = gr.Gallery(format='png', value=[], label=f"Cluster Heatmap [row#{i_row}]", show_label=True, elem_id="heatmap", columns=[6], rows=[1], object_fit="contain", height="550px", show_share_button=True, interactive=False) def delete_a_row(output_row_occupy, i_row=1): # output_row_occupy[i_row-1] = False return output_row_occupy, gr.update(visible=False) delete_button.click(partial(delete_a_row, i_row=i_row), output_row_occupy, outputs=[output_row_occupy, inspect_output_row]) return inspect_output_row, output_tree_image, heatmap_gallery, text_block gr.Markdown('---') inspect_output_rows, output_tree_images, heatmap_galleries, text_blocks = [], [], [], [] for i_row in range(MAX_ROWS, 0, -1): inspect_output_row, output_tree_image, heatmap_gallery, text_block = make_one_output_row(output_row_occupy, i_row) inspect_output_row.visible = False inspect_output_rows.append(inspect_output_row) output_tree_images.append(output_tree_image) heatmap_galleries.append(heatmap_gallery) text_blocks.append(text_block) def delete_all_output(output_row_occupy): n_rows = len(output_row_occupy) output_row_occupy = [False] * n_rows return output_row_occupy, 0, *[gr.update(visible=False) for _ in range(n_rows)] delete_all_output_button.click(delete_all_output, inputs=[output_row_occupy], outputs=[output_row_occupy, current_output_row, *inspect_output_rows]) def relative_xy_last_positive(prompts): image = prompts['image'] points = np.asarray(prompts['points']) if points.shape[0] == 0: return [], [] is_point = points[:, 5] == 4.0 points = points[is_point] is_positive = points[:, 2] == 1.0 if is_positive.sum() == 0: raise Exception("No blue point is selected.") is_negative = points[:, 2] == 0.0 xy = points[:, :2].tolist() if isinstance(image, str): image = Image.open(image) image = np.array(image) h, w = image.shape[:2] new_xy = [(x/w, y/h) for x, y in xy] last_positive_idx = np.where(is_positive)[0][-1] x, y = new_xy[last_positive_idx] return x, y def find_closest_fps_point_for_tsne_tree_plot(tsne_prompt, tsne2d_embed): x, y = relative_xy_last_positive(tsne_prompt) _x_ratio, _y_ratio = x, y x_vmax = tsne2d_embed[:, 0].max() * 1.1 x_vmin = tsne2d_embed[:, 0].min() * 1.1 y_vmax = tsne2d_embed[:, 1].max() * 1.1 y_vmin = tsne2d_embed[:, 1].min() * 1.1 x = x * (x_vmax - x_vmin) + x_vmin y = 1 - y y = y * (y_vmax - y_vmin) + y_vmin dist = np.linalg.norm(tsne2d_embed - np.array([x, y]), axis=1) closest_idx = np.argmin(dist) return closest_idx, (_x_ratio, _y_ratio) def find_closest_fps_point_for_image_prompt(image_prompt, i_image, eigvecs, fps_eigvecs): x, y = relative_xy_last_positive(image_prompt) _x_ratio, _y_ratio = x, y _eigvec = eigvecs[i_image] h, w = _eigvec.shape[:2] x = int(x * w) y = int(y * h) eigvec = _eigvec[y, x] sim = fps_eigvecs @ eigvec closest_idx = np.argmax(sim) return closest_idx, (_x_ratio, _y_ratio) def find_closest_fps_point(prompt_radio, tsne_image_prompt, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs): try: if prompt_radio == "Tree": return find_closest_fps_point_for_tsne_tree_plot(tsne_prompt, tsne2d_embed) if prompt_radio == "Image": return find_closest_fps_point_for_image_prompt(image_prompt, i_image, eigvecs, fps_eigvecs) if prompt_radio == "Tree [+Image]": return find_closest_fps_point_for_tsne_tree_plot(tsne_image_prompt, tsne2d_embed) except: raise gr.Error("""No blue point is selected.
Please left-click on the image to select a blue point.
After reloading the image (e.g., change granularity), please use the eraser to remove the previous point, then click on the image to select a blue point.""") def run_inspection(tsne_image_prompt, tsne_prompt, image_prompt, prompt_radio, current_output_row, tsne2d_embed, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity, eigvecs, i_image, tsne3d_rgb, input_gallery, output_row_occupy, max_rows=MAX_ROWS): if len(tsne2d_embed) == 0: raise gr.Error("Please run FPS+Cluster first.") closest_idx, (_x, _y) = find_closest_fps_point(prompt_radio, tsne_image_prompt, tsne_prompt, image_prompt, i_image, tsne2d_embed, eigvecs, fps_eigvecs) closest_rgb = fps_tsne_rgb[closest_idx] closest_rgb = (closest_rgb * 255).astype(np.uint8) from fps_cluster import find_connected_component connected_idxs = find_connected_component(edges[granularity:], closest_idx) logging_text = f"Clicked: idx={closest_idx}, xy=[{_x:.2f}, {_y:.2f}], RGB={closest_rgb}" logging_text += f"\nGranularity: k={granularity}, Connected: n={len(connected_idxs)}" output_tsne_plot = plot_tsne_tree(tsne2d_embed, edges, fps_tsne_rgb, granularity, closest_idx, highlight_connections=True) # draw heatmap for the connected components ## cosine distance connected_eigvecs = fps_eigvecs[connected_idxs] left = eigvecs.astype(np.float32) # B H W C right = connected_eigvecs.astype(np.float32) # N C # left = F.normalize(left, p=2, dim=-1) # right = F.normalize(right, p=2, dim=-1) # eigvec is already normalized when saved to gr.State similarity = left @ right.T similarity = similarity.max(axis=-1) # B H W N ## euclidean distance # b, h, w = tsne3d_rgb.shape[:3] # tsne3d_rgb = tsne3d_rgb.reshape(b*h*w, 3) # connected_rgb = tsne3d_rgb[fps_indices][connected_idxs] # left = torch.tensor(tsne3d_rgb).float() # (B H W) 3 # right = torch.tensor(connected_rgb).float() # N 3 # # dist B H W N # dist = left[:, None] - right[None] # dist = torch.sqrt((dist ** 2).sum(dim=-1)) # dist = dist.min(dim=-1).values # B H W # dist = dist.reshape(b, h, w) # gr.Info(f"dist: min={dist.min().item()}, max={dist.max().item()}, mean={dist.mean().item()}", 3) # similarity = 1 - dist hot_map = matplotlib.colormaps['hot'] heatmap = hot_map(similarity)[..., :3] # B H W 3 heatmap_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True) # overlay input images on the heatmap input_images = [x[0] for x in input_gallery] if isinstance(input_images[0], str): input_images = [Image.open(x) for x in input_images] for i, img in enumerate(input_images): _img = img.resize((256, 256)).convert('RGB') _heatmap = heatmap_images[i].resize((256, 256)).convert('RGB') blend = np.array(_img) * 0.5 + np.array(_heatmap) * 0.5 blend = Image.fromarray(blend.astype(np.uint8)) heatmap_images[i] = blend # find the output slot # search from the last row found_flag = False for i_slot in range(max_rows-1, -1, -1): if not output_row_occupy[i_slot]: found_flag = True break if not found_flag: i_slot = 0 gr.Warning("Output slots are full, Overwriting the first row. Please use '❌ Delete All Output' to clear all outputs.") output_row_occupy[i_slot] = True # tree_label = f"spectral-tSNE tree [row#{max_rows-output_slot}] k={granularity} idx={closest_idx} n={len(connected_idxs)}" tree_label = f"spectral-tSNE tree [row#{current_output_row+1}]" heatmap_label = f"Cluster Heatmap [row#{current_output_row+1}] k={granularity} n={len(connected_idxs)} xy=[{_x:.2f}, {_y:.2f}] idx={closest_idx}" # update the output slots output_rows = [gr.update() for _ in range(max_rows)] output_tsne_plots = [gr.update() for _ in range(max_rows)] output_heatmaps = [gr.update() for _ in range(max_rows)] output_texts = [gr.update() for _ in range(max_rows)] output_rows[i_slot] = gr.update(visible=True) output_tsne_plots[i_slot] = gr.update(value=output_tsne_plot, label=tree_label) output_heatmaps[i_slot] = gr.update(value=heatmap_images, label=heatmap_label) output_texts[i_slot] = gr.update(value=logging_text) # gr.Info(f"Output in [row#{max_rows-output_slot}]", 3) logging_text += f"\nOutput: [row#{current_output_row+1}]" current_output_row += 1 return *output_rows, *output_tsne_plots, *output_heatmaps, *output_texts, current_output_row, output_row_occupy, logging_text run_inspection_button.click( run_inspection, inputs=[tsne_image_plot, tsne_prompt_image, image_plot, prompt_radio, current_output_row, tsne_2d_points, edges, fps_eigvecs, fps_tsne_rgb, fps_indices, granularity_slider, eigvecs, image_slider, tsne3d_rgb, input_gallery, output_row_occupy], outputs=inspect_output_rows + output_tree_images + heatmap_galleries + text_blocks + [current_output_row, output_row_occupy, inspect_logging_text], ) with gr.Tab('PlayGround (eig)', visible=True) as test_playground_tab2: eigvecs = gr.State(np.array([])) with gr.Row(): with gr.Column(scale=5, min_width=200): gr.Markdown("### Step 1: Load Images") input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=10) submit_button.visible = False num_images_slider.value = 30 false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False) no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) with gr.Column(scale=5, min_width=200): gr.Markdown("### Step 2a: Run Backbone and NCUT") with gr.Accordion(label="Backbone Parameters", visible=True, open=False): [ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section(ncut_parameter_dropdown=False, tsne_parameter_dropdown=False) num_eig_slider.value = 1024 num_eig_slider.visible = False submit_button = gr.Button("🔴 RUN NCUT", elem_id="run_ncut", variant='primary') logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False) submit_button.click( partial(run_fn, n_ret=1, only_eigvecs=True), inputs=[ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, false_placeholder, no_prompt, no_prompt, no_prompt, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown ], outputs=[eigvecs, logging_text], ) gr.Markdown("### Step 2b: Pick an Image and Draw a Point") from gradio_image_prompter import ImagePrompter image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True) load_one_image_button = gr.Button("🔴 Load Image", elem_id="load_one_image_button", variant='primary') gr.Markdown("""
🖱️ Left Click: Foreground
""") prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False) def update_prompt_image(original_images, index): images = original_images if images is None: return gr.update() total_len = len(images) if total_len == 0: return gr.update() if index >= total_len: index = total_len - 1 return gr.update(value={'image': images[index][0], 'points': []}, interactive=True) load_one_image_button.click(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1]) input_gallery.change(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1]) input_gallery.change(fn=lambda x: gr.update(maximum=len(x)-1), inputs=[input_gallery], outputs=[image1_slider]) image1_slider.change(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1]) child_idx = gr.State([]) current_idx = gr.State(None) n_eig = gr.State(64) with gr.Column(scale=5, min_width=200): gr.Markdown("### Step 3: Check groupping") child_distance_slider = gr.Slider(0, 0.5, step=0.001, label="Child Distance", value=0.1, elem_id="child_distance_slider", interactive=True) child_distance_slider.visible = False overlay_image_checkbox = gr.Checkbox(label="Overlay Image", value=True, elem_id="overlay_image_checkbox", interactive=True) n_eig_slider = gr.Slider(0, 1024, step=1, label="Number of Eigenvectors", value=256, elem_id="n_eig_slider", interactive=True) run_button = gr.Button("🔴 RUN", elem_id="run_groupping", variant='primary') gr.Markdown("1. 🔴 RUN
2. repeat: [+num_eigvecs] / [-num_eigvecs]") with gr.Row(): doublue_eigs_button = gr.Button("[+num_eigvecs]", elem_id="doublue_eigs_button", variant='secondary') half_eigs_button = gr.Button("[-num_eigvecs]", elem_id="half_eigs_button", variant='secondary') current_plot = gr.Gallery(value=None, label="Current", show_label=True, elem_id="current_plot", interactive=False, rows=[1], columns=[2]) def relative_xy(prompts): image = prompts['image'] points = np.asarray(prompts['points']) if points.shape[0] == 0: return [], [] is_point = points[:, 5] == 4.0 points = points[is_point] is_positive = points[:, 2] == 1.0 is_negative = points[:, 2] == 0.0 xy = points[:, :2].tolist() if isinstance(image, str): image = Image.open(image) image = np.array(image) h, w = image.shape[:2] new_xy = [(x/w, y/h) for x, y in xy] # print(new_xy) return new_xy, is_positive def xy_eigvec(prompts, image_idx, eigvecs): eigvec = eigvecs[image_idx] xy, is_positive = relative_xy(prompts) for i, (x, y) in enumerate(xy): if not is_positive[i]: continue x = int(x * eigvec.shape[1]) y = int(y * eigvec.shape[0]) return eigvec[y, x], (y, x) from ncut_pytorch.ncut_pytorch import _transform_heatmap def _run_heatmap_fn(images, eigvecs, prompt_image_idx, prompt_points, n_eig, flat_idx=None, raw_heatmap=False, overlay_image=True): left = eigvecs[..., :n_eig] if flat_idx is not None: right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx] y, x = None, None else: right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs) right = right[:n_eig] left = F.normalize(left, p=2, dim=-1) _right = F.normalize(right, p=2, dim=-1) heatmap = left @ _right.unsqueeze(-1) heatmap = heatmap.squeeze(-1) # heatmap = 1 - heatmap # heatmap = _transform_heatmap(heatmap) if raw_heatmap: return heatmap # apply hot colormap and covert to PIL image 256x256 # gr.Info(f"heatmap vmin: {heatmap.min()}, vmax: {heatmap.max()}, mean: {heatmap.mean()}") heatmap = heatmap.cpu().numpy() hot_map = matplotlib.colormaps['hot'] heatmap = hot_map(heatmap) pil_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True) if overlay_image: overlaied_images = [] for i_image in range(len(images)): rgb_image = images[i_image].resize((256, 256)) rgb_image = np.array(rgb_image) heatmap_image = np.array(pil_images[i_image])[..., :3] blend_image = 0.5 * rgb_image + 0.5 * heatmap_image blend_image = Image.fromarray(blend_image.astype(np.uint8)) overlaied_images.append(blend_image) pil_images = overlaied_images return pil_images, (y, x) @torch.no_grad() def run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True): gr.Info(f"current number of eigenvectors: {n_eig}", 2) eigvecs = torch.tensor(eigvecs) image1_slider = min(image1_slider, len(images)-1) images = [image[0] for image in images] if isinstance(images[0], str): images = [Image.open(image[0]).convert("RGB").resize((256, 256)) for image in images] current_heatmap, (y, x) = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, n_eig, flat_idx, overlay_image=overlay_image) return current_heatmap def doublue_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True): n_eig = int(n_eig*2) n_eig = min(n_eig, eigvecs.shape[-1]) n_eig = max(n_eig, 1) return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx, overlay_image=overlay_image) def half_eigs_wrapper(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx=None, overlay_image=True): n_eig = int(n_eig/2) n_eig = min(n_eig, eigvecs.shape[-1]) n_eig = max(n_eig, 1) return gr.update(value=n_eig), run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx, overlay_image=overlay_image) none_placeholder = gr.State(None) run_button.click( run_heatmap, inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox], outputs=[current_plot], ) doublue_eigs_button.click( doublue_eigs_wrapper, inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, none_placeholder, overlay_image_checkbox], outputs=[n_eig_slider, current_plot], ) half_eigs_button.click( half_eigs_wrapper, inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig_slider, child_distance_slider, current_idx, overlay_image_checkbox], outputs=[n_eig_slider, current_plot], ) with gr.Tab('AlignedCut'): with gr.Row(): with gr.Column(scale=5, min_width=200): input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section() num_images_slider.value = 30 logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False) with gr.Column(scale=5, min_width=200): output_gallery = make_output_images_section() # cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False) [ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section() false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False) no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) submit_button.click( partial(run_fn, n_ret=1, plot_clusters=False), inputs=[ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, false_placeholder, no_prompt, no_prompt, no_prompt, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown ], outputs=[output_gallery, logging_text], api_name="API_AlignedCut", scroll_to_output=True, ) with gr.Tab('AlignedCut (Advanced)', visible=False) as tab_alignedcut_advanced: with gr.Row(): with gr.Column(scale=5, min_width=200): input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True) num_images_slider.value = 100 logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False, lines=20) with gr.Column(scale=5, min_width=200): output_gallery = make_output_images_section() add_download_button(output_gallery, "ncut_embed") norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False) add_download_button(norm_gallery, "eig_norm") cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False) add_download_button(cluster_gallery, "clusters") [ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section() num_eig_slider.value = 100 false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False) no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) submit_button.click( partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True, advanced=True), inputs=[ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, false_placeholder, no_prompt, no_prompt, no_prompt, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown ], outputs=[output_gallery, cluster_gallery, norm_gallery, logging_text], scroll_to_output=True, ) with gr.Tab('NCut'): gr.Markdown('#### NCut (Legacy), not aligned, no Nyström approximation') gr.Markdown('Each image is solved independently, color is not aligned across images') gr.Markdown('---') gr.Markdown('

NCut vs. AlignedCut

') with gr.Row(): with gr.Column(scale=5, min_width=200): gr.Markdown('#### Pros') gr.Markdown('- Easy Solution. Use less eigenvectors.') gr.Markdown('- Exact solution. No Nyström approximation.') with gr.Column(scale=5, min_width=200): gr.Markdown('#### Cons') gr.Markdown('- Not aligned. Distance is not preserved across images. No pseudo-labeling or correspondence.') gr.Markdown('- Poor complexity scaling. Unable to handle large number of pixels.') gr.Markdown('---') with gr.Row(): with gr.Column(scale=5, min_width=200): gr.Markdown(' ') with gr.Column(scale=5, min_width=200): gr.Markdown('color is not aligned across images 👇') with gr.Row(): with gr.Column(scale=5, min_width=200): input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section() with gr.Column(scale=5, min_width=200): output_gallery = make_output_images_section() [ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section() old_school_ncut_checkbox = gr.Checkbox(label="Old school NCut", value=True, elem_id="old_school_ncut") invisible_list = [old_school_ncut_checkbox, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, num_sample_tsne_slider, knn_tsne_slider, sampling_method_dropdown, ncut_metric_dropdown] for item in invisible_list: item.visible = False # logging text box logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information") false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False) no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) submit_button.click( run_fn, inputs=[ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, false_placeholder, no_prompt, no_prompt, no_prompt, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, old_school_ncut_checkbox ], outputs=[output_gallery, logging_text], api_name="API_NCut", ) with gr.Tab('RecursiveCut'): gr.Markdown('NCUT can be applied recursively, the eigenvectors from previous iteration is the input for the next iteration NCUT. ') gr.Markdown('__Recursive NCUT__ can amplify or weaken the connections, depending on the `affinity_focal_gamma` setting, please see [Documentation](https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/#recursive-ncut)') gr.Markdown('---') with gr.Row(): with gr.Column(scale=5, min_width=200): input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section() num_images_slider.value = 100 logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information") with gr.Column(scale=5, min_width=200): gr.Markdown('### Output (Recursion #1)') l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False) add_rotate_flip_buttons(l1_gallery) with gr.Column(scale=5, min_width=200): gr.Markdown('### Output (Recursion #2)') l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False) add_rotate_flip_buttons(l2_gallery) with gr.Column(scale=5, min_width=200): gr.Markdown('### Output (Recursion #3)') l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False) add_rotate_flip_buttons(l3_gallery) with gr.Row(): with gr.Column(scale=5, min_width=200): with gr.Accordion("➡️ Recursion config", open=True): l1_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #1: N eigenvectors", value=100, elem_id="l1_num_eig") l2_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #2: N eigenvectors", value=50, elem_id="l2_num_eig") l3_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #3: N eigenvectors", value=50, elem_id="l3_num_eig") metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="Recursion distance metric", value="cosine", elem_id="recursion_metric") l1_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #1: Affinity focal gamma", value=0.7, elem_id="recursion_l1_gamma") l2_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #2: Affinity focal gamma", value=0.7, elem_id="recursion_l2_gamma") l3_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #3: Affinity focal gamma", value=0.5, elem_id="recursion_l3_gamma") with gr.Column(scale=5, min_width=200): [ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section() num_eig_slider.visible = False affinity_focal_gamma_slider.visible = False true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder") true_placeholder.visible = False false_placeholder = gr.Checkbox(label="False placeholder", value=False, elem_id="false_placeholder") false_placeholder.visible = False number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder") number_placeholder.visible = False no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) submit_button.click( partial(run_fn, n_ret=3), inputs=[ input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, false_placeholder, no_prompt, no_prompt, no_prompt, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, false_placeholder, number_placeholder, true_placeholder, l2_num_eig_slider, l3_num_eig_slider, metric_dropdown, l1_affinity_focal_gamma_slider, l2_affinity_focal_gamma_slider, l3_affinity_focal_gamma_slider ], outputs=[l1_gallery, l2_gallery, l3_gallery, logging_text], api_name="API_RecursiveCut" ) with gr.Tab('RecursiveCut (Advanced)', visible=False) as tab_recursivecut_advanced: with gr.Row(): with gr.Column(scale=5, min_width=200): input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True) num_images_slider.value = 100 logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", lines=20) with gr.Column(scale=5, min_width=200): gr.Markdown('### Output (Recursion #1)') l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False) add_rotate_flip_buttons(l1_gallery) add_download_button(l1_gallery, "ncut_embed_recur1") l1_norm_gallery = gr.Gallery(value=[], label="Recursion #1 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False) add_download_button(l1_norm_gallery, "eig_norm_recur1") l1_cluster_gallery = gr.Gallery(value=[], label="Recursion #1 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='450px', show_share_button=True, preview=False, interactive=False) add_download_button(l1_cluster_gallery, "clusters_recur1") with gr.Column(scale=5, min_width=200): gr.Markdown('### Output (Recursion #2)') l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False) add_rotate_flip_buttons(l2_gallery) add_download_button(l2_gallery, "ncut_embed_recur2") l2_norm_gallery = gr.Gallery(value=[], label="Recursion #2 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False) add_download_button(l2_norm_gallery, "eig_norm_recur2") l2_cluster_gallery = gr.Gallery(value=[], label="Recursion #2 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='450px', show_share_button=True, preview=False, interactive=False) add_download_button(l2_cluster_gallery, "clusters_recur2") with gr.Column(scale=5, min_width=200): gr.Markdown('### Output (Recursion #3)') l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False) add_rotate_flip_buttons(l3_gallery) add_download_button(l3_gallery, "ncut_embed_recur3") l3_norm_gallery = gr.Gallery(value=[], label="Recursion #3 Eigenvector Magnitude", show_label=True, elem_id="eig_norm", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False) add_download_button(l3_norm_gallery, "eig_norm_recur3") l3_cluster_gallery = gr.Gallery(value=[], label="Recursion #3 Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[4], object_fit="contain", height='450px', show_share_button=True, preview=False, interactive=False) add_download_button(l3_cluster_gallery, "clusters_recur3") with gr.Row(): with gr.Column(scale=5, min_width=200): with gr.Accordion("➡️ Recursion config", open=True): l1_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #1: N eigenvectors", value=100, elem_id="l1_num_eig") l2_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #2: N eigenvectors", value=50, elem_id="l2_num_eig") l3_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #3: N eigenvectors", value=50, elem_id="l3_num_eig") metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="Recursion distance metric", value="cosine", elem_id="recursion_metric") l1_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #1: Affinity focal gamma", value=0.7, elem_id="recursion_l1_gamma") l2_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #2: Affinity focal gamma", value=0.7, elem_id="recursion_l2_gamma") l3_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #3: Affinity focal gamma", value=0.5, elem_id="recursion_l3_gamma") with gr.Column(scale=5, min_width=200): [ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section() num_eig_slider.visible = False affinity_focal_gamma_slider.visible = False true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder") true_placeholder.visible = False false_placeholder = gr.Checkbox(label="False placeholder", value=False, elem_id="false_placeholder") false_placeholder.visible = False number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder") number_placeholder.visible = False no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) submit_button.click( partial(run_fn, n_ret=9, advanced=True), inputs=[ input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, false_placeholder, no_prompt, no_prompt, no_prompt, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, false_placeholder, number_placeholder, true_placeholder, l2_num_eig_slider, l3_num_eig_slider, metric_dropdown, l1_affinity_focal_gamma_slider, l2_affinity_focal_gamma_slider, l3_affinity_focal_gamma_slider ], outputs=[l1_gallery, l2_gallery, l3_gallery, l1_norm_gallery, l2_norm_gallery, l3_norm_gallery, l1_cluster_gallery, l2_cluster_gallery, l3_cluster_gallery, logging_text], ) with gr.Tab('Video', visible=True) as tab_video: with gr.Row(): with gr.Column(scale=5, min_width=200): video_input_gallery, submit_button, clear_video_button, max_frame_number = make_input_video_section() with gr.Column(scale=5, min_width=200): video_output_gallery = gr.Video(value=None, label="NCUT Embedding", elem_id="ncut", height="auto", show_share_button=False) [ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section() num_sample_tsne_slider.value = 1000 perplexity_slider.value = 500 n_neighbors_slider.value = 500 knn_tsne_slider.value = 20 # logging text box logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information") clear_video_button.click(lambda x: (None, None), outputs=[video_input_gallery, video_output_gallery]) place_holder_false = gr.Checkbox(label="Place holder", value=False, elem_id="place_holder_false") place_holder_false.visible = False false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False) no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) submit_button.click( run_fn, inputs=[ video_input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, false_placeholder, no_prompt, no_prompt, no_prompt, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, place_holder_false, max_frame_number ], outputs=[video_output_gallery, logging_text], api_name="API_VideoCut", ) with gr.Tab('Text'): try: from app_text import make_demo except ImportError: print("Debugging") from draft_gradio_app_text import make_demo make_demo() with gr.Tab('Vision-Language', visible=False) as tab_lisa: gr.Markdown('[LISA](https://arxiv.org/pdf/2308.00692) is a vision-language model. Input a text prompt and image, LISA generate segmentation masks.') gr.Markdown('In the mask decoder layers, LISA updates the image features w.r.t. the text prompt') gr.Markdown('This page aims to see how the text prompt affects the image features') gr.Markdown('---') gr.Markdown('

Color is aligned across 3 prompts. NCUT is computed on the concatenated features from 3 prompts.

') with gr.Row(): with gr.Column(scale=5, min_width=200): gr.Markdown('### Output (Prompt #1)') l1_gallery = gr.Gallery(format='png', value=[], label="Prompt #1", show_label=False, elem_id="ncut_p1", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False) prompt1 = gr.Textbox(label="Input Prompt #1", elem_id="prompt1", value="where is the person, include the clothes, don't include the guitar and chair", lines=3) with gr.Column(scale=5, min_width=200): gr.Markdown('### Output (Prompt #2)') l2_gallery = gr.Gallery(format='png', value=[], label="Prompt #2", show_label=False, elem_id="ncut_p2", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False) prompt2 = gr.Textbox(label="Input Prompt #2", elem_id="prompt2", value="where is the Gibson Les Pual guitar", lines=3) with gr.Column(scale=5, min_width=200): gr.Markdown('### Output (Prompt #3)') l3_gallery = gr.Gallery(format='png', value=[], label="Prompt #3", show_label=False, elem_id="ncut_p3", columns=[3], rows=[5], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False) prompt3 = gr.Textbox(label="Input Prompt #3", elem_id="prompt3", value="where is the floor", lines=3) with gr.Row(): with gr.Column(scale=5, min_width=200): input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section() with gr.Column(scale=5, min_width=200): [ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section(is_lisa=True) logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information") galleries = [l1_gallery, l2_gallery, l3_gallery] true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder", visible=False) submit_button.click( partial(run_fn, n_ret=len(galleries)), inputs=[ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, true_placeholder, prompt1, prompt2, prompt3, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown ], outputs=galleries + [logging_text], ) with gr.Tab('Model Aligned', visible=False) as tab_aligned: gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)') gr.Markdown('---') gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.') gr.Markdown('NCUT is computed on the concatenated graph of all models, layers, and images. Color is **aligned** across all models and layers.') gr.Markdown('') gr.Markdown("To see a good pattern, you will need to load 100~1000 images. 100 images need 10sec for RTX4090. Running out of HuggingFace GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn") gr.Markdown('---') with gr.Row(): with gr.Column(scale=5, min_width=200): input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section() num_images_slider.value = 100 with gr.Column(scale=5, min_width=200): output_gallery = make_output_images_section() gr.Markdown('### TIP1: use the `full-screen` button, and use `arrow keys` to navigate') gr.Markdown('---') gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)') gr.Markdown('Layer type: attention output (attn), without sum of residual') gr.Markdown('### TIP2: for large image set, please increase the `num_sample` for t-SNE and NCUT') gr.Markdown('---') [ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section(model_ratio=False) model_dropdown.value = "AlignedThreeModelAttnNodes" model_dropdown.visible = False layer_slider.visible = False node_type_dropdown.visible = False num_sample_ncut_slider.value = 10000 num_sample_tsne_slider.value = 1000 # logging text box logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information") false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False) no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) submit_button.click( run_fn, inputs=[ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, false_placeholder, no_prompt, no_prompt, no_prompt, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown ], # outputs=galleries + [logging_text], outputs=[output_gallery, logging_text], ) with gr.Tab('Model Aligned (Advanced)', visible=False) as tab_model_aligned_advanced: gr.Markdown('This page reproduce the results from the paper [AlignedCut](https://arxiv.org/abs/2406.18344)') gr.Markdown('---') gr.Markdown('**Features are aligned across models and layers.** A linear alignment transform is trained for each model/layer, learning signal comes from 1) fMRI brain activation and 2) segmentation preserving eigen-constraints.') gr.Markdown('NCUT is computed on the concatenated graph of all models, layers, and images. Color is **aligned** across all models and layers.') gr.Markdown('') gr.Markdown("To see a good pattern, you will need to load 100~1000 images. 100 images need 10sec for RTX4090. Running out of HuggingFace GPU Quota? Try [Demo](https://ncut-pytorch.readthedocs.io/en/latest/demo/) hosted at UPenn") gr.Markdown('---') gr.Markdown('### Output (Recursion #1)') l1_gallery = gr.Gallery(format='png', value=[], label="Recursion #1", show_label=True, elem_id="ncut_l1", columns=[100], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False, preview=True) add_rotate_flip_buttons(l1_gallery) add_download_button(l1_gallery, "modelaligned_recur1") gr.Markdown('### Output (Recursion #2)') l2_gallery = gr.Gallery(format='png', value=[], label="Recursion #2", show_label=True, elem_id="ncut_l2", columns=[100], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False, preview=True) add_rotate_flip_buttons(l2_gallery) add_download_button(l2_gallery, "modelaligned_recur2") gr.Markdown('### Output (Recursion #3)') l3_gallery = gr.Gallery(format='png', value=[], label="Recursion #3", show_label=True, elem_id="ncut_l3", columns=[100], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False, preview=True) add_rotate_flip_buttons(l3_gallery) add_download_button(l3_gallery, "modelaligned_recur3") with gr.Row(): with gr.Column(scale=5, min_width=200): input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True) num_images_slider.value = 100 with gr.Column(scale=5, min_width=200): with gr.Accordion("➡️ Recursion config", open=True): l1_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #1: N eigenvectors", value=100, elem_id="l1_num_eig") l2_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #2: N eigenvectors", value=50, elem_id="l2_num_eig") l3_num_eig_slider = gr.Slider(1, 1000, step=1, label="Recursion #3: N eigenvectors", value=50, elem_id="l3_num_eig") metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="Recursion distance metric", value="cosine", elem_id="recursion_metric") l1_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #1: Affinity focal gamma", value=0.5, elem_id="recursion_l1_gamma") l2_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #2: Affinity focal gamma", value=0.5, elem_id="recursion_l2_gamma") l3_affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Recursion #3: Affinity focal gamma", value=0.5, elem_id="recursion_l3_gamma") gr.Markdown('---') gr.Markdown('Model: CLIP(ViT-B-16/openai), DiNOv2reg(dinov2_vitb14_reg), MAE(vit_base)') gr.Markdown('Layer type: attention output (attn), without sum of residual') [ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section(model_ratio=False) num_eig_slider.visible = False affinity_focal_gamma_slider.visible = False model_dropdown.value = "AlignedThreeModelAttnNodes" model_dropdown.visible = False layer_slider.visible = False node_type_dropdown.visible = False num_sample_ncut_slider.value = 10000 num_sample_tsne_slider.value = 1000 # logging text box logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information") true_placeholder = gr.Checkbox(label="True placeholder", value=True, elem_id="true_placeholder") true_placeholder.visible = False false_placeholder = gr.Checkbox(label="False placeholder", value=False, elem_id="false_placeholder") false_placeholder.visible = False number_placeholder = gr.Number(0, label="Number placeholder", elem_id="number_placeholder") number_placeholder.visible = False no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) submit_button.click( partial(run_fn, n_ret=3, advanced=True), inputs=[ input_gallery, model_dropdown, layer_slider, l1_num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, false_placeholder, no_prompt, no_prompt, no_prompt, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, false_placeholder, number_placeholder, true_placeholder, l2_num_eig_slider, l3_num_eig_slider, metric_dropdown, l1_affinity_focal_gamma_slider, l2_affinity_focal_gamma_slider, l3_affinity_focal_gamma_slider ], outputs=[l1_gallery, l2_gallery, l3_gallery, logging_text], ) with gr.Tab('Compare Models'): def add_one_model(i_model=1): with gr.Column(scale=5, min_width=200) as col: gr.Markdown(f'### Output Images') output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=False, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False) submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary') add_rotate_flip_buttons(output_gallery) [ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section() # logging text box logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information") false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False) no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) submit_button.click( run_fn, inputs=[ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, false_placeholder, no_prompt, no_prompt, no_prompt, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown ], outputs=[output_gallery, logging_text] ) return col with gr.Row(): with gr.Column(scale=5, min_width=200): input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section() submit_button.visible = False for i in range(3): add_one_model() # Create rows and buttons in a loop rows = [] buttons = [] for i in range(4): row = gr.Row(visible=False) rows.append(row) with row: for j in range(4): with gr.Column(scale=5, min_width=200): add_one_model() button = gr.Button("➕ Add Compare", elem_id=f"add_button_{i}", visible=False if i > 0 else True, scale=3) buttons.append(button) if i > 0: # Reveal the current row and next button buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=row) buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=button) # Hide the current button buttons[i - 1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[i - 1]) # Last button only reveals the last row and hides itself buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1]) buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1]) with gr.Tab('Compare Models (Advanced)', visible=False) as tab_compare_models_advanced: target_images = gr.State([]) input_images = gr.State([]) def add_mlp_fitting_buttons(output_gallery, mlp_gallery, target_images=target_images, input_images=input_images): with gr.Row(): # mark_as_target_button = gr.Button("mark target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary') # mark_as_input_button = gr.Button("mark input", elem_id=f"mark_as_input_button_{output_gallery.elem_id}", variant='secondary') mark_as_target_button = gr.Button("🎯 Mark Target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary') fit_to_target_button = gr.Button("🔴 [MLP] Fit", elem_id=f"fit_to_target_button_{output_gallery.elem_id}", variant='primary') def mark_fn(images, text="target"): if images is None: raise gr.Error("No images selected") if len(images) == 0: raise gr.Error("No images selected") num_images = len(images) gr.Info(f"Marked {num_images} images as {text}") images = [(Image.open(tup[0]), []) for tup in images] return images mark_as_target_button.click(partial(mark_fn, text="target"), inputs=[output_gallery], outputs=[target_images]) # mark_as_input_button.click(partial(mark_fn, text="input"), inputs=[output_gallery], outputs=[input_images]) with gr.Accordion("➡️ MLP Parameters", open=False): num_layers_slider = gr.Slider(2, 10, step=1, label="Number of Layers", value=3, elem_id=f"num_layers_slider_{output_gallery.elem_id}") width_slider = gr.Slider(128, 4096, step=128, label="Width", value=512, elem_id=f"width_slider_{output_gallery.elem_id}") batch_size_slider = gr.Slider(32, 4096, step=32, label="Batch Size", value=128, elem_id=f"batch_size_slider_{output_gallery.elem_id}") lr_slider = gr.Slider(1e-6, 1, step=1e-6, label="Learning Rate", value=3e-4, elem_id=f"lr_slider_{output_gallery.elem_id}") fitting_steps_slider = gr.Slider(1000, 100000, step=1000, label="Fitting Steps", value=30000, elem_id=f"fitting_steps_slider_{output_gallery.elem_id}") fps_sample_slider = gr.Slider(128, 50000, step=128, label="FPS Sample", value=10240, elem_id=f"fps_sample_slider_{output_gallery.elem_id}") segmentation_loss_lambda_slider = gr.Slider(0, 100, step=0.01, label="Segmentation Preserving Loss Lambda", value=1, elem_id=f"segmentation_loss_lambda_slider_{output_gallery.elem_id}") fit_to_target_button.click( run_mlp_fit, inputs=[output_gallery, target_images, num_layers_slider, width_slider, batch_size_slider, lr_slider, fitting_steps_slider, fps_sample_slider, segmentation_loss_lambda_slider], outputs=[mlp_gallery], ) def add_one_model(i_model=1): with gr.Column(scale=5, min_width=200) as col: gr.Markdown(f'### Output Images') output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False) submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary') add_rotate_flip_buttons(output_gallery) add_download_button(output_gallery, f"ncut_embed") mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False) add_mlp_fitting_buttons(output_gallery, mlp_gallery) add_download_button(mlp_gallery, f"mlp_color_align") norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False) add_download_button(norm_gallery, f"eig_norm") cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False) add_download_button(cluster_gallery, f"clusters") [ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section() # logging text box logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information") false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False) no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) submit_button.click( partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True, advanced=True), inputs=[ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, false_placeholder, no_prompt, no_prompt, no_prompt, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown ], outputs=[output_gallery, cluster_gallery, norm_gallery, logging_text] ) output_gallery.change(lambda x: gr.update(value=x), inputs=[output_gallery], outputs=[mlp_gallery]) return output_gallery galleries = [] with gr.Row(): with gr.Column(scale=5, min_width=200): input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True) submit_button.visible = False for i in range(3): g = add_one_model() galleries.append(g) # Create rows and buttons in a loop rows = [] buttons = [] for i in range(4): row = gr.Row(visible=False) rows.append(row) with row: for j in range(4): with gr.Column(scale=5, min_width=200): g = add_one_model() galleries.append(g) button = gr.Button("➕ Add Compare", elem_id=f"add_button_{i}", visible=False if i > 0 else True, scale=3) buttons.append(button) if i > 0: # Reveal the current row and next button buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=row) buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=button) # Hide the current button buttons[i - 1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[i - 1]) # Last button only reveals the last row and hides itself buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1]) buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1]) with gr.Tab('Directed (dev)', visible=False) as tab_directed_ncut: target_images = gr.State([]) input_images = gr.State([]) def add_mlp_fitting_buttons(output_gallery, mlp_gallery, target_images=target_images, input_images=input_images): with gr.Row(): # mark_as_target_button = gr.Button("mark target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary') # mark_as_input_button = gr.Button("mark input", elem_id=f"mark_as_input_button_{output_gallery.elem_id}", variant='secondary') mark_as_target_button = gr.Button("🎯 Mark Target", elem_id=f"mark_as_target_button_{output_gallery.elem_id}", variant='secondary') fit_to_target_button = gr.Button("🔴 [MLP] Fit", elem_id=f"fit_to_target_button_{output_gallery.elem_id}", variant='primary') def mark_fn(images, text="target"): if images is None: raise gr.Error("No images selected") if len(images) == 0: raise gr.Error("No images selected") num_images = len(images) gr.Info(f"Marked {num_images} images as {text}") images = [(Image.open(tup[0]), []) for tup in images] return images mark_as_target_button.click(partial(mark_fn, text="target"), inputs=[output_gallery], outputs=[target_images]) # mark_as_input_button.click(partial(mark_fn, text="input"), inputs=[output_gallery], outputs=[input_images]) with gr.Accordion("➡️ MLP Parameters", open=False): num_layers_slider = gr.Slider(2, 10, step=1, label="Number of Layers", value=3, elem_id=f"num_layers_slider_{output_gallery.elem_id}") width_slider = gr.Slider(128, 4096, step=128, label="Width", value=512, elem_id=f"width_slider_{output_gallery.elem_id}") batch_size_slider = gr.Slider(32, 4096, step=32, label="Batch Size", value=128, elem_id=f"batch_size_slider_{output_gallery.elem_id}") lr_slider = gr.Slider(1e-6, 1, step=1e-6, label="Learning Rate", value=3e-4, elem_id=f"lr_slider_{output_gallery.elem_id}") fitting_steps_slider = gr.Slider(1000, 100000, step=1000, label="Fitting Steps", value=30000, elem_id=f"fitting_steps_slider_{output_gallery.elem_id}") fps_sample_slider = gr.Slider(128, 50000, step=128, label="FPS Sample", value=10240, elem_id=f"fps_sample_slider_{output_gallery.elem_id}") segmentation_loss_lambda_slider = gr.Slider(0, 100, step=0.01, label="Segmentation Preserving Loss Lambda", value=1, elem_id=f"segmentation_loss_lambda_slider_{output_gallery.elem_id}") fit_to_target_button.click( run_mlp_fit, inputs=[output_gallery, target_images, num_layers_slider, width_slider, batch_size_slider, lr_slider, fitting_steps_slider, fps_sample_slider, segmentation_loss_lambda_slider], outputs=[mlp_gallery], ) def make_parameters_section_2model(model_ratio=True): gr.Markdown("### Parameters Help") from ncut_pytorch.backbone import list_models, get_demo_model_names model_names = list_models() model_names = sorted(model_names) # only CLIP DINO MAE is implemented for q k v ok_models = ["CLIP(ViT", "DiNO(", "MAE("] model_names = [m for m in model_names if any(ok in m for ok in ok_models)] def get_filtered_model_names(name): return [m for m in model_names if name.lower() in m.lower()] def get_default_model_name(name): lst = get_filtered_model_names(name) if len(lst) > 1: return lst[1] return lst[0] model_radio = gr.Radio(["CLIP", "DiNO", "MAE"], label="Backbone", value="DiNO", elem_id="model_radio", show_label=True, visible=model_ratio) model_dropdown = gr.Dropdown(get_filtered_model_names("DiNO"), label="", value="DiNO(dino_vitb8_448)", elem_id="model_name", show_label=False) model_radio.change(fn=lambda x: gr.update(choices=get_filtered_model_names(x), value=get_default_model_name(x)), inputs=model_radio, outputs=[model_dropdown]) layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer") positive_prompt = gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'") positive_prompt.visible = False negative_prompt = gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'") negative_prompt.visible = False node_type_dropdown = gr.Dropdown(['q', 'k', 'v'], label="Left-side Node Type", value="q", elem_id="node_type", info="In directed case, left-side SVD eigenvector is taken") node_type_dropdown2 = gr.Dropdown(['q', 'k', 'v'], label="Right-side Node Type", value="k", elem_id="node_type2") head_index_text = gr.Textbox(value='all', label="Head Index", elem_id="head_index", type="text", info="which attention heads to use, comma separated, e.g. 0,1,2") make_symmetric = gr.Checkbox(label="Make Symmetric", value=False, elem_id="make_symmetric", info="make the graph symmetric by A = (A + A.T) / 2") num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for smaller clusters') def change_layer_slider(model_name): # SD2, UNET if "stable" in model_name.lower() and "diffusion" in model_name.lower(): from ncut_pytorch.backbone import SD_KEY_DICT default_layer = 'up_2_resnets_1_block' if 'diffusion-3' not in model_name else 'block_23' return (gr.Slider(1, 49, step=1, label="Diffusion: Timestep (Noise)", value=5, elem_id="layer", visible=True, info="Noise level, 50 is max noise"), gr.Dropdown(SD_KEY_DICT[model_name], label="Diffusion: Layer and Node", value=default_layer, elem_id="node_type", info="U-Net (v1, v2) or DiT (v3)")) if model_name == "LISSL(xinlai/LISSL-7B-v1)": layer_names = ["dec_0_input", "dec_0_attn", "dec_0_block", "dec_1_input", "dec_1_attn", "dec_1_block"] default_layer = "dec_1_block" return (gr.Slider(1, 6, step=1, label="LISA decoder: Layer index", value=6, elem_id="layer", visible=False, info=""), gr.Dropdown(layer_names, label="LISA decoder: Layer and Node", value=default_layer, elem_id="node_type")) layer_dict = LAYER_DICT if model_name in layer_dict: value = layer_dict[model_name] return gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True, info="") else: value = 12 return gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True, info="") model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=layer_slider) def change_prompt_text(model_name): if model_name in promptable_diffusion_models: return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=True), gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=True)) return (gr.Textbox(label="Prompt (Positive)", elem_id="prompt", placeholder="e.g. 'a photo of Gibson Les Pual guitar'", visible=False), gr.Textbox(label="Prompt (Negative)", elem_id="prompt", placeholder="e.g. 'a photo from egocentric view'", visible=False)) model_dropdown.change(fn=change_prompt_text, inputs=model_dropdown, outputs=[positive_prompt, negative_prompt]) with gr.Accordion("Advanced Parameters: NCUT", open=False): gr.Markdown("Docs: How to Get Better Segmentation") affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation") num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation") # sampling_method_dropdown = gr.Dropdown(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method", info="Nyström approximation") sampling_method_dropdown = gr.Radio(["QuickFPS", "random"], label="NCUT: Sampling method", value="QuickFPS", elem_id="sampling_method") # ncut_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric") ncut_metric_dropdown = gr.Radio(["euclidean", "cosine", "rbf"], label="NCUT: Distance metric", value="cosine", elem_id="ncut_metric") ncut_knn_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation") ncut_indirect_connection = gr.Checkbox(label="indirect_connection", value=False, elem_id="ncut_indirect_connection", info="TODO: Indirect connection is not implemented for directed NCUT", interactive=False) ncut_make_orthogonal = gr.Checkbox(label="make_orthogonal", value=False, elem_id="ncut_make_orthogonal", info="Apply post-hoc eigenvectors orthogonalization") with gr.Accordion("Advanced Parameters: Visualization", open=False): # embedding_method_dropdown = gr.Dropdown(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method") embedding_method_dropdown = gr.Radio(["tsne_3d", "umap_3d", "umap_sphere", "tsne_2d", "umap_2d"], label="Coloring method", value="tsne_3d", elem_id="embedding_method") # embedding_metric_dropdown = gr.Dropdown(["euclidean", "cosine"], label="t-SNE/UMAP metric", value="euclidean", elem_id="embedding_metric") embedding_metric_dropdown = gr.Radio(["euclidean", "cosine"], label="t-SNE/UMAP: metric", value="cosine", elem_id="embedding_metric") num_sample_tsne_slider = gr.Slider(100, 10000, step=100, label="t-SNE/UMAP: num_sample", value=300, elem_id="num_sample_tsne", info="Nyström approximation") knn_tsne_slider = gr.Slider(1, 100, step=1, label="t-SNE/UMAP: KNN", value=10, elem_id="knn_tsne", info="Nyström approximation") perplexity_slider = gr.Slider(10, 1000, step=10, label="t-SNE: perplexity", value=150, elem_id="perplexity") n_neighbors_slider = gr.Slider(10, 1000, step=10, label="UMAP: n_neighbors", value=150, elem_id="n_neighbors") min_dist_slider = gr.Slider(0.1, 1, step=0.1, label="UMAP: min_dist", value=0.1, elem_id="min_dist") return [model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt] def add_one_model(i_model=1): with gr.Column(scale=5, min_width=200) as col: gr.Markdown(f'### Output Images') output_gallery = gr.Gallery(format='png', value=[], label="NCUT Embedding", show_label=True, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False) submit_button = gr.Button("🔴 RUN", elem_id=f"submit_button{i_model}", variant='primary') add_rotate_flip_buttons(output_gallery) add_download_button(output_gallery, f"ncut_embed") mlp_gallery = gr.Gallery(format='png', value=[], label="MLP color align", show_label=True, elem_id=f"mlp{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_fullscreen_button=True, interactive=False) add_mlp_fitting_buttons(output_gallery, mlp_gallery) add_download_button(mlp_gallery, f"mlp_color_align") norm_gallery = gr.Gallery(value=[], label="Eigenvector Magnitude", show_label=True, elem_id=f"eig_norm{i_model}", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False) add_download_button(norm_gallery, f"eig_norm") cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id=f"clusters{i_model}", columns=[2], rows=[4], object_fit="contain", height="450px", show_share_button=True, preview=False, interactive=False) add_download_button(cluster_gallery, f"clusters") [ model_dropdown, layer_slider, node_type_dropdown, node_type_dropdown2, head_index_text, make_symmetric, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section_2model() # logging text box logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information") false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False) no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False) submit_button.click( partial(run_fn, n_ret=3, plot_clusters=True, alignedcut_eig_norm_plot=True, advanced=True, directed=True), inputs=[ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, false_placeholder, no_prompt, no_prompt, no_prompt, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, *[false_placeholder for _ in range(9)], node_type_dropdown2, head_index_text, make_symmetric ], outputs=[output_gallery, cluster_gallery, norm_gallery, logging_text] ) output_gallery.change(lambda x: gr.update(value=x), inputs=[output_gallery], outputs=[mlp_gallery]) return output_gallery galleries = [] with gr.Row(): with gr.Column(scale=5, min_width=200): input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(allow_download=True) submit_button.visible = False for i in range(3): g = add_one_model() galleries.append(g) # Create rows and buttons in a loop rows = [] buttons = [] for i in range(4): row = gr.Row(visible=False) rows.append(row) with row: for j in range(4): with gr.Column(scale=5, min_width=200): g = add_one_model() galleries.append(g) button = gr.Button("➕ Add Compare", elem_id=f"add_button_{i}", visible=False if i > 0 else True, scale=3) buttons.append(button) if i > 0: # Reveal the current row and next button buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=row) buttons[i - 1].click(fn=lambda x: gr.update(visible=True), outputs=button) # Hide the current button buttons[i - 1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[i - 1]) # Last button only reveals the last row and hides itself buttons[-1].click(fn=lambda x: gr.update(visible=True), outputs=rows[-1]) buttons[-1].click(fn=lambda x: gr.update(visible=False), outputs=buttons[-1]) with gr.Tab('Application'): gr.Markdown("Draw some points on the image to find corrsponding segments in other images. E.g. click on one face to segment all the face. [Video Tutorial](https://ncut-pytorch.readthedocs.io/en/latest/gallery_application/)") with gr.Row(): with gr.Column(scale=5, min_width=200): gr.Markdown("### Step 0: Load Images") input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(markdown=False) submit_button.visible = False num_images_slider.value = 30 logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False) with gr.Column(scale=5, min_width=200): gr.Markdown("### Step 1: NCUT Embedding") output_gallery = make_output_images_section(markdown=False, button=False) submit_button = gr.Button("🔴 RUN", elem_id="submit_button", variant='primary') add_rotate_flip_buttons(output_gallery) [ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt ] = make_parameters_section() false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False) no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) submit_button.click( partial(run_fn, n_ret=1), inputs=[ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown, positive_prompt, negative_prompt, false_placeholder, no_prompt, no_prompt, no_prompt, affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown ], outputs=[output_gallery, logging_text], ) with gr.Column(scale=5, min_width=200): gr.Markdown("### Step 2a: Pick an Image") from gradio_image_prompter import ImagePrompter image_type_radio = gr.Radio(["Original", "NCUT"], label="Image Display Type", value="Original", elem_id="image_type_radio") with gr.Row(): image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True) image2_slider = gr.Slider(0, 100, step=1, label="Image#2 Index", value=1, elem_id="image2_slider", interactive=True) image3_slider = gr.Slider(0, 100, step=1, label="Image#3 Index", value=2, elem_id="image3_slider", interactive=True) load_one_image_button = gr.Button("🔴 Load", elem_id="load_one_image_button", variant='primary') gr.Markdown("### Step 2b: Draw Points") gr.Markdown("""
🖱️ Left Click: Foreground
🖱️ Middle Click: Background

Top Right Buttons:
: Remove Last Point
: Clear All Points
(Known issue: please manually clear the points after loading new image)
""") prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False) prompt_image2 = ImagePrompter(show_label=False, elem_id="prompt_image2", interactive=False) prompt_image3 = ImagePrompter(show_label=False, elem_id="prompt_image3", interactive=False) # def update_number_of_images(images): # if images is None: # return gr.update(max=0, value=0) # return gr.update(max=len(images)-1, value=1) # input_gallery.change(update_number_of_images, inputs=input_gallery, outputs=image1_slider) def update_prompt_image(original_images, ncut_images, image_type, index): if image_type == "Original": images = original_images else: images = ncut_images if images is None: return total_len = len(images) if total_len == 0: return if index >= total_len: index = total_len - 1 return ImagePrompter(value={'image': images[index][0], 'points': []}, interactive=True) # return gr.Image(value=images[index][0], elem_id=f"prompt_image{randint}", interactive=True) load_one_image_button.click(update_prompt_image, inputs=[input_gallery, output_gallery, image_type_radio, image1_slider], outputs=[prompt_image1]) load_one_image_button.click(update_prompt_image, inputs=[input_gallery, output_gallery, image_type_radio, image2_slider], outputs=[prompt_image2]) load_one_image_button.click(update_prompt_image, inputs=[input_gallery, output_gallery, image_type_radio, image3_slider], outputs=[prompt_image3]) image3_slider.visible = False prompt_image3.visible = False with gr.Column(scale=5, min_width=200): gr.Markdown("### Step 3: Segment and Crop") mask_gallery = gr.Gallery(value=[], label="Segmentation Masks", show_label=True, elem_id="mask_gallery", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False) run_crop_button = gr.Button("🔴 RUN", elem_id="run_crop_button", variant='primary') add_download_button(mask_gallery, "mask") distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (FG)", value=0.9, elem_id="distance_threshold", info="increase for smaller FG mask") fg_contrast_slider = gr.Slider(0, 2, step=0.01, label="Mask Scaling (FG)", value=1, elem_id="distance_focal", info="increase for smaller FG mask", visible=True) negative_distance_threshold_slider = gr.Slider(0, 1, step=0.01, label="Mask Threshold (BG)", value=0.9, elem_id="distance_threshold", info="increase for less BG removal") bg_contrast_slider = gr.Slider(0, 2, step=0.01, label="Mask Scaling (BG)", value=1, elem_id="distance_focal", info="increase for less BG removal", visible=True) overlay_image_checkbox = gr.Checkbox(label="Overlay Original Image", value=True, elem_id="overlay_image_checkbox") # filter_small_area_checkbox = gr.Checkbox(label="Noise Reduction", value=True, elem_id="filter_small_area_checkbox") distance_power_slider = gr.Slider(-3, 3, step=0.01, label="Distance Power", value=0.5, elem_id="distance_power", info="d = d^p", visible=False) crop_gallery = gr.Gallery(value=[], label="Cropped Images", show_label=True, elem_id="crop_gallery", columns=[3], rows=[1], object_fit="contain", height="450px", show_share_button=True, interactive=False) add_download_button(crop_gallery, "cropped") crop_expand_slider = gr.Slider(1.0, 2.0, step=0.1, label="Crop bbox Expand Factor", value=1.0, elem_id="crop_expand", info="increase for larger crop", visible=True) area_threshold_slider = gr.Slider(0, 100, step=0.1, label="Area Threshold (%)", value=3, elem_id="area_threshold", info="for noise filtering (area of connected components)", visible=False) # logging_image = gr.Image(value=None, label="Logging Image", elem_id="logging_image", interactive=False) # prompt_image.change(lambda x: gr.update(value=x.get('image', None)), inputs=prompt_image, outputs=[logging_image]) def relative_xy(prompts): image = prompts['image'] points = np.asarray(prompts['points']) if points.shape[0] == 0: return [], [] is_point = points[:, 5] == 4.0 points = points[is_point] is_positive = points[:, 2] == 1.0 is_negative = points[:, 2] == 0.0 xy = points[:, :2].tolist() if isinstance(image, str): image = Image.open(image) image = np.array(image) h, w = image.shape[:2] new_xy = [(x/w, y/h) for x, y in xy] # print(new_xy) return new_xy, is_positive def xy_rgb(prompts, image_idx, ncut_images): image = ncut_images[image_idx] xy, is_positive = relative_xy(prompts) rgbs = [] for i, (x, y) in enumerate(xy): rgb = image.getpixel((int(x*image.width), int(y*image.height))) rgbs.append((rgb, is_positive[i])) return rgbs def run_crop(original_images, ncut_images, prompts1, prompts2, prompts3, image_idx1, image_idx2, image_idx3, crop_expand, distance_threshold, distance_power, area_threshold, overlay_image, negative_distance_threshold, fg_contrast, bg_contrast): ncut_images = [image[0] for image in ncut_images] if len(ncut_images) == 0: return [] if isinstance(ncut_images[0], str): ncut_images = [Image.open(image) for image in ncut_images] rgbs = xy_rgb(prompts1, image_idx1, ncut_images) + \ xy_rgb(prompts2, image_idx2, ncut_images) + \ xy_rgb(prompts3, image_idx3, ncut_images) # print(rgbs) ncut_images = [np.array(image).astype(np.float32) for image in ncut_images] ncut_pixels = [image.reshape(-1, 3) for image in ncut_images] h, w = ncut_images[0].shape[:2] ncut_pixels = torch.tensor(np.array(ncut_pixels).reshape(-1, 3)) / 255 # normalized_ncut_pixels = F.normalize(ncut_pixels, p=2, dim=-1) def to_mask(heatmap, threshold, gamma): heatmap = (heatmap - heatmap.mean()) / heatmap.std() heatmap = heatmap.double() heatmap = torch.exp(heatmap) # heatmap = 1 / (heatmap + 1e-6) heatmap = 1 / heatmap ** gamma # import math # heatmap = 1 / heatmap ** math.log(6.1 - gamma) if heatmap.shape[0] > 10000: np.random.seed(0) random_idx = np.random.choice(heatmap.shape[0], 10000, replace=False) vmin, vmax = heatmap[random_idx].quantile(0.01), heatmap[random_idx].quantile(0.99) else: vmin, vmax = heatmap.quantile(0.01), heatmap.quantile(0.99) heatmap = (heatmap - vmin) / (vmax - vmin) heatmap = heatmap.reshape(len(ncut_images), h, w) mask = heatmap > threshold return mask positive_masks, negative_masks = [], [] for rgb, is_positive in rgbs: rgb = torch.tensor(rgb).float() / 255 distance = (ncut_pixels - rgb[None]).norm(dim=-1) distance = distance.squeeze(-1) if is_positive: positive_masks.append(to_mask(distance, distance_threshold, fg_contrast)) else: negative_masks.append(to_mask(distance, negative_distance_threshold, bg_contrast)) if len(positive_masks) == 0: raise gr.Error("No prompt points. Please draw some points on the image.") positive_masks = torch.stack(positive_masks) positive_mask = positive_masks.any(dim=0) if len(negative_masks) > 0: negative_masks = torch.stack(negative_masks) negative_mask = negative_masks.any(dim=0) positive_mask = positive_mask & ~negative_mask # convert to PIL mask = positive_mask.cpu().numpy() mask = mask.astype(np.uint8) * 255 import cv2 def get_bboxes_and_clean_mask(mask, min_area=500): """ Args: - mask: A numpy image of a binary mask with 255 for the object and 0 for the background. - min_area: Minimum area for a connected component to be considered valid (default 500). Returns: - bounding_boxes: List of bounding boxes for valid objects (x, y, width, height). - cleaned_pil_mask: A Pillow image of the cleaned mask, with small components removed. """ # Ensure the mask is binary (0 or 255) mask = np.where(mask > 127, 255, 0).astype(np.uint8) # Remove small noise using morphological operations (denoising) kernel = np.ones((5, 5), np.uint8) cleaned_mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) # Find connected components in the cleaned mask num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(cleaned_mask, connectivity=8) # Initialize an empty mask to store the final cleaned mask final_cleaned_mask = np.zeros_like(cleaned_mask) # Collect bounding boxes for components that are larger than the threshold and update the cleaned mask bounding_boxes = [] for i in range(1, num_labels): # Skip label 0 (background) x, y, w, h, area = stats[i] if area >= min_area: # Add the bounding box of the valid component bounding_boxes.append((x, y, w, h)) # Keep the valid components in the final cleaned mask final_cleaned_mask[labels == i] = 255 # Convert the final cleaned mask back to a Pillow image cleaned_pil_mask = Image.fromarray(final_cleaned_mask) return bounding_boxes, cleaned_pil_mask bboxs, filtered_masks = zip(*[get_bboxes_and_clean_mask(_mask) for _mask in mask]) original_images = [image[0] for image in original_images] if isinstance(original_images[0], str): original_images = [Image.open(image) for image in original_images] # combine the masks, also draw the bounding boxes combined_masks = [] for i_image in range(len(mask)): noisy_mask = np.array(mask[i_image]) bbox = bboxs[i_image] clean_mask = np.array(filtered_masks[i_image]) combined_mask = noisy_mask * 0.4 + clean_mask combined_mask = np.clip(combined_mask, 0, 255).astype(np.uint8) if overlay_image: # add empty red and green channel combined_mask = np.stack([np.zeros_like(combined_mask), np.zeros_like(combined_mask), combined_mask], axis=-1) _image = original_images[i_image].convert("RGB").resize((combined_mask.shape[1], combined_mask.shape[0])) _image = np.array(_image) combined_mask = 0.5 * combined_mask + 0.5 * _image combined_mask = np.clip(combined_mask, 0, 255).astype(np.uint8) for x, y, w, h in bbox: cv2.rectangle(combined_mask, (x-1, y-1), (x + w+2, y + h+2), (255, 0, 0), 2) combined_mask = Image.fromarray(combined_mask) combined_masks.append(combined_mask) def extend_the_mask(xywh, factor=1.5): x, y, w, h = xywh x -= w * (factor - 1) / 2 y -= h * (factor - 1) / 2 w *= factor h *= factor return x, y, w, h def resize_the_mask(xywh, original_size, target_size): x, y, w, h = xywh x *= target_size[0] / original_size[0] y *= target_size[1] / original_size[1] w *= target_size[0] / original_size[0] h *= target_size[1] / original_size[1] x, y, w, h = int(x), int(y), int(w), int(h) return x, y, w, h def crop_image(image, xywh, mask_h, mask_w, factor=1.0): x, y, w, h = xywh x, y, w, h = resize_the_mask((x, y, w, h), (mask_h, mask_w), image.size) _x, _y, _w, _h = extend_the_mask((x, y, w, h), factor=factor) crop = image.crop((_x, _y, _x + _w, _y + _h)) return crop mask_h, mask_w = filtered_masks[0].size cropped_images = [] for _image, _bboxs in zip(original_images, bboxs): for _bbox in _bboxs: cropped_images.append(crop_image(_image, _bbox, mask_h, mask_w, factor=crop_expand)) return combined_masks, cropped_images run_crop_button.click(run_crop, inputs=[input_gallery, output_gallery, prompt_image1, prompt_image2, prompt_image3, image1_slider, image2_slider, image3_slider, crop_expand_slider, distance_threshold_slider, distance_power_slider, area_threshold_slider, overlay_image_checkbox, negative_distance_threshold_slider, fg_contrast_slider, bg_contrast_slider], outputs=[mask_gallery, crop_gallery]) # with gr.Tab('PlayGround (test)', visible=False) as test_playground_tab: # eigvecs = gr.State(np.array([])) # with gr.Row(): # with gr.Column(scale=5, min_width=200): # gr.Markdown("### Step 1: Load Images and Run NCUT") # input_gallery, submit_button, clear_images_button, dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_input_images_section(n_example_images=100) # # submit_button.visible = False # num_images_slider.value = 30 # [ # model_dropdown, layer_slider, node_type_dropdown, num_eig_slider, # affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, # embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, # perplexity_slider, n_neighbors_slider, min_dist_slider, # sampling_method_dropdown, ncut_metric_dropdown, positive_prompt, negative_prompt # ] = make_parameters_section(ncut_parameter_dropdown=False) # num_eig_slider.value = 1000 # num_eig_slider.visible = False # logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False) # false_placeholder = gr.Checkbox(label="False", value=False, elem_id="false_placeholder", visible=False) # no_prompt = gr.Textbox("", label="", elem_id="empty_placeholder", type="text", placeholder="", visible=False) # submit_button.click( # partial(run_fn, n_ret=1, only_eigvecs=True), # inputs=[ # input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown, # positive_prompt, negative_prompt, # false_placeholder, no_prompt, no_prompt, no_prompt, # affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal, # embedding_method_dropdown, embedding_metric_dropdown, num_sample_tsne_slider, knn_tsne_slider, # perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown, ncut_metric_dropdown # ], # outputs=[eigvecs, logging_text], # ) # with gr.Column(scale=5, min_width=200): # gr.Markdown("### Step 2a: Pick an Image") # from gradio_image_prompter import ImagePrompter # with gr.Row(): # image1_slider = gr.Slider(0, 100, step=1, label="Image#1 Index", value=0, elem_id="image1_slider", interactive=True) # load_one_image_button = gr.Button("🔴 Load", elem_id="load_one_image_button", variant='primary') # gr.Markdown("### Step 2b: Draw a Point") # gr.Markdown(""" #
# 🖱️ Left Click: Foreground
#
# """) # prompt_image1 = ImagePrompter(show_label=False, elem_id="prompt_image1", interactive=False) # def update_prompt_image(original_images, index): # images = original_images # if images is None: # return # total_len = len(images) # if total_len == 0: # return # if index >= total_len: # index = total_len - 1 # return ImagePrompter(value={'image': images[index][0], 'points': []}, interactive=True) # # return gr.Image(value=images[index][0], elem_id=f"prompt_image{randint}", interactive=True) # load_one_image_button.click(update_prompt_image, inputs=[input_gallery, image1_slider], outputs=[prompt_image1]) # child_idx = gr.State([]) # current_idx = gr.State(None) # n_eig = gr.State(64) # with gr.Column(scale=5, min_width=200): # gr.Markdown("### Step 3: Check groupping") # child_distance_slider = gr.Slider(0, 0.5, step=0.001, label="Child Distance", value=0.1, elem_id="child_distance_slider", interactive=True) # overlay_image_checkbox = gr.Checkbox(label="Overlay Image", value=True, elem_id="overlay_image_checkbox", interactive=True) # run_button = gr.Button("🔴 RUN", elem_id="run_groupping", variant='primary') # parent_plot = gr.Gallery(value=None, label="Parent", show_label=True, elem_id="parent_plot", interactive=False, rows=[1], columns=[2]) # parent_button = gr.Button("Use Parent", elem_id="run_parent") # current_plot = gr.Gallery(value=None, label="Current", show_label=True, elem_id="current_plot", interactive=False, rows=[1], columns=[2]) # with gr.Column(scale=5, min_width=200): # child_plots = [] # child_buttons = [] # for i in range(4): # child_plots.append(gr.Gallery(value=None, label=f"Child {i}", show_label=True, elem_id=f"child_plot_{i}", interactive=False, rows=[1], columns=[2])) # child_buttons.append(gr.Button(f"Use Child {i}", elem_id=f"run_child_{i}")) # def relative_xy(prompts): # image = prompts['image'] # points = np.asarray(prompts['points']) # if points.shape[0] == 0: # return [], [] # is_point = points[:, 5] == 4.0 # points = points[is_point] # is_positive = points[:, 2] == 1.0 # is_negative = points[:, 2] == 0.0 # xy = points[:, :2].tolist() # if isinstance(image, str): # image = Image.open(image) # image = np.array(image) # h, w = image.shape[:2] # new_xy = [(x/w, y/h) for x, y in xy] # # print(new_xy) # return new_xy, is_positive # def xy_eigvec(prompts, image_idx, eigvecs): # eigvec = eigvecs[image_idx] # xy, is_positive = relative_xy(prompts) # for i, (x, y) in enumerate(xy): # if not is_positive[i]: # continue # x = int(x * eigvec.shape[1]) # y = int(y * eigvec.shape[0]) # return eigvec[y, x], (y, x) # from ncut_pytorch.ncut_pytorch import _transform_heatmap # def _run_heatmap_fn(images, eigvecs, prompt_image_idx, prompt_points, n_eig, flat_idx=None, raw_heatmap=False, overlay_image=True): # left = eigvecs[..., :n_eig] # if flat_idx is not None: # right = eigvecs.reshape(-1, eigvecs.shape[-1])[flat_idx] # y, x = None, None # else: # right, (y, x) = xy_eigvec(prompt_points, prompt_image_idx, eigvecs) # right = right[:n_eig] # left = F.normalize(left, p=2, dim=-1) # _right = F.normalize(right, p=2, dim=-1) # heatmap = left @ _right.unsqueeze(-1) # heatmap = heatmap.squeeze(-1) # heatmap = 1 - heatmap # heatmap = _transform_heatmap(heatmap) # if raw_heatmap: # return heatmap # # apply hot colormap and covert to PIL image 256x256 # heatmap = heatmap.cpu().numpy() # hot_map = matplotlib.colormaps['hot'] # heatmap = hot_map(heatmap) # pil_images = to_pil_images(torch.tensor(heatmap), target_size=256, force_size=True) # if overlay_image: # overlaied_images = [] # for i_image in range(len(images)): # rgb_image = images[i_image].resize((256, 256)) # rgb_image = np.array(rgb_image) # heatmap_image = np.array(pil_images[i_image])[..., :3] # blend_image = 0.5 * rgb_image + 0.5 * heatmap_image # blend_image = Image.fromarray(blend_image.astype(np.uint8)) # overlaied_images.append(blend_image) # pil_images = overlaied_images # return pil_images, (y, x) # def _farthest_point_sampling( # features, # start_feature, # num_sample=300, # h=9, # ): # import fpsample # h = min(h, int(np.log2(features.shape[0]))) # inp = features.cpu().numpy() # inp = np.concatenate([inp, start_feature[None, :]], axis=0) # kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling( # inp, num_sample, h, start_idx=inp.shape[0] - 1 # ).astype(np.int64) # return kdline_fps_samples_idx # @torch.no_grad() # def run_heatmap(images, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx=None, overlay_image=True): # gr.Info(f"current number of eigenvectors: {n_eig}") # eigvecs = torch.tensor(eigvecs) # image1_slider = min(image1_slider, len(images)-1) # images = [image[0] for image in images] # if isinstance(images[0], str): # images = [Image.open(image[0]).convert("RGB").resize((256, 256)) for image in images] # current_heatmap, (y, x) = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, n_eig, flat_idx, overlay_image=overlay_image) # parent_heatmap, _ = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, int(n_eig/2), flat_idx, overlay_image=overlay_image) # # find childs # # pca_eigvecs # _eigvecs = eigvecs.reshape(-1, eigvecs.shape[-1]) # u, s, v = torch.pca_lowrank(_eigvecs, q=8) # _n = _eigvecs.shape[0] # s /= math.sqrt(_n) # _eigvecs = u @ torch.diag(s) # if flat_idx is None: # _picked_eigvec = _eigvecs.reshape(*eigvecs.shape[:-1], 8)[image1_slider, y, x] # else: # _picked_eigvec = _eigvecs[flat_idx] # l2_distance = torch.norm(_eigvecs - _picked_eigvec, dim=-1) # average_distance = l2_distance.mean() # distance_threshold = distance_slider * average_distance # distance_mask = l2_distance < distance_threshold # masked_eigvecs = _eigvecs[distance_mask] # num_childs = min(4, masked_eigvecs.shape[0]) # assert num_childs > 0 # child_idx = _farthest_point_sampling(masked_eigvecs, _picked_eigvec, num_sample=num_childs+1) # child_idx = np.sort(child_idx)[:-1] # # convert child_idx to flat_idx # dummy_idx = torch.zeros(_eigvecs.shape[0], dtype=torch.bool) # dummy_idx2 = torch.zeros(int(distance_mask.sum().item()), dtype=torch.bool) # dummy_idx2[child_idx] = True # dummy_idx[distance_mask] = dummy_idx2 # child_idx = torch.where(dummy_idx)[0] # # current_child heatmap, for contrast # current_child_heatmap = _run_heatmap_fn(images,eigvecs, image1_slider, prompt_image1, int(n_eig*2), flat_idx, raw_heatmap=True, overlay_image=overlay_image) # # child_heatmaps, contrast mean of current clicked point # child_heatmaps = [] # for idx in child_idx: # child_heatmap = _run_heatmap_fn(images,eigvecs, image1_slider, prompt_image1, int(n_eig*2), idx, raw_heatmap=True, overlay_image=overlay_image) # heatmap = child_heatmap - current_child_heatmap # # convert [-1, 1] to [0, 1] # heatmap = (heatmap + 1) / 2 # heatmap = heatmap.cpu().numpy() # cm = matplotlib.colormaps['bwr'] # heatmap = cm(heatmap) # # bwr with contrast # pil_images1 = to_pil_images(torch.tensor(heatmap), resize=256) # # no contrast # pil_images2, _ = _run_heatmap_fn(images, eigvecs, image1_slider, prompt_image1, int(n_eig*2), idx, overlay_image=overlay_image) # # combine contrast and no contrast # pil_images = [] # for i in range(len(pil_images1)): # pil_images.append(pil_images2[i]) # pil_images.append(pil_images1[i]) # child_heatmaps.append(pil_images) # return parent_heatmap, current_heatmap, *child_heatmaps, child_idx.tolist() # # def debug_fn(eigvecs): # # shape = eigvecs.shape # # gr.Info(f"eigvecs shape: {shape}") # # run_button.click( # # debug_fn, # # inputs=[eigvecs], # # outputs=[], # # ) # none_placeholder = gr.State(None) # run_button.click( # run_heatmap, # inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, none_placeholder, overlay_image_checkbox], # outputs=[parent_plot, current_plot, *child_plots, child_idx], # ) # def run_paraent(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx=None, overlay_image=True): # n_eig = int(n_eig/2) # return n_eig, *run_heatmap(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, current_idx, overlay_image) # parent_button.click( # run_paraent, # inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, current_idx, overlay_image_checkbox], # outputs=[n_eig, parent_plot, current_plot, *child_plots, child_idx], # ) # def run_child(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, child_idx=[], overlay_image=True, i_child=0): # n_eig = int(n_eig*2) # flat_idx = child_idx[i_child] # return n_eig, flat_idx, *run_heatmap(input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, distance_slider, flat_idx, overlay_image) # for i in range(4): # child_buttons[i].click( # partial(run_child, i_child=i), # inputs=[input_gallery, eigvecs, image1_slider, prompt_image1, n_eig, child_distance_slider, child_idx, overlay_image_checkbox], # outputs=[n_eig, current_idx, parent_plot, current_plot, *child_plots, child_idx], # ) with gr.Tab('📄About'): with gr.Column(): gr.Markdown("**This demo is for Python package `ncut-pytorch`, please visit the [Documentation](https://ncut-pytorch.readthedocs.io/)**") gr.Markdown("**All the models and functions used for this demo are in the Python package `ncut-pytorch`**") gr.Markdown("---") gr.Markdown("---") gr.Markdown("**Normalized Cuts**, aka. spectral clustering, is a graphical method to analyze data grouping in the affinity eigenvector space. It has been widely used for unsupervised segmentation in the 2000s.") gr.Markdown("*Normalized Cuts and Image Segmentation, Jianbo Shi and Jitendra Malik, 2000*") gr.Markdown("---") gr.Markdown("**We have improved NCut, with some advanced features:**") gr.Markdown("- **Nyström** Normalized Cut, is a new approximation algorithm developed for large-scale graph cuts, a large-graph of million nodes can be processed in under 10s (cpu) or 2s (gpu).") gr.Markdown("- **spectral-tSNE** visualization, a new method to visualize the high-dimensional eigenvector space with 3D RGB cube. Color is aligned across images, color infers distance in representation.") gr.Markdown("*paper in prep, Yang 2024*") gr.Markdown("*AlignedCut: Visual Concepts Discovery on Brain-Guided Universal Feature Space, Huzheng Yang, James Gee\*, and Jianbo Shi\*, 2024*") gr.Markdown("---") gr.Markdown("---") gr.Markdown('

We thank HuggingFace for hosting this demo.

') # unlock the hidden tab with gr.Row(): with gr.Column(scale=5): gr.Markdown("") with gr.Column(scale=5): hidden_button = gr.Checkbox(label="🤗", value=False, elem_id="unlock_button", visible=True, interactive=True) with gr.Column(scale=5): gr.Markdown("") n_smiles = gr.State(0) unlock_value = 6 def update_smile(n_smiles): n_smiles = n_smiles + 1 n_smiles = unlock_value if n_smiles > unlock_value else n_smiles if n_smiles == unlock_value - 2: gr.Info("click one more time to unlock", 2) if n_smiles == unlock_value: label = "🔓 unlocked" return n_smiles, gr.update(label=label, value=True, interactive=False) label = ["😊"] * n_smiles label = "".join(label) return n_smiles, gr.update(label=label, value=False) def unlock_tabs(n_smiles, n_tab=1): if n_smiles == unlock_value: gr.Info("🔓 unlocked tabs", 2) return [gr.update(visible=True)] * n_tab return [gr.update()] * n_tab hidden_button.change(update_smile, [n_smiles], [n_smiles, hidden_button]) hidden_tabs = [tab_alignedcut_advanced, tab_model_aligned_advanced, tab_recursivecut_advanced, tab_compare_models_advanced, tab_directed_ncut, tab_aligned, tab_lisa] hidden_button.change(partial(unlock_tabs, n_tab=len(hidden_tabs)), [n_smiles], hidden_tabs) with gr.Row(): gr.Markdown("**This demo is for Python package `ncut-pytorch`, [Documentation](https://ncut-pytorch.readthedocs.io/)**") # for local development if os.path.exists("/hf_token.txt"): os.environ["HF_ACCESS_TOKEN"] = open("/hf_token.txt").read().strip() if DOWNLOAD_ALL_MODELS_DATASETS: from ncut_pytorch.backbone import download_all_models # t1 = threading.Thread(target=download_all_models).start() # t1.join() # t3 = threading.Thread(target=download_all_datasets).start() # t3.join() download_all_models() download_all_datasets() from ncut_pytorch.backbone_text import download_all_models # t2 = threading.Thread(target=download_all_models).start() # t2.join() download_all_models() demo.launch(share=True) # # %% # # debug # # change working directory to "/" # os.chdir("/") # images = [(Image.open(image), None) for image in default_images] # ret = run_fn(images, num_eig=30) # # %% # %% # %% # %% # %% # %% # %% # %% # %% # %% # %%