Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from typing import List, Optional, Tuple | |
| import torch.nn.functional as F | |
| from omegaconf import OmegaConf | |
| from ipadapter_model import generate_images_from_clip_embeddings | |
| from ipadapter_model import load_ipadapter | |
| from intrinsic_dim import estimate_intrinsic_dimension | |
| from vibespace_model import VibeSpaceModel, train_vibe_space, clear_gpu_memory | |
| from dino_correspondence import kway_cluster_per_image, match_centers_two_images, get_cluster_center_features | |
| from extract_features import extract_dino_features, extract_clip_features, dino_image_transform, clip_image_transform | |
| import logging | |
| import gradio as gr | |
| DEFAULT_CONFIG_PATH = "./config.yaml" | |
| def load_config(config_path: str): | |
| cfg_base = OmegaConf.load(DEFAULT_CONFIG_PATH) | |
| cfg = OmegaConf.load(config_path) | |
| cfg_base.update(cfg) | |
| return cfg_base | |
| def run_vibe_blend_safe(image1, image2, extra_images, negative_images, config_path, interpolation_weights: List[float], n_clusters: int = 25): | |
| success = False | |
| while not success: | |
| try: | |
| model, trainer = run_vibe_space_training( | |
| positive_images=[image1, image2, *extra_images], | |
| negative_images=negative_images, | |
| config_path=config_path, | |
| ) | |
| success = True | |
| except Exception as e: | |
| logging.error(f"Error training model: {e}") | |
| clear_gpu_memory() | |
| continue | |
| success = False | |
| while not success: | |
| try: | |
| blended_images = generate_blend_images( | |
| image1, | |
| image2, | |
| model, | |
| interpolation_weights, | |
| n_clusters=n_clusters, | |
| ) | |
| success = True | |
| except Exception as e: | |
| logging.error(f"Error generating images: {e}") | |
| clear_gpu_memory() | |
| continue | |
| return blended_images | |
| def run_vibe_blend_not_safe(image1, image2, extra_images, negative_images, config_path, interpolation_weights: List[float], n_clusters: int = 20): | |
| model, trainer = run_vibe_space_training( | |
| positive_images=[image1, image2, *extra_images], | |
| negative_images=negative_images, | |
| config_path=config_path, | |
| ) | |
| blended_images = generate_blend_images( | |
| image1, | |
| image2, | |
| model, | |
| interpolation_weights, | |
| n_clusters=n_clusters, | |
| ) | |
| return blended_images | |
| def run_vibe_space_training(positive_images: List[Image.Image], | |
| negative_images: List[Image.Image], | |
| config_path: str = DEFAULT_CONFIG_PATH) -> Tuple[VibeSpaceModel, object]: | |
| """ | |
| Train a Mood Space compression model from input images. | |
| This function extracts DINO and CLIP features from the input images, | |
| estimates the intrinsic dimensionality if not provided, and trains | |
| a neural compression model to learn a meaningful embedding space. | |
| Args: | |
| pil_images: List of PIL Images for training | |
| """ | |
| # Load and configure training parameters | |
| config = load_config(config_path) | |
| positive_images = [img for img in positive_images if img is not None] | |
| negative_images = [img for img in negative_images or [] if img is not None] | |
| if len(positive_images) == 0: | |
| raise ValueError("No valid positive images provided for Vibe Space training") | |
| has_negative_images = len(negative_images) > 0 | |
| # Transform images for feature extraction | |
| dino_input_images = torch.stack([dino_image_transform(image) for image in positive_images]) | |
| clip_input_images = torch.stack([clip_image_transform(image) for image in positive_images]) | |
| if has_negative_images: | |
| negative_dino_input_images = torch.stack([dino_image_transform(image) for image in negative_images]) | |
| else: | |
| negative_dino_input_images = None | |
| # Extract features using pre-trained models | |
| dino_image_embeds = extract_dino_features(dino_input_images) | |
| clip_image_embeds = extract_clip_features(clip_input_images) | |
| if has_negative_images: | |
| negative_dino_embeds = extract_dino_features(negative_dino_input_images) | |
| else: | |
| negative_dino_embeds = None | |
| # Determine intrinsic dimensionality | |
| flattened_features = dino_image_embeds.flatten(end_dim=-2) | |
| estimated_dim = estimate_intrinsic_dimension(flattened_features) | |
| hidden_dim = int(estimated_dim) | |
| config.vibe_dim = hidden_dim | |
| if len(positive_images) > 2: | |
| # increase training steps for extra images | |
| config.steps = config.steps * 2 | |
| # Create and train model | |
| model = VibeSpaceModel(config, enable_gradio_progress=True) | |
| trainer = train_vibe_space( | |
| model, | |
| config, | |
| dino_image_embeds, | |
| clip_image_embeds, | |
| negative_dino_embeds, | |
| ) | |
| return model, trainer | |
| def _compute_direction_from_two_images(image_embeds: torch.Tensor, | |
| eigenvectors: torch.Tensor | List[torch.Tensor], | |
| a_to_b_mapping: np.ndarray, | |
| use_unit_norm: bool = False) -> torch.Tensor: | |
| # Compute cluster centers | |
| a_center_features = get_cluster_center_features( | |
| image_embeds[0], eigenvectors[0].argmax(-1).cpu(), eigenvectors[0].shape[-1]) | |
| b_center_features = get_cluster_center_features( | |
| image_embeds[1], eigenvectors[1].argmax(-1).cpu(), eigenvectors[1].shape[-1]) | |
| # Compute direction vectors | |
| direction_vectors = [] | |
| for i_a, i_b in enumerate(a_to_b_mapping): | |
| direction = b_center_features[i_b] - a_center_features[i_a] | |
| if use_unit_norm: | |
| direction = F.normalize(direction, dim=-1) | |
| direction_vectors.append(direction) | |
| direction_vectors = torch.stack(direction_vectors) | |
| # Apply direction based on cluster assignments | |
| cluster_labels = eigenvectors[0].argmax(-1).cpu() | |
| direction_field = torch.zeros_like(image_embeds[0]) | |
| for i_cluster in range(eigenvectors[0].shape[-1]): | |
| cluster_mask = cluster_labels == i_cluster | |
| if cluster_mask.sum() > 0: | |
| direction_field[cluster_mask] = direction_vectors[i_cluster] | |
| return direction_field | |
| def generate_blend_images(image1: Image.Image, | |
| image2: Image.Image, | |
| model: VibeSpaceModel, | |
| interpolation_weights: List[float], | |
| n_clusters: int = 20, | |
| seed: Optional[int] = None, | |
| ) -> List[Image.Image]: | |
| """ | |
| Interpolate between two images using the trained compression model. | |
| Args: | |
| image1, image2: Input PIL Images | |
| model: Trained compression model | |
| interpolation_weights: Weights for interpolation | |
| n_clusters: Number of clusters for correspondence matching | |
| seed: Random seed for generation | |
| Returns: | |
| List[Image.Image]: Generated interpolated images | |
| """ | |
| clear_gpu_memory() | |
| # Prepare images and extract features | |
| images = torch.stack([dino_image_transform(img) for img in [image1, image2]]) | |
| dino_image_embeds = extract_dino_features(images) | |
| compressed_image_embeds = model.encoder(dino_image_embeds) | |
| cluster_eigenvectors = kway_cluster_per_image(dino_image_embeds, n_clusters=n_clusters, gamma=None) | |
| a_to_b_mapping = match_centers_two_images( | |
| dino_image_embeds[0], dino_image_embeds[1], | |
| cluster_eigenvectors[0], cluster_eigenvectors[1], | |
| match_method='hungarian' | |
| ) | |
| direction_field = _compute_direction_from_two_images( | |
| compressed_image_embeds, cluster_eigenvectors, a_to_b_mapping, use_unit_norm=False | |
| ) | |
| # Generate interpolated images | |
| ip_model = load_ipadapter() | |
| progress_tracker = gr.Progress() | |
| generated_images = [] | |
| for i, weight in enumerate(interpolation_weights): | |
| progress_tracker(i / len(interpolation_weights), desc=f"Generating images, α = {weight:.2f}") | |
| interpolated_embedding = compressed_image_embeds[0] + direction_field * weight | |
| decompressed_embedding = model.decoder(interpolated_embedding) | |
| batch_images = generate_images_from_clip_embeddings( | |
| ip_model, decompressed_embedding, num_samples=1, seed=seed | |
| ) | |
| if np.all(np.array(batch_images[0]) == 0): | |
| raise ValueError("Generated image is all black") | |
| generated_images.extend(batch_images) | |
| # Clean up | |
| del ip_model | |
| clear_gpu_memory() | |
| return generated_images | |