Spaces:
Sleeping
Sleeping
| """ | |
| DINO Correspondence Analysis Module | |
| This module provides functions for analyzing visual correspondences between images | |
| using DINO features, normalized cuts (NCut), and clustering techniques. | |
| """ | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from scipy.optimize import linear_sum_assignment | |
| from einops import rearrange | |
| from extract_features import image_inverse_transform | |
| from ipadapter_model import image_grid | |
| from ncut_pytorch import ncut_fn, kway_ncut, convert_to_lab_color | |
| from ncut_pytorch.color import tsne_color | |
| from ncut_pytorch.utils.gamma import find_gamma_by_degree | |
| # ===== Core NCut and Clustering Functions ===== | |
| def _get_compute_device(tensor: torch.Tensor) -> str: | |
| device_type = tensor.device.type | |
| return device_type if device_type in {"cpu", "cuda"} else "cpu" | |
| def _run_kway_ncut_on_cpu(eigenvectors: torch.Tensor, n_clusters: int) -> torch.Tensor: | |
| """Keep NCut discretization on CPU to avoid unsupported CUDA kernels on HF ZeroGPU.""" | |
| return kway_ncut( | |
| eigenvectors[:, :n_clusters].cpu(), | |
| device="cpu", | |
| ) | |
| def ncut_tsne_multiple_images(image_embeds, n_eig=50, gamma=None, degree=0.5): | |
| """ | |
| Apply NCut and t-SNE coloring to multiple image embeddings. | |
| image_embeds is (batch, length, channels) | |
| """ | |
| batch_size, length, channels = image_embeds.shape | |
| flattened_input = image_embeds.flatten(end_dim=-2) | |
| compute_device = _get_compute_device(flattened_input) | |
| if gamma is None: | |
| gamma = find_gamma_by_degree(flattened_input, degree) | |
| eigenvectors, eigenvalues = ncut_fn( | |
| flattened_input, n_eig=n_eig, gamma=gamma, device=compute_device | |
| ) | |
| rgb_colors = tsne_color(eigenvectors, n_dim=3, device=compute_device, perplexity=50) | |
| rgb_colors = convert_to_lab_color(rgb_colors) | |
| # Reshape back to original batch structure | |
| rgb_colors = rearrange(rgb_colors, '(b l) c -> b l c', b=batch_size) | |
| eigenvectors = rearrange(eigenvectors, '(b l) c -> b l c', b=batch_size) | |
| return eigenvectors, rgb_colors | |
| def _kway_cluster_single_image(image_embeds, n_clusters, gamma=None, degree=0.5): | |
| length, channels = image_embeds.shape | |
| flattened_input = image_embeds.flatten(end_dim=-2) | |
| compute_device = _get_compute_device(flattened_input) | |
| if gamma is None: | |
| gamma = find_gamma_by_degree(flattened_input, degree) | |
| else: | |
| gamma = gamma * image_embeds.var(0).sum().item() | |
| # Calculate number of eigenvectors needed | |
| n_eig = min(n_clusters * 2 + 6, flattened_input.shape[0] // 2 - 1) | |
| eigenvectors, _ = ncut_fn( | |
| flattened_input, n_eig=n_eig, gamma=gamma, device=compute_device | |
| ) | |
| continuous_clusters = _run_kway_ncut_on_cpu(eigenvectors, n_clusters) | |
| return continuous_clusters | |
| def kway_cluster_per_image(image_embeds, n_clusters, gamma=None, degree=0.5): | |
| """ | |
| Perform k-way clustering on each image separately. | |
| image_embeds is (batch, length, channels) | |
| return (batch, length, clusters) | |
| """ | |
| clustered_eigenvectors = [] | |
| for i in range(image_embeds.shape[0]): | |
| eigenvector = _kway_cluster_single_image( | |
| image_embeds[i], n_clusters, gamma, degree | |
| ) | |
| clustered_eigenvectors.append(eigenvector) | |
| return torch.stack(clustered_eigenvectors) | |
| def kway_cluster_multiple_images(image_embeds, n_clusters, gamma=None, degree=0.5): | |
| """ | |
| Perform k-way clustering on multiple images jointly. | |
| image_embeds is (batch, length, channels) | |
| return (batch, length, clusters) | |
| """ | |
| batch_size, length, channels = image_embeds.shape | |
| flattened_input = image_embeds.flatten(end_dim=-2) | |
| compute_device = _get_compute_device(flattened_input) | |
| if gamma is None: | |
| gamma = find_gamma_by_degree(flattened_input, degree) | |
| # Calculate number of eigenvectors needed | |
| n_eig = min(n_clusters * 2 + 6, flattened_input.shape[0] // 2 - 1) | |
| eigenvectors, _ = ncut_fn( | |
| flattened_input, n_eig=n_eig, gamma=gamma, device=compute_device | |
| ) | |
| continuous_clusters = _run_kway_ncut_on_cpu(eigenvectors, n_clusters) | |
| continuous_clusters = rearrange( | |
| continuous_clusters, '(b l) c -> b l c', b=batch_size | |
| ) | |
| return continuous_clusters | |
| # ===== Color and Visualization Functions ===== | |
| def get_discrete_colors_from_clusters(joint_colors, cluster_eigenvectors): | |
| n_clusters = cluster_eigenvectors.shape[-1] | |
| discrete_colors = np.zeros_like(joint_colors) | |
| for img_idx in range(joint_colors.shape[0]): | |
| colors = joint_colors[img_idx] | |
| eigenvector = cluster_eigenvectors[img_idx].cpu().numpy() | |
| cluster_labels = eigenvector.argmax(-1) | |
| discrete_img_colors = np.zeros_like(colors) | |
| for cluster_idx in range(n_clusters): | |
| cluster_mask = cluster_labels == cluster_idx | |
| if cluster_mask.sum() > 0: | |
| # Use mean color for each cluster | |
| discrete_img_colors[cluster_mask] = colors[cluster_mask].mean(0) | |
| discrete_colors[img_idx] = discrete_img_colors | |
| # Convert to uint8 format | |
| discrete_colors = (discrete_colors * 255).astype(np.uint8) | |
| return discrete_colors | |
| # ===== Center Matching Functions ===== | |
| def get_cluster_center_features(image_embeds, cluster_labels, n_clusters): | |
| center_features = torch.zeros((n_clusters, image_embeds.shape[-1])) | |
| for cluster_idx in range(n_clusters): | |
| cluster_mask = cluster_labels == cluster_idx | |
| if cluster_mask.sum() > 0: | |
| center_features[cluster_idx] = image_embeds[cluster_mask].mean(0) | |
| else: | |
| # Use a unique identifier for empty clusters | |
| center_features[cluster_idx] = torch.ones_like(image_embeds[0]) * 114514 | |
| return center_features | |
| def cosine_similarity(matrix_a, matrix_b): | |
| normalized_a = matrix_a / matrix_a.norm(dim=-1, keepdim=True) | |
| normalized_b = matrix_b / matrix_b.norm(dim=-1, keepdim=True) | |
| return normalized_a @ normalized_b.T | |
| def hungarian_match_centers(center_features1, center_features2): | |
| distances = torch.cdist(center_features1, center_features2) | |
| distances = distances.cpu().detach().numpy() | |
| _, column_indices = linear_sum_assignment(distances) | |
| return column_indices | |
| def argmin_matching(center_features1, center_features2): | |
| distances = torch.cdist(center_features1, center_features2) | |
| distances = distances.cpu().detach().numpy() | |
| return np.argmin(distances, axis=-1) | |
| def match_cluster_centers(image_embed1, image_embed2, eigvec1, eigvec2, | |
| match_method='hungarian'): | |
| cluster_labels1 = eigvec1.argmax(-1).cpu().numpy() | |
| cluster_labels2 = eigvec2.argmax(-1).cpu().numpy() | |
| center_features1 = get_cluster_center_features( | |
| image_embed1, cluster_labels1, eigvec1.shape[-1] | |
| ) | |
| center_features2 = get_cluster_center_features( | |
| image_embed2, cluster_labels2, eigvec2.shape[-1] | |
| ) | |
| if match_method == 'hungarian': | |
| mapping = hungarian_match_centers(center_features1, center_features2) | |
| elif match_method == 'argmin': | |
| mapping = argmin_matching(center_features1, center_features2) | |
| else: | |
| raise ValueError(f"Unknown match_method: {match_method}") | |
| return mapping | |
| def match_centers_three_images(image_embeds, eigenvectors, match_method='hungarian'): | |
| """ | |
| Match cluster centers across three images (A2 -> A1 -> B1). | |
| Args: | |
| image_embeds (torch.Tensor): Embeddings for 3 images [A2, A1, B1] | |
| eigenvectors (torch.Tensor): Eigenvectors for 3 images | |
| match_method (str): Matching method | |
| Returns: | |
| tuple: (A2_to_A1_mapping, A1_to_B1_mapping) | |
| """ | |
| a2_to_a1_mapping = match_cluster_centers( | |
| image_embeds[0], image_embeds[1], | |
| eigenvectors[0], eigenvectors[1], | |
| match_method=match_method | |
| ) | |
| a1_to_b1_mapping = match_cluster_centers( | |
| image_embeds[1], image_embeds[2], | |
| eigenvectors[1], eigenvectors[2], | |
| match_method=match_method | |
| ) | |
| return a2_to_a1_mapping, a1_to_b1_mapping | |
| def match_centers_two_images(image_embed1, image_embed2, eigvec1, eigvec2, | |
| match_method='hungarian'): | |
| return match_cluster_centers( | |
| image_embed1, image_embed2, eigvec1, eigvec2, match_method=match_method | |
| ) | |
| # ===== Two-Step Clustering Functions ===== | |
| def kway_cluster_per_image_two_step( | |
| image_embeds, | |
| n_superclusters, | |
| n_subclusters_per_supercluster, | |
| supercluster_gamma=None, | |
| subcluster_gamma=None, | |
| degree=0.5 | |
| ): | |
| """ | |
| Perform 2-step hierarchical clustering on each image separately. | |
| First finds superclusters, then subdivides each supercluster into subclusters. | |
| Args: | |
| image_embeds: (batch, length, channels) - Image embeddings | |
| n_superclusters: Number of coarse superclusters to find | |
| n_subclusters_per_supercluster: Number of subclusters within each supercluster | |
| supercluster_gamma: Gamma parameter for supercluster NCut (None = auto) | |
| subcluster_gamma: Gamma parameter for subcluster NCut (None = auto) | |
| degree: Degree parameter for gamma estimation | |
| Returns: | |
| tuple: (supercluster_eigenvectors, subcluster_eigenvectors, subcluster_to_supercluster_mapping) | |
| - supercluster_eigenvectors: (batch, length, n_superclusters) | |
| - subcluster_eigenvectors: (batch, length, total_subclusters) | |
| - subcluster_to_supercluster_mapping: (batch, total_subclusters) mapping each subcluster to its supercluster | |
| """ | |
| batch_size = image_embeds.shape[0] | |
| # Step 1: Compute superclusters for each image | |
| supercluster_eigenvectors = [] | |
| for i in range(batch_size): | |
| eigenvector = _kway_cluster_single_image( | |
| image_embeds[i], n_superclusters, supercluster_gamma, degree | |
| ) | |
| supercluster_eigenvectors.append(eigenvector) | |
| supercluster_eigenvectors = torch.stack(supercluster_eigenvectors) | |
| # Step 2: For each supercluster in each image, compute subclusters | |
| subcluster_eigenvectors = [] | |
| subcluster_to_supercluster_mapping = [] | |
| for img_idx in range(batch_size): | |
| img_subclusters = [] | |
| img_mapping = [] | |
| supercluster_labels = supercluster_eigenvectors[img_idx].argmax(-1) | |
| # For each supercluster, extract tokens and compute subclusters | |
| for supercluster_idx in range(n_superclusters): | |
| supercluster_mask = supercluster_labels == supercluster_idx | |
| if supercluster_mask.sum() == 0: | |
| # Empty supercluster - create dummy subclusters | |
| for sub_idx in range(n_subclusters_per_supercluster): | |
| img_mapping.append(supercluster_idx) | |
| continue | |
| # Extract features belonging to this supercluster | |
| supercluster_features = image_embeds[img_idx][supercluster_mask] | |
| # Perform clustering on this subset | |
| if supercluster_features.shape[0] <= n_subclusters_per_supercluster: | |
| # Too few tokens - each token becomes its own subcluster | |
| n_actual_subclusters = supercluster_features.shape[0] | |
| subcluster_labels = torch.arange(n_actual_subclusters).to(supercluster_features.device) | |
| # Pad with dummy subclusters if needed | |
| for sub_idx in range(n_subclusters_per_supercluster): | |
| img_mapping.append(supercluster_idx) | |
| else: | |
| # Perform subclustering | |
| subcluster_eigvecs = _kway_cluster_single_image( | |
| supercluster_features, | |
| n_subclusters_per_supercluster, | |
| subcluster_gamma, | |
| degree | |
| ) | |
| subcluster_labels = subcluster_eigvecs.argmax(-1) | |
| # Track which supercluster these subclusters belong to | |
| for sub_idx in range(n_subclusters_per_supercluster): | |
| img_mapping.append(supercluster_idx) | |
| # Store subcluster assignments for this supercluster | |
| for sub_idx in range(n_subclusters_per_supercluster): | |
| img_subclusters.append((supercluster_mask, subcluster_labels == sub_idx if supercluster_features.shape[0] > n_subclusters_per_supercluster else None)) | |
| # Convert to full eigenvector representation | |
| total_subclusters = n_superclusters * n_subclusters_per_supercluster | |
| img_subcluster_eigvec = torch.zeros((image_embeds.shape[1], total_subclusters)).to(image_embeds.device) | |
| for subcluster_global_idx, (supercluster_mask, subcluster_mask) in enumerate(img_subclusters): | |
| if subcluster_mask is not None: | |
| # Combine masks: belongs to supercluster AND subcluster | |
| final_mask = torch.zeros(image_embeds.shape[1], dtype=torch.bool).to(image_embeds.device) | |
| supercluster_indices = torch.where(supercluster_mask)[0] | |
| subcluster_within_super = torch.where(subcluster_mask)[0] | |
| if len(subcluster_within_super) > 0: | |
| final_indices = supercluster_indices[subcluster_within_super] | |
| final_mask[final_indices] = True | |
| img_subcluster_eigvec[final_mask, subcluster_global_idx] = 1.0 | |
| # else: leave as zeros (empty subcluster) | |
| subcluster_eigenvectors.append(img_subcluster_eigvec) | |
| subcluster_to_supercluster_mapping.append(torch.tensor(img_mapping)) | |
| subcluster_eigenvectors = torch.stack(subcluster_eigenvectors) | |
| subcluster_to_supercluster_mapping = torch.stack(subcluster_to_supercluster_mapping) | |
| return supercluster_eigenvectors, subcluster_eigenvectors, subcluster_to_supercluster_mapping | |
| def match_centers_two_step( | |
| image_embed1, | |
| image_embed2, | |
| supercluster_eigvec1, | |
| supercluster_eigvec2, | |
| subcluster_eigvec1, | |
| subcluster_eigvec2, | |
| subcluster_to_supercluster_mapping1, | |
| subcluster_to_supercluster_mapping2, | |
| supercluster_match_method='hungarian', | |
| subcluster_match_method='hungarian' | |
| ): | |
| """ | |
| Match clusters using 2-step hierarchical approach. | |
| First matches superclusters, then matches subclusters only within matched superclusters. | |
| Args: | |
| image_embed1, image_embed2: Image embeddings (length, channels) | |
| supercluster_eigvec1, supercluster_eigvec2: Supercluster eigenvectors (length, n_superclusters) | |
| subcluster_eigvec1, subcluster_eigvec2: Subcluster eigenvectors (length, total_subclusters) | |
| subcluster_to_supercluster_mapping1, subcluster_to_supercluster_mapping2: (total_subclusters,) | |
| supercluster_match_method: Matching method for superclusters | |
| subcluster_match_method: Matching method for subclusters | |
| Returns: | |
| np.ndarray: Mapping from image1 subclusters to image2 subclusters | |
| """ | |
| n_superclusters = supercluster_eigvec1.shape[-1] | |
| n_subclusters_total = subcluster_eigvec1.shape[-1] | |
| # Step 1: Match superclusters | |
| supercluster_mapping = match_cluster_centers( | |
| image_embed1, image_embed2, | |
| supercluster_eigvec1, supercluster_eigvec2, | |
| match_method=supercluster_match_method | |
| ) | |
| # Step 2: For each matched supercluster pair, match subclusters within them | |
| subcluster_mapping = np.zeros(n_subclusters_total, dtype=np.int64) | |
| for supercluster1_idx in range(n_superclusters): | |
| # Find which supercluster in image2 this maps to | |
| supercluster2_idx = supercluster_mapping[supercluster1_idx] | |
| # Find all subclusters belonging to these superclusters | |
| subclusters1_mask = (subcluster_to_supercluster_mapping1 == supercluster1_idx).cpu().numpy() | |
| subclusters2_mask = (subcluster_to_supercluster_mapping2 == supercluster2_idx).cpu().numpy() | |
| subclusters1_indices = np.where(subclusters1_mask)[0] | |
| subclusters2_indices = np.where(subclusters2_mask)[0] | |
| if len(subclusters1_indices) == 0 or len(subclusters2_indices) == 0: | |
| # No subclusters in one or both superclusters - use identity mapping | |
| for sub1_idx in subclusters1_indices: | |
| if sub1_idx < len(subclusters2_indices): | |
| subcluster_mapping[sub1_idx] = subclusters2_indices[sub1_idx] | |
| else: | |
| subcluster_mapping[sub1_idx] = subclusters2_indices[0] if len(subclusters2_indices) > 0 else 0 | |
| continue | |
| # Extract subcluster eigenvectors for matching | |
| sub_eigvec1 = subcluster_eigvec1[:, subclusters1_indices] | |
| sub_eigvec2 = subcluster_eigvec2[:, subclusters2_indices] | |
| # Compute cluster centers for these subclusters | |
| cluster_labels1 = sub_eigvec1.argmax(-1).cpu() | |
| cluster_labels2 = sub_eigvec2.argmax(-1).cpu() | |
| center_features1 = get_cluster_center_features( | |
| image_embed1, cluster_labels1, len(subclusters1_indices) | |
| ) | |
| center_features2 = get_cluster_center_features( | |
| image_embed2, cluster_labels2, len(subclusters2_indices) | |
| ) | |
| # Match subclusters within this supercluster pair | |
| if subcluster_match_method == 'hungarian': | |
| local_mapping = hungarian_match_centers(center_features1, center_features2) | |
| elif subcluster_match_method == 'argmin': | |
| local_mapping = argmin_matching(center_features1, center_features2) | |
| else: | |
| raise ValueError(f"Unknown subcluster_match_method: {subcluster_match_method}") | |
| # Convert local mapping to global subcluster indices | |
| for local_idx, global_idx1 in enumerate(subclusters1_indices): | |
| global_idx2 = subclusters2_indices[local_mapping[local_idx]] | |
| subcluster_mapping[global_idx1] = global_idx2 | |
| return subcluster_mapping | |
| def kway_cluster_per_image_two_step_fgbg( | |
| image_embeds, | |
| n_foreground_subclusters, | |
| n_background_subclusters, | |
| supercluster_gamma=None, | |
| subcluster_gamma=None, | |
| degree=0.5 | |
| ): | |
| """ | |
| Perform 2-step hierarchical clustering with automatic foreground/background separation. | |
| First separates foreground (FG) and background (BG) using 2 clusters, identifying FG | |
| by the cluster with highest max eigenvector value. Then subdivides FG and BG separately. | |
| Args: | |
| image_embeds: (batch, length, channels) - Image embeddings | |
| n_foreground_subclusters: Number of subclusters within foreground | |
| n_background_subclusters: Number of subclusters within background | |
| supercluster_gamma: Gamma parameter for FG/BG clustering (None = auto) | |
| subcluster_gamma: Gamma parameter for subcluster NCut (None = auto) | |
| degree: Degree parameter for gamma estimation | |
| Returns: | |
| tuple: (supercluster_eigenvectors, subcluster_eigenvectors, subcluster_to_supercluster_mapping, fg_indices) | |
| - supercluster_eigenvectors: (batch, length, 2) - [BG, FG] clusters | |
| - subcluster_eigenvectors: (batch, length, total_subclusters) | |
| - subcluster_to_supercluster_mapping: (batch, total_subclusters) - 0=BG, 1=FG | |
| - fg_indices: (batch,) - which supercluster index is foreground for each image | |
| """ | |
| batch_size = image_embeds.shape[0] | |
| n_superclusters = 2 # Always FG and BG | |
| # Step 1: Compute FG/BG separation for each image | |
| supercluster_eigenvectors = [] | |
| fg_indices = [] | |
| for i in range(batch_size): | |
| eigenvector = _kway_cluster_single_image( | |
| image_embeds[i], n_clusters=2, gamma=supercluster_gamma, degree=degree | |
| ) | |
| supercluster_eigenvectors.append(eigenvector) | |
| # Identify foreground: cluster with highest max eigenvector value | |
| fg_idx = eigenvector.max(0).values.argmax().item() | |
| fg_indices.append(fg_idx) | |
| supercluster_eigenvectors = torch.stack(supercluster_eigenvectors) | |
| fg_indices = torch.tensor(fg_indices) | |
| # Step 2: For each image, compute subclusters within FG and BG | |
| subcluster_eigenvectors = [] | |
| subcluster_to_supercluster_mapping = [] | |
| for img_idx in range(batch_size): | |
| img_subclusters = [] | |
| img_mapping = [] | |
| supercluster_labels = supercluster_eigenvectors[img_idx].argmax(-1) | |
| fg_idx = fg_indices[img_idx].item() | |
| bg_idx = 1 - fg_idx | |
| # Process BG and FG in order (BG first, then FG) | |
| for is_foreground, n_subclusters in [(False, n_background_subclusters), (True, n_foreground_subclusters)]: | |
| supercluster_idx = fg_idx if is_foreground else bg_idx | |
| supercluster_mask = supercluster_labels == supercluster_idx | |
| # Mark which supercluster type (0=BG, 1=FG) | |
| supercluster_type = 1 if is_foreground else 0 | |
| if supercluster_mask.sum() == 0: | |
| # Empty supercluster - create dummy subclusters | |
| for sub_idx in range(n_subclusters): | |
| img_mapping.append(supercluster_type) | |
| img_subclusters.append((supercluster_mask, None)) | |
| continue | |
| # Extract features belonging to this supercluster | |
| supercluster_features = image_embeds[img_idx][supercluster_mask] | |
| # Perform clustering on this subset | |
| if supercluster_features.shape[0] <= n_subclusters: | |
| # Too few tokens - each token becomes its own subcluster | |
| n_actual_subclusters = supercluster_features.shape[0] | |
| subcluster_labels = torch.arange(n_actual_subclusters).to(supercluster_features.device) | |
| # Pad with dummy subclusters if needed | |
| for sub_idx in range(n_subclusters): | |
| img_mapping.append(supercluster_type) | |
| if sub_idx < n_actual_subclusters: | |
| img_subclusters.append((supercluster_mask, subcluster_labels == sub_idx)) | |
| else: | |
| img_subclusters.append((supercluster_mask, None)) | |
| else: | |
| # Perform subclustering | |
| subcluster_eigvecs = _kway_cluster_single_image( | |
| supercluster_features, | |
| n_subclusters, | |
| subcluster_gamma, | |
| degree | |
| ) | |
| subcluster_labels = subcluster_eigvecs.argmax(-1) | |
| # Store subcluster assignments | |
| for sub_idx in range(n_subclusters): | |
| img_mapping.append(supercluster_type) | |
| img_subclusters.append((supercluster_mask, subcluster_labels == sub_idx)) | |
| # Convert to full eigenvector representation | |
| total_subclusters = n_background_subclusters + n_foreground_subclusters | |
| img_subcluster_eigvec = torch.zeros((image_embeds.shape[1], total_subclusters)).to(image_embeds.device) | |
| for subcluster_global_idx, (supercluster_mask, subcluster_mask) in enumerate(img_subclusters): | |
| if subcluster_mask is not None: | |
| # Combine masks: belongs to supercluster AND subcluster | |
| final_mask = torch.zeros(image_embeds.shape[1], dtype=torch.bool).to(image_embeds.device) | |
| supercluster_indices = torch.where(supercluster_mask)[0] | |
| subcluster_within_super = torch.where(subcluster_mask)[0] | |
| if len(subcluster_within_super) > 0: | |
| final_indices = supercluster_indices[subcluster_within_super] | |
| final_mask[final_indices] = True | |
| img_subcluster_eigvec[final_mask, subcluster_global_idx] = 1.0 | |
| # else: leave as zeros (empty subcluster) | |
| subcluster_eigenvectors.append(img_subcluster_eigvec) | |
| subcluster_to_supercluster_mapping.append(torch.tensor(img_mapping)) | |
| subcluster_eigenvectors = torch.stack(subcluster_eigenvectors) | |
| subcluster_to_supercluster_mapping = torch.stack(subcluster_to_supercluster_mapping) | |
| return supercluster_eigenvectors, subcluster_eigenvectors, subcluster_to_supercluster_mapping, fg_indices | |
| def match_centers_two_step_fgbg( | |
| image_embed1, | |
| image_embed2, | |
| subcluster_eigvec1, | |
| subcluster_eigvec2, | |
| subcluster_to_supercluster_mapping1, | |
| subcluster_to_supercluster_mapping2, | |
| n_background_subclusters, | |
| n_foreground_subclusters, | |
| background_match_method='hungarian', | |
| foreground_match_method='hungarian' | |
| ): | |
| """ | |
| Match clusters using 2-step FG/BG hierarchical approach. | |
| FG and BG are automatically matched (no need for supercluster matching). | |
| Subclusters are matched within their respective FG or BG groups. | |
| Args: | |
| image_embed1, image_embed2: Image embeddings (length, channels) | |
| subcluster_eigvec1, subcluster_eigvec2: Subcluster eigenvectors (length, total_subclusters) | |
| subcluster_to_supercluster_mapping1, subcluster_to_supercluster_mapping2: (total_subclusters,) - 0=BG, 1=FG | |
| n_background_subclusters: Number of background subclusters | |
| n_foreground_subclusters: Number of foreground subclusters | |
| background_match_method: Matching method for background subclusters | |
| foreground_match_method: Matching method for foreground subclusters | |
| Returns: | |
| np.ndarray: Mapping from image1 subclusters to image2 subclusters | |
| """ | |
| total_subclusters = n_background_subclusters + n_foreground_subclusters | |
| subcluster_mapping = np.zeros(total_subclusters, dtype=np.int64) | |
| # Process BG (supercluster_type=0) and FG (supercluster_type=1) separately | |
| for supercluster_type in [0, 1]: # 0=BG, 1=FG | |
| # Find subclusters belonging to this supercluster type | |
| subclusters1_mask = (subcluster_to_supercluster_mapping1 == supercluster_type).cpu().numpy() | |
| subclusters2_mask = (subcluster_to_supercluster_mapping2 == supercluster_type).cpu().numpy() | |
| subclusters1_indices = np.where(subclusters1_mask)[0] | |
| subclusters2_indices = np.where(subclusters2_mask)[0] | |
| if len(subclusters1_indices) == 0 or len(subclusters2_indices) == 0: | |
| # No subclusters in one or both - use identity mapping | |
| for sub1_idx in subclusters1_indices: | |
| if sub1_idx < len(subclusters2_indices): | |
| subcluster_mapping[sub1_idx] = subclusters2_indices[sub1_idx] | |
| else: | |
| subcluster_mapping[sub1_idx] = subclusters2_indices[0] if len(subclusters2_indices) > 0 else 0 | |
| continue | |
| # Extract subcluster eigenvectors for matching | |
| sub_eigvec1 = subcluster_eigvec1[:, subclusters1_indices] | |
| sub_eigvec2 = subcluster_eigvec2[:, subclusters2_indices] | |
| # Compute cluster centers for these subclusters | |
| cluster_labels1 = sub_eigvec1.argmax(-1).cpu() | |
| cluster_labels2 = sub_eigvec2.argmax(-1).cpu() | |
| center_features1 = get_cluster_center_features( | |
| image_embed1, cluster_labels1, len(subclusters1_indices) | |
| ) | |
| center_features2 = get_cluster_center_features( | |
| image_embed2, cluster_labels2, len(subclusters2_indices) | |
| ) | |
| # Match subclusters within this FG/BG group | |
| match_method = foreground_match_method if supercluster_type == 1 else background_match_method | |
| if match_method == 'hungarian': | |
| local_mapping = hungarian_match_centers(center_features1, center_features2) | |
| elif match_method == 'argmin': | |
| local_mapping = argmin_matching(center_features1, center_features2) | |
| else: | |
| raise ValueError(f"Unknown match_method: {match_method}") | |
| # Convert local mapping to global subcluster indices | |
| for local_idx, global_idx1 in enumerate(subclusters1_indices): | |
| global_idx2 = subclusters2_indices[local_mapping[local_idx]] | |
| subcluster_mapping[global_idx1] = global_idx2 | |
| return subcluster_mapping | |
| # ===== Visualization Functions ===== | |
| def plot_cluster_masks(image, eigenvector, cluster_order, hw=16): | |
| """ | |
| blend the image with the cluster masks | |
| # image is (c, h, w) | |
| # eigenvector is (h*w, n_eig) | |
| # cluster_order is (n_eig), the order of the clusters | |
| """ | |
| cluster_images = [] | |
| base_img = image_inverse_transform(image).resize( | |
| (128, 128), resample=Image.Resampling.NEAREST | |
| ) | |
| for cluster_idx in cluster_order: | |
| # Create cluster mask | |
| cluster_mask = eigenvector.argmax(-1) == cluster_idx | |
| mask_array = cluster_mask.cpu().numpy()[1:].reshape(hw, hw) | |
| mask_array = (mask_array * 255).astype(np.uint8) | |
| # Resize mask to match image | |
| mask_img = Image.fromarray(mask_array).resize( | |
| (128, 128), resample=Image.Resampling.NEAREST | |
| ) | |
| # Apply mask to image | |
| mask_normalized = np.array(mask_img).astype(np.float32) / 255 | |
| img_array = np.array(base_img).astype(np.float32) / 255 | |
| # Create 3-channel mask and apply | |
| mask_3ch = np.stack([mask_normalized] * 3, axis=-1) | |
| mask_3ch[mask_3ch == 0] = 0.1 # Dim non-masked areas | |
| masked_img = img_array * mask_3ch | |
| masked_img = (masked_img * 255).astype(np.uint8) | |
| cluster_images.append(Image.fromarray(masked_img)) | |
| return cluster_images | |
| def create_image_grid_row(image, eigenvector, cluster_order, discrete_colors, | |
| hw=16, n_cols=10): | |
| cluster_images = plot_cluster_masks(image, eigenvector, cluster_order, hw) | |
| # Prepare base images | |
| base_img = image_inverse_transform(image).resize( | |
| (128, 128), resample=Image.Resampling.NEAREST | |
| ) | |
| ncut_visualization = discrete_colors[1:].reshape(hw, hw, 3) | |
| ncut_img = Image.fromarray(ncut_visualization).resize( | |
| (128, 128), resample=Image.Resampling.NEAREST | |
| ) | |
| # Pad cluster images to fill grid | |
| num_missing = n_cols - len(cluster_images) % n_cols | |
| if num_missing != n_cols: | |
| empty_img = Image.fromarray(np.zeros((128, 128, 3), dtype=np.uint8)) | |
| cluster_images.extend([empty_img] * num_missing) | |
| # Create grid rows | |
| prepend_images = [base_img, ncut_img] | |
| n_rows = len(cluster_images) // n_cols | |
| grid_rows = [] | |
| for row_idx in range(n_rows): | |
| start_idx = row_idx * n_cols | |
| end_idx = (row_idx + 1) * n_cols | |
| row_images = prepend_images + cluster_images[start_idx:end_idx] | |
| grid_rows.append(row_images) | |
| return grid_rows | |
| def create_multi_image_grid(images, eigenvectors, cluster_orders, discrete_colors, | |
| hw=16, n_cols=10): | |
| all_grid_rows = [] | |
| for image, eigvec, cluster_order, discrete_rgb in zip( | |
| images, eigenvectors, cluster_orders, discrete_colors | |
| ): | |
| grid_rows = create_image_grid_row( | |
| image, eigvec, cluster_order, discrete_rgb, hw, n_cols | |
| ) | |
| all_grid_rows.append(grid_rows) | |
| # Interleave rows from different images | |
| interleaved_rows = [] | |
| for row_idx in range(len(all_grid_rows[0])): | |
| for img_idx in range(len(all_grid_rows)): | |
| interleaved_rows.append(all_grid_rows[img_idx][row_idx]) | |
| return interleaved_rows | |
| def get_correspondence_plot(images, eigenvectors, cluster_orders, discrete_colors, | |
| hw=16, n_cols=10): | |
| n_clusters = eigenvectors.shape[-1] | |
| n_cols = min(n_cols, n_clusters) | |
| interleaved_rows = create_multi_image_grid( | |
| images, eigenvectors, cluster_orders, discrete_colors, hw, n_cols | |
| ) | |
| n_rows = len(interleaved_rows) | |
| n_cols = len(interleaved_rows[0]) | |
| # Flatten all images and create final grid | |
| all_images = sum(interleaved_rows, []) | |
| final_grid = image_grid(all_images, n_rows, n_cols) | |
| return final_grid | |