VibeSpace / dino_correspondence.py
huzey's picture
Run kway ncut on CPU
ba2b064
"""
DINO Correspondence Analysis Module
This module provides functions for analyzing visual correspondences between images
using DINO features, normalized cuts (NCut), and clustering techniques.
"""
import numpy as np
import torch
from PIL import Image
from scipy.optimize import linear_sum_assignment
from einops import rearrange
from extract_features import image_inverse_transform
from ipadapter_model import image_grid
from ncut_pytorch import ncut_fn, kway_ncut, convert_to_lab_color
from ncut_pytorch.color import tsne_color
from ncut_pytorch.utils.gamma import find_gamma_by_degree
# ===== Core NCut and Clustering Functions =====
def _get_compute_device(tensor: torch.Tensor) -> str:
device_type = tensor.device.type
return device_type if device_type in {"cpu", "cuda"} else "cpu"
def _run_kway_ncut_on_cpu(eigenvectors: torch.Tensor, n_clusters: int) -> torch.Tensor:
"""Keep NCut discretization on CPU to avoid unsupported CUDA kernels on HF ZeroGPU."""
return kway_ncut(
eigenvectors[:, :n_clusters].cpu(),
device="cpu",
)
def ncut_tsne_multiple_images(image_embeds, n_eig=50, gamma=None, degree=0.5):
"""
Apply NCut and t-SNE coloring to multiple image embeddings.
image_embeds is (batch, length, channels)
"""
batch_size, length, channels = image_embeds.shape
flattened_input = image_embeds.flatten(end_dim=-2)
compute_device = _get_compute_device(flattened_input)
if gamma is None:
gamma = find_gamma_by_degree(flattened_input, degree)
eigenvectors, eigenvalues = ncut_fn(
flattened_input, n_eig=n_eig, gamma=gamma, device=compute_device
)
rgb_colors = tsne_color(eigenvectors, n_dim=3, device=compute_device, perplexity=50)
rgb_colors = convert_to_lab_color(rgb_colors)
# Reshape back to original batch structure
rgb_colors = rearrange(rgb_colors, '(b l) c -> b l c', b=batch_size)
eigenvectors = rearrange(eigenvectors, '(b l) c -> b l c', b=batch_size)
return eigenvectors, rgb_colors
def _kway_cluster_single_image(image_embeds, n_clusters, gamma=None, degree=0.5):
length, channels = image_embeds.shape
flattened_input = image_embeds.flatten(end_dim=-2)
compute_device = _get_compute_device(flattened_input)
if gamma is None:
gamma = find_gamma_by_degree(flattened_input, degree)
else:
gamma = gamma * image_embeds.var(0).sum().item()
# Calculate number of eigenvectors needed
n_eig = min(n_clusters * 2 + 6, flattened_input.shape[0] // 2 - 1)
eigenvectors, _ = ncut_fn(
flattened_input, n_eig=n_eig, gamma=gamma, device=compute_device
)
continuous_clusters = _run_kway_ncut_on_cpu(eigenvectors, n_clusters)
return continuous_clusters
def kway_cluster_per_image(image_embeds, n_clusters, gamma=None, degree=0.5):
"""
Perform k-way clustering on each image separately.
image_embeds is (batch, length, channels)
return (batch, length, clusters)
"""
clustered_eigenvectors = []
for i in range(image_embeds.shape[0]):
eigenvector = _kway_cluster_single_image(
image_embeds[i], n_clusters, gamma, degree
)
clustered_eigenvectors.append(eigenvector)
return torch.stack(clustered_eigenvectors)
def kway_cluster_multiple_images(image_embeds, n_clusters, gamma=None, degree=0.5):
"""
Perform k-way clustering on multiple images jointly.
image_embeds is (batch, length, channels)
return (batch, length, clusters)
"""
batch_size, length, channels = image_embeds.shape
flattened_input = image_embeds.flatten(end_dim=-2)
compute_device = _get_compute_device(flattened_input)
if gamma is None:
gamma = find_gamma_by_degree(flattened_input, degree)
# Calculate number of eigenvectors needed
n_eig = min(n_clusters * 2 + 6, flattened_input.shape[0] // 2 - 1)
eigenvectors, _ = ncut_fn(
flattened_input, n_eig=n_eig, gamma=gamma, device=compute_device
)
continuous_clusters = _run_kway_ncut_on_cpu(eigenvectors, n_clusters)
continuous_clusters = rearrange(
continuous_clusters, '(b l) c -> b l c', b=batch_size
)
return continuous_clusters
# ===== Color and Visualization Functions =====
def get_discrete_colors_from_clusters(joint_colors, cluster_eigenvectors):
n_clusters = cluster_eigenvectors.shape[-1]
discrete_colors = np.zeros_like(joint_colors)
for img_idx in range(joint_colors.shape[0]):
colors = joint_colors[img_idx]
eigenvector = cluster_eigenvectors[img_idx].cpu().numpy()
cluster_labels = eigenvector.argmax(-1)
discrete_img_colors = np.zeros_like(colors)
for cluster_idx in range(n_clusters):
cluster_mask = cluster_labels == cluster_idx
if cluster_mask.sum() > 0:
# Use mean color for each cluster
discrete_img_colors[cluster_mask] = colors[cluster_mask].mean(0)
discrete_colors[img_idx] = discrete_img_colors
# Convert to uint8 format
discrete_colors = (discrete_colors * 255).astype(np.uint8)
return discrete_colors
# ===== Center Matching Functions =====
def get_cluster_center_features(image_embeds, cluster_labels, n_clusters):
center_features = torch.zeros((n_clusters, image_embeds.shape[-1]))
for cluster_idx in range(n_clusters):
cluster_mask = cluster_labels == cluster_idx
if cluster_mask.sum() > 0:
center_features[cluster_idx] = image_embeds[cluster_mask].mean(0)
else:
# Use a unique identifier for empty clusters
center_features[cluster_idx] = torch.ones_like(image_embeds[0]) * 114514
return center_features
def cosine_similarity(matrix_a, matrix_b):
normalized_a = matrix_a / matrix_a.norm(dim=-1, keepdim=True)
normalized_b = matrix_b / matrix_b.norm(dim=-1, keepdim=True)
return normalized_a @ normalized_b.T
def hungarian_match_centers(center_features1, center_features2):
distances = torch.cdist(center_features1, center_features2)
distances = distances.cpu().detach().numpy()
_, column_indices = linear_sum_assignment(distances)
return column_indices
def argmin_matching(center_features1, center_features2):
distances = torch.cdist(center_features1, center_features2)
distances = distances.cpu().detach().numpy()
return np.argmin(distances, axis=-1)
def match_cluster_centers(image_embed1, image_embed2, eigvec1, eigvec2,
match_method='hungarian'):
cluster_labels1 = eigvec1.argmax(-1).cpu().numpy()
cluster_labels2 = eigvec2.argmax(-1).cpu().numpy()
center_features1 = get_cluster_center_features(
image_embed1, cluster_labels1, eigvec1.shape[-1]
)
center_features2 = get_cluster_center_features(
image_embed2, cluster_labels2, eigvec2.shape[-1]
)
if match_method == 'hungarian':
mapping = hungarian_match_centers(center_features1, center_features2)
elif match_method == 'argmin':
mapping = argmin_matching(center_features1, center_features2)
else:
raise ValueError(f"Unknown match_method: {match_method}")
return mapping
def match_centers_three_images(image_embeds, eigenvectors, match_method='hungarian'):
"""
Match cluster centers across three images (A2 -> A1 -> B1).
Args:
image_embeds (torch.Tensor): Embeddings for 3 images [A2, A1, B1]
eigenvectors (torch.Tensor): Eigenvectors for 3 images
match_method (str): Matching method
Returns:
tuple: (A2_to_A1_mapping, A1_to_B1_mapping)
"""
a2_to_a1_mapping = match_cluster_centers(
image_embeds[0], image_embeds[1],
eigenvectors[0], eigenvectors[1],
match_method=match_method
)
a1_to_b1_mapping = match_cluster_centers(
image_embeds[1], image_embeds[2],
eigenvectors[1], eigenvectors[2],
match_method=match_method
)
return a2_to_a1_mapping, a1_to_b1_mapping
def match_centers_two_images(image_embed1, image_embed2, eigvec1, eigvec2,
match_method='hungarian'):
return match_cluster_centers(
image_embed1, image_embed2, eigvec1, eigvec2, match_method=match_method
)
# ===== Two-Step Clustering Functions =====
def kway_cluster_per_image_two_step(
image_embeds,
n_superclusters,
n_subclusters_per_supercluster,
supercluster_gamma=None,
subcluster_gamma=None,
degree=0.5
):
"""
Perform 2-step hierarchical clustering on each image separately.
First finds superclusters, then subdivides each supercluster into subclusters.
Args:
image_embeds: (batch, length, channels) - Image embeddings
n_superclusters: Number of coarse superclusters to find
n_subclusters_per_supercluster: Number of subclusters within each supercluster
supercluster_gamma: Gamma parameter for supercluster NCut (None = auto)
subcluster_gamma: Gamma parameter for subcluster NCut (None = auto)
degree: Degree parameter for gamma estimation
Returns:
tuple: (supercluster_eigenvectors, subcluster_eigenvectors, subcluster_to_supercluster_mapping)
- supercluster_eigenvectors: (batch, length, n_superclusters)
- subcluster_eigenvectors: (batch, length, total_subclusters)
- subcluster_to_supercluster_mapping: (batch, total_subclusters) mapping each subcluster to its supercluster
"""
batch_size = image_embeds.shape[0]
# Step 1: Compute superclusters for each image
supercluster_eigenvectors = []
for i in range(batch_size):
eigenvector = _kway_cluster_single_image(
image_embeds[i], n_superclusters, supercluster_gamma, degree
)
supercluster_eigenvectors.append(eigenvector)
supercluster_eigenvectors = torch.stack(supercluster_eigenvectors)
# Step 2: For each supercluster in each image, compute subclusters
subcluster_eigenvectors = []
subcluster_to_supercluster_mapping = []
for img_idx in range(batch_size):
img_subclusters = []
img_mapping = []
supercluster_labels = supercluster_eigenvectors[img_idx].argmax(-1)
# For each supercluster, extract tokens and compute subclusters
for supercluster_idx in range(n_superclusters):
supercluster_mask = supercluster_labels == supercluster_idx
if supercluster_mask.sum() == 0:
# Empty supercluster - create dummy subclusters
for sub_idx in range(n_subclusters_per_supercluster):
img_mapping.append(supercluster_idx)
continue
# Extract features belonging to this supercluster
supercluster_features = image_embeds[img_idx][supercluster_mask]
# Perform clustering on this subset
if supercluster_features.shape[0] <= n_subclusters_per_supercluster:
# Too few tokens - each token becomes its own subcluster
n_actual_subclusters = supercluster_features.shape[0]
subcluster_labels = torch.arange(n_actual_subclusters).to(supercluster_features.device)
# Pad with dummy subclusters if needed
for sub_idx in range(n_subclusters_per_supercluster):
img_mapping.append(supercluster_idx)
else:
# Perform subclustering
subcluster_eigvecs = _kway_cluster_single_image(
supercluster_features,
n_subclusters_per_supercluster,
subcluster_gamma,
degree
)
subcluster_labels = subcluster_eigvecs.argmax(-1)
# Track which supercluster these subclusters belong to
for sub_idx in range(n_subclusters_per_supercluster):
img_mapping.append(supercluster_idx)
# Store subcluster assignments for this supercluster
for sub_idx in range(n_subclusters_per_supercluster):
img_subclusters.append((supercluster_mask, subcluster_labels == sub_idx if supercluster_features.shape[0] > n_subclusters_per_supercluster else None))
# Convert to full eigenvector representation
total_subclusters = n_superclusters * n_subclusters_per_supercluster
img_subcluster_eigvec = torch.zeros((image_embeds.shape[1], total_subclusters)).to(image_embeds.device)
for subcluster_global_idx, (supercluster_mask, subcluster_mask) in enumerate(img_subclusters):
if subcluster_mask is not None:
# Combine masks: belongs to supercluster AND subcluster
final_mask = torch.zeros(image_embeds.shape[1], dtype=torch.bool).to(image_embeds.device)
supercluster_indices = torch.where(supercluster_mask)[0]
subcluster_within_super = torch.where(subcluster_mask)[0]
if len(subcluster_within_super) > 0:
final_indices = supercluster_indices[subcluster_within_super]
final_mask[final_indices] = True
img_subcluster_eigvec[final_mask, subcluster_global_idx] = 1.0
# else: leave as zeros (empty subcluster)
subcluster_eigenvectors.append(img_subcluster_eigvec)
subcluster_to_supercluster_mapping.append(torch.tensor(img_mapping))
subcluster_eigenvectors = torch.stack(subcluster_eigenvectors)
subcluster_to_supercluster_mapping = torch.stack(subcluster_to_supercluster_mapping)
return supercluster_eigenvectors, subcluster_eigenvectors, subcluster_to_supercluster_mapping
def match_centers_two_step(
image_embed1,
image_embed2,
supercluster_eigvec1,
supercluster_eigvec2,
subcluster_eigvec1,
subcluster_eigvec2,
subcluster_to_supercluster_mapping1,
subcluster_to_supercluster_mapping2,
supercluster_match_method='hungarian',
subcluster_match_method='hungarian'
):
"""
Match clusters using 2-step hierarchical approach.
First matches superclusters, then matches subclusters only within matched superclusters.
Args:
image_embed1, image_embed2: Image embeddings (length, channels)
supercluster_eigvec1, supercluster_eigvec2: Supercluster eigenvectors (length, n_superclusters)
subcluster_eigvec1, subcluster_eigvec2: Subcluster eigenvectors (length, total_subclusters)
subcluster_to_supercluster_mapping1, subcluster_to_supercluster_mapping2: (total_subclusters,)
supercluster_match_method: Matching method for superclusters
subcluster_match_method: Matching method for subclusters
Returns:
np.ndarray: Mapping from image1 subclusters to image2 subclusters
"""
n_superclusters = supercluster_eigvec1.shape[-1]
n_subclusters_total = subcluster_eigvec1.shape[-1]
# Step 1: Match superclusters
supercluster_mapping = match_cluster_centers(
image_embed1, image_embed2,
supercluster_eigvec1, supercluster_eigvec2,
match_method=supercluster_match_method
)
# Step 2: For each matched supercluster pair, match subclusters within them
subcluster_mapping = np.zeros(n_subclusters_total, dtype=np.int64)
for supercluster1_idx in range(n_superclusters):
# Find which supercluster in image2 this maps to
supercluster2_idx = supercluster_mapping[supercluster1_idx]
# Find all subclusters belonging to these superclusters
subclusters1_mask = (subcluster_to_supercluster_mapping1 == supercluster1_idx).cpu().numpy()
subclusters2_mask = (subcluster_to_supercluster_mapping2 == supercluster2_idx).cpu().numpy()
subclusters1_indices = np.where(subclusters1_mask)[0]
subclusters2_indices = np.where(subclusters2_mask)[0]
if len(subclusters1_indices) == 0 or len(subclusters2_indices) == 0:
# No subclusters in one or both superclusters - use identity mapping
for sub1_idx in subclusters1_indices:
if sub1_idx < len(subclusters2_indices):
subcluster_mapping[sub1_idx] = subclusters2_indices[sub1_idx]
else:
subcluster_mapping[sub1_idx] = subclusters2_indices[0] if len(subclusters2_indices) > 0 else 0
continue
# Extract subcluster eigenvectors for matching
sub_eigvec1 = subcluster_eigvec1[:, subclusters1_indices]
sub_eigvec2 = subcluster_eigvec2[:, subclusters2_indices]
# Compute cluster centers for these subclusters
cluster_labels1 = sub_eigvec1.argmax(-1).cpu()
cluster_labels2 = sub_eigvec2.argmax(-1).cpu()
center_features1 = get_cluster_center_features(
image_embed1, cluster_labels1, len(subclusters1_indices)
)
center_features2 = get_cluster_center_features(
image_embed2, cluster_labels2, len(subclusters2_indices)
)
# Match subclusters within this supercluster pair
if subcluster_match_method == 'hungarian':
local_mapping = hungarian_match_centers(center_features1, center_features2)
elif subcluster_match_method == 'argmin':
local_mapping = argmin_matching(center_features1, center_features2)
else:
raise ValueError(f"Unknown subcluster_match_method: {subcluster_match_method}")
# Convert local mapping to global subcluster indices
for local_idx, global_idx1 in enumerate(subclusters1_indices):
global_idx2 = subclusters2_indices[local_mapping[local_idx]]
subcluster_mapping[global_idx1] = global_idx2
return subcluster_mapping
def kway_cluster_per_image_two_step_fgbg(
image_embeds,
n_foreground_subclusters,
n_background_subclusters,
supercluster_gamma=None,
subcluster_gamma=None,
degree=0.5
):
"""
Perform 2-step hierarchical clustering with automatic foreground/background separation.
First separates foreground (FG) and background (BG) using 2 clusters, identifying FG
by the cluster with highest max eigenvector value. Then subdivides FG and BG separately.
Args:
image_embeds: (batch, length, channels) - Image embeddings
n_foreground_subclusters: Number of subclusters within foreground
n_background_subclusters: Number of subclusters within background
supercluster_gamma: Gamma parameter for FG/BG clustering (None = auto)
subcluster_gamma: Gamma parameter for subcluster NCut (None = auto)
degree: Degree parameter for gamma estimation
Returns:
tuple: (supercluster_eigenvectors, subcluster_eigenvectors, subcluster_to_supercluster_mapping, fg_indices)
- supercluster_eigenvectors: (batch, length, 2) - [BG, FG] clusters
- subcluster_eigenvectors: (batch, length, total_subclusters)
- subcluster_to_supercluster_mapping: (batch, total_subclusters) - 0=BG, 1=FG
- fg_indices: (batch,) - which supercluster index is foreground for each image
"""
batch_size = image_embeds.shape[0]
n_superclusters = 2 # Always FG and BG
# Step 1: Compute FG/BG separation for each image
supercluster_eigenvectors = []
fg_indices = []
for i in range(batch_size):
eigenvector = _kway_cluster_single_image(
image_embeds[i], n_clusters=2, gamma=supercluster_gamma, degree=degree
)
supercluster_eigenvectors.append(eigenvector)
# Identify foreground: cluster with highest max eigenvector value
fg_idx = eigenvector.max(0).values.argmax().item()
fg_indices.append(fg_idx)
supercluster_eigenvectors = torch.stack(supercluster_eigenvectors)
fg_indices = torch.tensor(fg_indices)
# Step 2: For each image, compute subclusters within FG and BG
subcluster_eigenvectors = []
subcluster_to_supercluster_mapping = []
for img_idx in range(batch_size):
img_subclusters = []
img_mapping = []
supercluster_labels = supercluster_eigenvectors[img_idx].argmax(-1)
fg_idx = fg_indices[img_idx].item()
bg_idx = 1 - fg_idx
# Process BG and FG in order (BG first, then FG)
for is_foreground, n_subclusters in [(False, n_background_subclusters), (True, n_foreground_subclusters)]:
supercluster_idx = fg_idx if is_foreground else bg_idx
supercluster_mask = supercluster_labels == supercluster_idx
# Mark which supercluster type (0=BG, 1=FG)
supercluster_type = 1 if is_foreground else 0
if supercluster_mask.sum() == 0:
# Empty supercluster - create dummy subclusters
for sub_idx in range(n_subclusters):
img_mapping.append(supercluster_type)
img_subclusters.append((supercluster_mask, None))
continue
# Extract features belonging to this supercluster
supercluster_features = image_embeds[img_idx][supercluster_mask]
# Perform clustering on this subset
if supercluster_features.shape[0] <= n_subclusters:
# Too few tokens - each token becomes its own subcluster
n_actual_subclusters = supercluster_features.shape[0]
subcluster_labels = torch.arange(n_actual_subclusters).to(supercluster_features.device)
# Pad with dummy subclusters if needed
for sub_idx in range(n_subclusters):
img_mapping.append(supercluster_type)
if sub_idx < n_actual_subclusters:
img_subclusters.append((supercluster_mask, subcluster_labels == sub_idx))
else:
img_subclusters.append((supercluster_mask, None))
else:
# Perform subclustering
subcluster_eigvecs = _kway_cluster_single_image(
supercluster_features,
n_subclusters,
subcluster_gamma,
degree
)
subcluster_labels = subcluster_eigvecs.argmax(-1)
# Store subcluster assignments
for sub_idx in range(n_subclusters):
img_mapping.append(supercluster_type)
img_subclusters.append((supercluster_mask, subcluster_labels == sub_idx))
# Convert to full eigenvector representation
total_subclusters = n_background_subclusters + n_foreground_subclusters
img_subcluster_eigvec = torch.zeros((image_embeds.shape[1], total_subclusters)).to(image_embeds.device)
for subcluster_global_idx, (supercluster_mask, subcluster_mask) in enumerate(img_subclusters):
if subcluster_mask is not None:
# Combine masks: belongs to supercluster AND subcluster
final_mask = torch.zeros(image_embeds.shape[1], dtype=torch.bool).to(image_embeds.device)
supercluster_indices = torch.where(supercluster_mask)[0]
subcluster_within_super = torch.where(subcluster_mask)[0]
if len(subcluster_within_super) > 0:
final_indices = supercluster_indices[subcluster_within_super]
final_mask[final_indices] = True
img_subcluster_eigvec[final_mask, subcluster_global_idx] = 1.0
# else: leave as zeros (empty subcluster)
subcluster_eigenvectors.append(img_subcluster_eigvec)
subcluster_to_supercluster_mapping.append(torch.tensor(img_mapping))
subcluster_eigenvectors = torch.stack(subcluster_eigenvectors)
subcluster_to_supercluster_mapping = torch.stack(subcluster_to_supercluster_mapping)
return supercluster_eigenvectors, subcluster_eigenvectors, subcluster_to_supercluster_mapping, fg_indices
def match_centers_two_step_fgbg(
image_embed1,
image_embed2,
subcluster_eigvec1,
subcluster_eigvec2,
subcluster_to_supercluster_mapping1,
subcluster_to_supercluster_mapping2,
n_background_subclusters,
n_foreground_subclusters,
background_match_method='hungarian',
foreground_match_method='hungarian'
):
"""
Match clusters using 2-step FG/BG hierarchical approach.
FG and BG are automatically matched (no need for supercluster matching).
Subclusters are matched within their respective FG or BG groups.
Args:
image_embed1, image_embed2: Image embeddings (length, channels)
subcluster_eigvec1, subcluster_eigvec2: Subcluster eigenvectors (length, total_subclusters)
subcluster_to_supercluster_mapping1, subcluster_to_supercluster_mapping2: (total_subclusters,) - 0=BG, 1=FG
n_background_subclusters: Number of background subclusters
n_foreground_subclusters: Number of foreground subclusters
background_match_method: Matching method for background subclusters
foreground_match_method: Matching method for foreground subclusters
Returns:
np.ndarray: Mapping from image1 subclusters to image2 subclusters
"""
total_subclusters = n_background_subclusters + n_foreground_subclusters
subcluster_mapping = np.zeros(total_subclusters, dtype=np.int64)
# Process BG (supercluster_type=0) and FG (supercluster_type=1) separately
for supercluster_type in [0, 1]: # 0=BG, 1=FG
# Find subclusters belonging to this supercluster type
subclusters1_mask = (subcluster_to_supercluster_mapping1 == supercluster_type).cpu().numpy()
subclusters2_mask = (subcluster_to_supercluster_mapping2 == supercluster_type).cpu().numpy()
subclusters1_indices = np.where(subclusters1_mask)[0]
subclusters2_indices = np.where(subclusters2_mask)[0]
if len(subclusters1_indices) == 0 or len(subclusters2_indices) == 0:
# No subclusters in one or both - use identity mapping
for sub1_idx in subclusters1_indices:
if sub1_idx < len(subclusters2_indices):
subcluster_mapping[sub1_idx] = subclusters2_indices[sub1_idx]
else:
subcluster_mapping[sub1_idx] = subclusters2_indices[0] if len(subclusters2_indices) > 0 else 0
continue
# Extract subcluster eigenvectors for matching
sub_eigvec1 = subcluster_eigvec1[:, subclusters1_indices]
sub_eigvec2 = subcluster_eigvec2[:, subclusters2_indices]
# Compute cluster centers for these subclusters
cluster_labels1 = sub_eigvec1.argmax(-1).cpu()
cluster_labels2 = sub_eigvec2.argmax(-1).cpu()
center_features1 = get_cluster_center_features(
image_embed1, cluster_labels1, len(subclusters1_indices)
)
center_features2 = get_cluster_center_features(
image_embed2, cluster_labels2, len(subclusters2_indices)
)
# Match subclusters within this FG/BG group
match_method = foreground_match_method if supercluster_type == 1 else background_match_method
if match_method == 'hungarian':
local_mapping = hungarian_match_centers(center_features1, center_features2)
elif match_method == 'argmin':
local_mapping = argmin_matching(center_features1, center_features2)
else:
raise ValueError(f"Unknown match_method: {match_method}")
# Convert local mapping to global subcluster indices
for local_idx, global_idx1 in enumerate(subclusters1_indices):
global_idx2 = subclusters2_indices[local_mapping[local_idx]]
subcluster_mapping[global_idx1] = global_idx2
return subcluster_mapping
# ===== Visualization Functions =====
def plot_cluster_masks(image, eigenvector, cluster_order, hw=16):
"""
blend the image with the cluster masks
# image is (c, h, w)
# eigenvector is (h*w, n_eig)
# cluster_order is (n_eig), the order of the clusters
"""
cluster_images = []
base_img = image_inverse_transform(image).resize(
(128, 128), resample=Image.Resampling.NEAREST
)
for cluster_idx in cluster_order:
# Create cluster mask
cluster_mask = eigenvector.argmax(-1) == cluster_idx
mask_array = cluster_mask.cpu().numpy()[1:].reshape(hw, hw)
mask_array = (mask_array * 255).astype(np.uint8)
# Resize mask to match image
mask_img = Image.fromarray(mask_array).resize(
(128, 128), resample=Image.Resampling.NEAREST
)
# Apply mask to image
mask_normalized = np.array(mask_img).astype(np.float32) / 255
img_array = np.array(base_img).astype(np.float32) / 255
# Create 3-channel mask and apply
mask_3ch = np.stack([mask_normalized] * 3, axis=-1)
mask_3ch[mask_3ch == 0] = 0.1 # Dim non-masked areas
masked_img = img_array * mask_3ch
masked_img = (masked_img * 255).astype(np.uint8)
cluster_images.append(Image.fromarray(masked_img))
return cluster_images
def create_image_grid_row(image, eigenvector, cluster_order, discrete_colors,
hw=16, n_cols=10):
cluster_images = plot_cluster_masks(image, eigenvector, cluster_order, hw)
# Prepare base images
base_img = image_inverse_transform(image).resize(
(128, 128), resample=Image.Resampling.NEAREST
)
ncut_visualization = discrete_colors[1:].reshape(hw, hw, 3)
ncut_img = Image.fromarray(ncut_visualization).resize(
(128, 128), resample=Image.Resampling.NEAREST
)
# Pad cluster images to fill grid
num_missing = n_cols - len(cluster_images) % n_cols
if num_missing != n_cols:
empty_img = Image.fromarray(np.zeros((128, 128, 3), dtype=np.uint8))
cluster_images.extend([empty_img] * num_missing)
# Create grid rows
prepend_images = [base_img, ncut_img]
n_rows = len(cluster_images) // n_cols
grid_rows = []
for row_idx in range(n_rows):
start_idx = row_idx * n_cols
end_idx = (row_idx + 1) * n_cols
row_images = prepend_images + cluster_images[start_idx:end_idx]
grid_rows.append(row_images)
return grid_rows
def create_multi_image_grid(images, eigenvectors, cluster_orders, discrete_colors,
hw=16, n_cols=10):
all_grid_rows = []
for image, eigvec, cluster_order, discrete_rgb in zip(
images, eigenvectors, cluster_orders, discrete_colors
):
grid_rows = create_image_grid_row(
image, eigvec, cluster_order, discrete_rgb, hw, n_cols
)
all_grid_rows.append(grid_rows)
# Interleave rows from different images
interleaved_rows = []
for row_idx in range(len(all_grid_rows[0])):
for img_idx in range(len(all_grid_rows)):
interleaved_rows.append(all_grid_rows[img_idx][row_idx])
return interleaved_rows
def get_correspondence_plot(images, eigenvectors, cluster_orders, discrete_colors,
hw=16, n_cols=10):
n_clusters = eigenvectors.shape[-1]
n_cols = min(n_cols, n_clusters)
interleaved_rows = create_multi_image_grid(
images, eigenvectors, cluster_orders, discrete_colors, hw, n_cols
)
n_rows = len(interleaved_rows)
n_cols = len(interleaved_rows[0])
# Flatten all images and create final grid
all_images = sum(interleaved_rows, [])
final_grid = image_grid(all_images, n_rows, n_cols)
return final_grid