VibeSpace / vibespace_model.py
huzey's picture
Add CPU fallback for HF custom feature stages
22afff9
"""
Neural Compression Model for Feature Space Learning
This module implements a compression model that learns to compress and decompress
image features while preserving their geometric and semantic properties using
normalized cuts (NCut).
"""
import gc
from collections import defaultdict
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from einops import rearrange
from omegaconf import DictConfig
import gradio as gr
from ncut_pytorch.ncuts.ncut_nystrom import _plain_ncut
from ncut_pytorch.utils.math import rbf_affinity
def compute_ncut_eigenvectors(features: torch.Tensor, n_eig: int) -> Tuple[torch.Tensor, torch.Tensor]:
gamma = features.var(0).sum().item()
affinity_matrix = rbf_affinity(features, gamma=gamma)
eigenvectors, eigenvalues = _plain_ncut(affinity_matrix, n_eig)
return eigenvectors, eigenvalues
# ===== Neural Network Components =====
class MultiLayerPerceptron(nn.Module):
def __init__(self, input_dim: int, output_dim: int, num_layers: int = 4, hidden_dim: int = 4096):
super().__init__()
layers = [nn.Linear(input_dim, hidden_dim), nn.GELU()]
# Add hidden layers
for _ in range(num_layers):
layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.GELU()])
# Output layer
layers.append(nn.Linear(hidden_dim, output_dim))
self.mlp = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
class SpatialPoolingAvgPool(nn.Module):
"""
AvgPool layer for spatial pooling of feature maps with support for sequence inputs.
Handles inputs with CLS tokens and reshapes appropriately for 2D convolution.
"""
def __init__(self, downsample_factor: int = 2):
super().__init__()
self.downsample_factor = downsample_factor
self.avg_pool = nn.AvgPool2d(downsample_factor)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass supporting both (batch, seq_len, channels) and (seq_len, channels) inputs.
"""
# Handle input shape variations
added_batch_dim = False
if x.dim() == 2:
x = x.unsqueeze(0)
added_batch_dim = True
elif x.dim() != 3:
raise ValueError(f"Expected input shape (B, L, C) or (L, C), got {x.shape}")
batch_size, seq_len, channels = x.shape
if seq_len < 2:
raise ValueError("Sequence length must be at least 2 (1 CLS token + 1 patch)")
# Validate that seq_len-1 is a perfect square (for spatial arrangement)
spatial_size = int(round((seq_len - 1) ** 0.5))
if spatial_size * spatial_size != (seq_len - 1):
raise ValueError(f"seq_len-1 must be perfect square. Got {seq_len-1}")
# Separate CLS token and spatial features
cls_tokens = x[:, :1, :] # (B, 1, C)
spatial_features = x[:, 1:, :] # (B, H*W, C)
# Reshape to 2D for convolution
spatial_2d = rearrange(
spatial_features, 'b (h w) c -> b c h w',
h=spatial_size, w=spatial_size
)
# Apply pooling
pooled_features = self.avg_pool(spatial_2d)
# Reshape back to sequence format
pooled_sequence = rearrange(pooled_features, 'b c h w -> b (h w) c')
# Concatenate CLS token back
output = torch.cat([cls_tokens, pooled_sequence], dim=1)
# Remove batch dimension if it was added
if added_batch_dim:
output = output.squeeze(0)
return output
class MLPWithSpatialPooling(nn.Module):
def __init__(self, input_dim: int, output_dim: int, num_layers: int = 4,
hidden_dim: int = 4096, downsample_factor: int = 2):
super().__init__()
self.pooling = SpatialPoolingAvgPool(downsample_factor)
layers = [nn.Linear(input_dim, hidden_dim), nn.GELU()]
# Add hidden layers
for _ in range(num_layers):
layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.GELU()])
# Output layer
layers.append(nn.Linear(hidden_dim, output_dim))
self.network = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pooling(x)
return self.network(x)
# ===== Main Compression Model =====
class VibeSpaceModel(pl.LightningModule):
"""
Neural compression model for learning compressed feature representations.
This model compresses input features to a lower-dimensional "vibe space" and
then decompresses them back, while preserving geometric and semantic properties
through various loss functions including NCut-based losses.
"""
def __init__(self, config: DictConfig, enable_gradio_progress: bool = False, downsample_factor: int = 2):
super().__init__()
self.config = config
self.downsample_factor = downsample_factor
self.encoder = MultiLayerPerceptron(
config.in_dim, config.vibe_dim, config.n_layer, config.latent_dim
)
self.decoder = MLPWithSpatialPooling(
config.vibe_dim, config.out_dim, config.n_layer,
config.latent_dim, self.downsample_factor
)
self.loss_history = defaultdict(list)
self.enable_gradio_progress = enable_gradio_progress
if enable_gradio_progress:
self.progress_tracker = gr.Progress()
def forward(self, x: torch.Tensor) -> torch.Tensor:
compressed = self.encoder(x)
reconstructed = self.decoder(compressed)
return reconstructed
def training_step(self, batch, batch_idx):
# Update progress bar if using Gradio
if (self.enable_gradio_progress and
self.trainer.global_step % 10 == 0 and
self.trainer.global_step > 0 and
self.loss_history['recon']):
progress = self.trainer.global_step / self.config.steps
recent_loss = self.loss_history['recon'][-1]
self.progress_tracker(progress, desc=f"Training Vibe Space, loss = {recent_loss:.4f}")
positive_features, negative_features, target_features, negative_mask = batch
negative_mask = negative_mask.bool()
has_negatives = bool(negative_mask.any().item())
if has_negatives:
if bool(negative_mask.all().item()):
batch_negative_features = negative_features
else:
batch_negative_features = negative_features[negative_mask]
else:
batch_negative_features = None
compressed_features = self.encoder(positive_features)
reconstructed_features = self.decoder(compressed_features)
total_loss = self._compute_total_loss(
positive_features,
batch_negative_features,
target_features,
compressed_features,
reconstructed_features,
)
self.log("loss/total", total_loss, prog_bar=True)
return total_loss
def _compute_ncut_eigenvectors(self, features: torch.Tensor) -> torch.Tensor:
"""Compute NCut eigenvectors for features."""
# Accept inputs shaped either (batch, length, channels) or (length, channels)
flattened_features = features
if flattened_features.dim() >= 3:
flattened_features = flattened_features.flatten(0, 1)
elif flattened_features.dim() == 1:
# rbf_affinity expects at least 2D; treat single vector as one sample with channels
flattened_features = flattened_features.unsqueeze(0)
if flattened_features.numel() > 0 and flattened_features.dim() == 2:
eigenvectors, _ = compute_ncut_eigenvectors(flattened_features, self.config.n_eig)
return eigenvectors
else:
# Return zero tensor if no features
device = features.device if isinstance(features, torch.Tensor) else 'cpu'
return torch.zeros((1, self.config.n_eig), device=device)
def _compute_multiscale_similarity(self, eigenvectors: torch.Tensor,
start_n_eig: int = 4, step_mult: int = 2) -> torch.Tensor:
"""Compute multi-scale similarity matrix from eigenvectors.
eigenvectors is (batch*length, n_eig)
"""
total_similarity = 0.0
num_scales = 0
max_available = eigenvectors.shape[1]
current_n_eig = min(start_n_eig, max_available)
if self.config.single_scale_flag:
current_n_eig = max_available
while current_n_eig <= max_available:
eigvec_subset = eigenvectors[:, :current_n_eig]
eigvec_normalized = F.normalize(eigvec_subset, dim=-1)
total_similarity += eigvec_normalized @ eigvec_normalized.T
num_scales += 1
current_n_eig *= step_mult
return total_similarity / num_scales if num_scales > 0 else total_similarity
def _compute_flag_decoder_loss(
self,
compressed_features: torch.Tensor,
reconstructed_features: torch.Tensor,
negative_input_features: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
compressed_features is (batch, length, channels)
reconstructed_features is (batch, length, channels)
"""
pooled_compressed = self.decoder.pooling(compressed_features)
pooled_compressed = pooled_compressed.flatten(0, 1)
reconstructed_features = reconstructed_features.flatten(0, 1)
has_negative = (
negative_input_features is not None and negative_input_features.numel() > 0
)
# sample points from the compressed feature space (only when no negatives available)
dim_mins = pooled_compressed.min(0).values
dim_maxs = pooled_compressed.max(0).values
dim_mins -= 0.25 * (dim_maxs - dim_mins) * torch.rand_like(dim_mins)
dim_maxs += 0.25 * (dim_maxs - dim_mins) * torch.rand_like(dim_maxs)
num_samples = 0 if has_negative else self.config.n_negative_sample
sample_points = torch.rand(num_samples, pooled_compressed.shape[1], device=pooled_compressed.device)
sample_points = sample_points * (dim_maxs - dim_mins) + dim_mins
# reconstruct the sample points
sample_reconstructed = self.decoder.network(sample_points)
all_compressed = torch.cat([pooled_compressed, sample_points], dim=0)
all_reconstructed = torch.cat([reconstructed_features, sample_reconstructed], dim=0)
# flag loss on the sample points
similarity = all_compressed @ all_compressed.T
eigenvectors_pos, _ = compute_ncut_eigenvectors(all_reconstructed, self.config.n_eig)
if has_negative and self.config.get('do_decoder_negative_flag', False):
negative_compressed = self.encoder(negative_input_features)
negative_reconstructed = self.decoder(negative_compressed)
negative_reconstructed = negative_reconstructed.flatten(0, 1)
neg_eigenvectors, _ = compute_ncut_eigenvectors(negative_reconstructed, self.config.n_eig)
max_available = min(eigenvectors_pos.shape[1], neg_eigenvectors.shape[1])
if max_available == 0:
eig_similarity = self._compute_multiscale_similarity(eigenvectors_pos)
else:
if self.config.single_scale_flag:
current_n_eig = max_available
else:
current_n_eig = min(self.config.get('start_n_eig', 4), max_available)
current_n_eig = max(current_n_eig, 1)
total_filtered_similarity = similarity.new_zeros(similarity.shape)
num_scales = 0
beta = self.config.get('decoder_negative_beta', self.config.get('negative_beta', 1.0))
step_mult = self.config.get('step_mult', 2)
while current_n_eig <= max_available:
P = eigenvectors_pos[:, :current_n_eig]
N = neg_eigenvectors[:, :current_n_eig]
N_norm = F.normalize(N, dim=0)
projection = torch.matmul(N_norm.T, P)
P_filtered = P - beta * torch.matmul(N_norm, projection)
P_filtered_norm = F.normalize(P_filtered, dim=-1)
total_filtered_similarity += P_filtered_norm @ P_filtered_norm.T
num_scales += 1
current_n_eig *= step_mult
if num_scales > 0:
eig_similarity = total_filtered_similarity / num_scales
else:
eig_similarity = self._compute_multiscale_similarity(eigenvectors_pos)
else:
eig_similarity = self._compute_multiscale_similarity(eigenvectors_pos)
loss = F.smooth_l1_loss(eig_similarity, similarity)
return loss
def _compute_flag_encoder_loss(self, input_features: torch.Tensor, compressed_features: torch.Tensor) -> torch.Tensor:
"""
input_features is (batch, length, channels)
compressed_features is (batch, length, channels)
"""
sample_indices = torch.randperm(input_features.shape[0])[:self.config.n_sample_eigsolve]
gt_eigenvectors = self._compute_ncut_eigenvectors(input_features.flatten(0, 1)[sample_indices])
gt_similarity = self._compute_multiscale_similarity(gt_eigenvectors)
flattened_compressed = compressed_features.flatten(0, 1)[sample_indices]
pred_similarity = flattened_compressed @ flattened_compressed.T
loss = F.smooth_l1_loss(gt_similarity, pred_similarity)
return loss
def _compute_total_loss(
self,
positive_features: torch.Tensor,
negative_features: Optional[torch.Tensor],
target_features: torch.Tensor,
compressed_features: torch.Tensor,
reconstructed_features: torch.Tensor,
) -> torch.Tensor:
"""
positive_features is (batch, length, channels)
target_features is (batch, length, channels)
compressed_features is (batch, length, channels)
reconstructed_features is (batch, length, channels)
"""
total_loss = positive_features.new_tensor(0.0)
has_negative_features = (
negative_features is not None and negative_features.numel() > 0
)
beta = self.config.get('negative_beta', 1.0)
# Flag encoder loss - guide the structure from encoder to compressed features
if self.config.flag_encoder_loss > 0 and has_negative_features:
gt_eigenvectors_pos = self._compute_ncut_eigenvectors(positive_features)
gt_eigenvectors_neg = self._compute_ncut_eigenvectors(negative_features)
total_filtered_similarity = 0.0
num_scales = 0
max_available = min(gt_eigenvectors_pos.shape[1], gt_eigenvectors_neg.shape[1])
if max_available == 0:
gt_similarity = self._compute_multiscale_similarity(gt_eigenvectors_pos)
else:
if self.config.single_scale_flag:
current_n_eig = max_available
else:
current_n_eig = min(self.config.get('start_n_eig', 4), max_available)
current_n_eig = max(current_n_eig, 1)
step_mult = self.config.get('step_mult', 2)
while current_n_eig <= max_available and current_n_eig > 0:
P = gt_eigenvectors_pos[:, :current_n_eig]
N = gt_eigenvectors_neg[:, :current_n_eig]
N_norm = F.normalize(N, dim=0)
projection = torch.matmul(N_norm.T, P)
P_filtered = P - beta * torch.matmul(N_norm, projection)
P_filtered_norm = F.normalize(P_filtered, dim=-1)
total_filtered_similarity += P_filtered_norm @ P_filtered_norm.T
num_scales += 1
current_n_eig *= step_mult
if num_scales > 0:
gt_similarity = total_filtered_similarity / num_scales
else:
gt_similarity = self._compute_multiscale_similarity(gt_eigenvectors_pos)
flattened_compressed = compressed_features.flatten(0, 1)
pred_similarity = flattened_compressed @ flattened_compressed.T
flag_encoder_loss = F.smooth_l1_loss(gt_similarity, pred_similarity)
self.log("loss/flag_encoder", flag_encoder_loss, prog_bar=True)
total_loss += flag_encoder_loss * self.config.flag_encoder_loss
self.loss_history['flag_encoder'].append(flag_encoder_loss.item())
elif self.config.flag_encoder_loss > 0:
flag_encoder_loss = self._compute_flag_encoder_loss(positive_features, compressed_features)
self.log("loss/flag_encoder", flag_encoder_loss, prog_bar=True)
total_loss += flag_encoder_loss * self.config.flag_encoder_loss
self.loss_history['flag_encoder'].append(flag_encoder_loss.item())
# Flag decoder loss - guide the structure from compressed to decoded features
if self.config.flag_decoder_loss > 0:
if self.trainer.global_step >= 500: # warmup period
flag_decoder_loss = self._compute_flag_decoder_loss(
compressed_features,
reconstructed_features,
negative_features,
)
self.log("loss/flag_decoder", flag_decoder_loss, prog_bar=True)
total_loss += flag_decoder_loss * self.config.flag_decoder_loss
self.loss_history['flag_decoder'].append(flag_decoder_loss.item())
# Reconstruction loss
if self.config.recon_loss > 0:
recon_loss = F.smooth_l1_loss(target_features, reconstructed_features)
self.log("loss/recon", recon_loss, prog_bar=True)
total_loss += recon_loss * self.config.recon_loss
self.loss_history['recon'].append(recon_loss.item())
return total_loss
def configure_optimizers(self):
return torch.optim.NAdam(self.parameters(), lr=self.config.lr)
# ===== Dataset and Training Utilities =====
class FeatureDataset(torch.utils.data.Dataset):
def __init__(
self,
positive_features: torch.Tensor,
target_features: torch.Tensor,
negative_features: Optional[torch.Tensor] = None,
):
self.positive_features = positive_features
self.target_features = target_features
if negative_features is not None and negative_features.numel() > 0:
self.negative_features = negative_features
else:
self.negative_features = None
def __len__(self) -> int:
return len(self.positive_features)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
positive = self.positive_features[idx]
target = self.target_features[idx]
if self.negative_features is None:
negative = torch.zeros_like(positive)
has_negative = torch.tensor(False, dtype=torch.bool)
else:
neg_idx = torch.randint(0, self.negative_features.shape[0], (1,)).item()
negative = self.negative_features[neg_idx]
has_negative = torch.tensor(True, dtype=torch.bool)
return positive, negative, target, has_negative
def clear_gpu_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
def train_vibe_space(model: VibeSpaceModel,
config: DictConfig,
input_features: torch.Tensor,
target_features: torch.Tensor,
negative_features: Optional[torch.Tensor] = None,
devices: List[int] = [0]) -> pl.Trainer:
clear_gpu_memory()
dataset = FeatureDataset(input_features, target_features, negative_features)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)
feature_device = input_features.device.type
if feature_device == "cuda" and torch.cuda.is_available():
accelerator = "gpu"
trainer_devices: int | List[int] = devices
elif feature_device == "mps" and torch.backends.mps.is_available():
accelerator = "mps"
trainer_devices = 1
else:
accelerator = "cpu"
trainer_devices = 1
trainer = pl.Trainer(
max_steps=config.steps,
gradient_clip_val=1.0,
accelerator=accelerator,
devices=trainer_devices,
enable_checkpointing=False,
enable_progress_bar=True,
logger=False # Disable default logger
)
trainer.fit(model, dataloader)
return trainer