| | """ |
| | HuggingFace Spaces entry point for diffviews. |
| | |
| | This file is the main entry point for HF Spaces deployment. |
| | It downloads required data and checkpoints on startup, then launches the Gradio app. |
| | |
| | Requirements: |
| | Python 3.10+ |
| | Gradio 6.0+ |
| | |
| | Environment variables: |
| | DIFFVIEWS_DATA_DIR: Override data directory (default: data) |
| | DIFFVIEWS_CHECKPOINT: Which checkpoint to download (dmd2, edm, all, none; default: all) |
| | DIFFVIEWS_DEVICE: Override device (cuda, mps, cpu; auto-detected if not set) |
| | """ |
| |
|
| | import os |
| | import subprocess |
| | import sys |
| | from pathlib import Path |
| |
|
| | |
| | _REPO_URL = "https://github.com/mckellcarter/diffviews.git" |
| | _REPO_BRANCH = os.environ.get("DIFFVIEWS_BRANCH", "diffviews-gradio6-HFz-CFr2") |
| | _REPO_DIR = "/tmp/diffviews" |
| |
|
| | |
| | subprocess.run(["pip", "uninstall", "-y", "diffviews"], capture_output=True) |
| | |
| | for mod in list(sys.modules): |
| | if mod == "diffviews" or mod.startswith("diffviews."): |
| | del sys.modules[mod] |
| |
|
| | if not os.path.exists(_REPO_DIR): |
| | print(f"Cloning diffviews from {_REPO_BRANCH}...") |
| | subprocess.run( |
| | ["git", "clone", "--depth=1", "-b", _REPO_BRANCH, _REPO_URL, _REPO_DIR], |
| | check=True, |
| | ) |
| | sys.path.insert(0, _REPO_DIR) |
| |
|
| | import spaces |
| |
|
| | |
| | DATA_REPO_ID = "mckell/diffviews_demo_data" |
| | CHECKPOINT_URLS = { |
| | "dmd2": ( |
| | "https://huggingface.co/mckell/diffviews-dmd2-checkpoint/" |
| | "resolve/main/dmd2-imagenet-64-10step.pkl" |
| | ), |
| | "edm": ( |
| | "https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/" |
| | "edm-imagenet-64x64-cond-adm.pkl" |
| | ), |
| | } |
| | CHECKPOINT_FILENAMES = { |
| | "dmd2": "dmd2-imagenet-64-10step.pkl", |
| | "edm": "edm-imagenet-64x64-cond-adm.pkl", |
| | } |
| |
|
| |
|
| | def download_data_r2(output_dir: Path) -> bool: |
| | """Download data from Cloudflare R2. Returns True on success.""" |
| | from diffviews.data.r2_cache import R2DataStore |
| |
|
| | store = R2DataStore() |
| | if not store.enabled: |
| | return False |
| |
|
| | print(f"Downloading data from R2...") |
| | for model in ["dmd2", "edm"]: |
| | store.download_model_data(model, output_dir) |
| | return True |
| |
|
| |
|
| | def download_data_hf(output_dir: Path) -> None: |
| | """Fallback: download data from HuggingFace Hub.""" |
| | from huggingface_hub import snapshot_download |
| |
|
| | print(f"Downloading data from {DATA_REPO_ID} (HF fallback)...") |
| | snapshot_download( |
| | repo_id=DATA_REPO_ID, |
| | repo_type="dataset", |
| | local_dir=output_dir, |
| | revision="main", |
| | ) |
| | print(f"Data downloaded to {output_dir}") |
| |
|
| |
|
| | def download_data(output_dir: Path) -> None: |
| | """Download data: R2 first, HF fallback.""" |
| | print(f"Output directory: {output_dir.absolute()}") |
| | if not download_data_r2(output_dir): |
| | download_data_hf(output_dir) |
| |
|
| |
|
| | def download_checkpoint(output_dir: Path, model: str) -> None: |
| | """Download model checkpoint: R2 first, URL fallback.""" |
| | if model not in CHECKPOINT_URLS: |
| | print(f"Unknown model: {model}") |
| | return |
| |
|
| | ckpt_dir = output_dir / model / "checkpoints" |
| | ckpt_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | filename = CHECKPOINT_FILENAMES[model] |
| | filepath = ckpt_dir / filename |
| |
|
| | if filepath.exists(): |
| | print(f"Checkpoint exists: {filepath}") |
| | return |
| |
|
| | |
| | from diffviews.data.r2_cache import R2DataStore |
| | store = R2DataStore() |
| | r2_key = f"data/{model}/checkpoints/{filename}" |
| | if store.enabled and store.download_file(r2_key, filepath): |
| | print(f"Checkpoint downloaded from R2: {filepath} ({filepath.stat().st_size / 1e6:.1f} MB)") |
| | return |
| |
|
| | |
| | import urllib.request |
| | url = CHECKPOINT_URLS[model] |
| | print(f"Downloading {model} checkpoint from URL (~1GB)...") |
| | print(f" URL: {url}") |
| |
|
| | try: |
| | urllib.request.urlretrieve(url, filepath) |
| | print(f" Done ({filepath.stat().st_size / 1e6:.1f} MB)") |
| | except Exception as e: |
| | print(f" Error downloading checkpoint: {e}") |
| | print(" Generation will be disabled without checkpoint") |
| |
|
| |
|
| | def get_pca_components() -> int | None: |
| | """Read PCA pre-reduction setting from env. None = disabled.""" |
| | val = os.environ.get("DIFFVIEWS_PCA_COMPONENTS", "50") |
| | if val.lower() in ("0", "none", "off", ""): |
| | return None |
| | return int(val) |
| |
|
| |
|
| | def regenerate_umap(data_dir: Path, model: str) -> bool: |
| | """Regenerate UMAP pickle for a model to ensure numba compatibility. |
| | |
| | This recomputes UMAP from activations and saves new pickle file. |
| | Required when running on different environment than original pickle was created. |
| | """ |
| | from diffviews.processing.umap import ( |
| | load_dataset_activations, |
| | compute_umap, |
| | save_embeddings, |
| | ) |
| | import json |
| |
|
| | model_dir = data_dir / model |
| | activation_dir = model_dir / "activations" / "imagenet_real" |
| | metadata_path = model_dir / "metadata" / "imagenet_real" / "dataset_info.json" |
| | embeddings_dir = model_dir / "embeddings" |
| |
|
| | |
| | if not activation_dir.exists() or not metadata_path.exists(): |
| | print(f" Skipping UMAP regeneration for {model}: missing activations") |
| | return False |
| |
|
| | |
| | csv_files = list(embeddings_dir.glob("*.csv")) |
| | if not csv_files: |
| | print(f" Skipping UMAP regeneration for {model}: no embeddings CSV") |
| | return False |
| |
|
| | csv_path = csv_files[0] |
| | json_path = csv_path.with_suffix(".json") |
| | pkl_path = csv_path.with_suffix(".pkl") |
| |
|
| | |
| | umap_params = {"n_neighbors": 15, "min_dist": 0.1, "layers": ["encoder_bottleneck", "midblock"]} |
| | if json_path.exists(): |
| | with open(json_path, "r") as f: |
| | umap_params = json.load(f) |
| |
|
| | pca_components = get_pca_components() |
| | print(f" Regenerating UMAP for {model}...") |
| | print(f" Params: n_neighbors={umap_params.get('n_neighbors', 15)}, min_dist={umap_params.get('min_dist', 0.1)}, pca={pca_components}") |
| |
|
| | try: |
| | |
| | activations, metadata_df = load_dataset_activations(activation_dir, metadata_path) |
| | print(f" Loaded {activations.shape[0]} activations ({activations.shape[1]} dims)") |
| |
|
| | |
| | embeddings, reducer, scaler, pca_reducer = compute_umap( |
| | activations, |
| | n_neighbors=umap_params.get("n_neighbors", 15), |
| | min_dist=umap_params.get("min_dist", 0.1), |
| | normalize=True, |
| | pca_components=pca_components, |
| | ) |
| |
|
| | |
| | save_embeddings(embeddings, metadata_df, csv_path, umap_params, reducer, scaler, pca_reducer) |
| | print(f" UMAP pickle regenerated: {pkl_path}") |
| | return True |
| |
|
| | except Exception as e: |
| | print(f" Error regenerating UMAP: {e}") |
| | return False |
| |
|
| |
|
| | def check_umap_compatibility(data_dir: Path, model: str) -> bool: |
| | """Check if UMAP pickle is compatible with current numba environment.""" |
| | embeddings_dir = data_dir / model / "embeddings" |
| | pkl_files = list(embeddings_dir.glob("*.pkl")) |
| |
|
| | if not pkl_files: |
| | return True |
| |
|
| | pkl_path = pkl_files[0] |
| |
|
| | try: |
| | import pickle |
| | with open(pkl_path, "rb") as f: |
| | umap_data = pickle.load(f) |
| |
|
| | reducer = umap_data.get("reducer") |
| | if reducer is None: |
| | return True |
| |
|
| | |
| | import numpy as np |
| | dummy = np.random.randn(1, 100).astype(np.float32) |
| |
|
| | |
| | scaler = umap_data.get("scaler") |
| | if scaler: |
| | dummy_scaled = scaler.transform(dummy) |
| | else: |
| | dummy_scaled = dummy |
| |
|
| | |
| | _ = reducer.transform(dummy_scaled) |
| | return True |
| |
|
| | except Exception as e: |
| | print(f" UMAP compatibility check failed for {model}: {e}") |
| | return False |
| |
|
| |
|
| | def ensure_data_ready(data_dir: Path, checkpoints: list) -> bool: |
| | """Ensure data and checkpoints are downloaded.""" |
| | print(f"Checking for existing data in {data_dir.absolute()}...") |
| |
|
| | |
| | models_with_data = [] |
| | for model in ["dmd2", "edm"]: |
| | config_path = data_dir / model / "config.json" |
| | embeddings_dir = data_dir / model / "embeddings" |
| | images_dir = data_dir / model / "images" / "imagenet_real" |
| |
|
| | if not config_path.exists(): |
| | continue |
| | if not embeddings_dir.exists(): |
| | continue |
| |
|
| | csv_files = list(embeddings_dir.glob("*.csv")) |
| | png_files = list(images_dir.glob("sample_*.png")) if images_dir.exists() else [] |
| |
|
| | if csv_files and png_files: |
| | models_with_data.append(model) |
| | print(f" Found {model}: {len(csv_files)} csv, {len(png_files)} images") |
| |
|
| | if not models_with_data: |
| | print("Data not found, downloading...") |
| | download_data(data_dir) |
| | else: |
| | print(f"Data already present: {models_with_data}") |
| |
|
| | |
| | for model in checkpoints: |
| | download_checkpoint(data_dir, model) |
| |
|
| | |
| | |
| | print("\nRegenerating UMAP pickles for numba compatibility...") |
| | for model in ["dmd2", "edm"]: |
| | model_dir = data_dir / model |
| | if not model_dir.exists(): |
| | print(f" {model}: model dir not found, skipping") |
| | continue |
| |
|
| | embeddings_dir = model_dir / "embeddings" |
| | if not embeddings_dir.exists() or not list(embeddings_dir.glob("*.csv")): |
| | print(f" {model}: no embeddings found, skipping") |
| | continue |
| |
|
| | print(f" {model}: regenerating UMAP...") |
| | regenerate_umap(data_dir, model) |
| |
|
| | return True |
| |
|
| |
|
| | def get_device() -> str: |
| | """Auto-detect best available device.""" |
| | override = os.environ.get("DIFFVIEWS_DEVICE") |
| | if override: |
| | return override |
| |
|
| | import torch |
| |
|
| | if torch.cuda.is_available(): |
| | return "cuda" |
| | if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| | return "mps" |
| | return "cpu" |
| |
|
| |
|
| | @spaces.GPU(duration=120) |
| | def generate_on_gpu( |
| | model_name, all_neighbors, class_label, |
| | n_steps, m_steps, s_max, s_min, guidance, noise_mode, |
| | extract_layers, can_project |
| | ): |
| | """Run masked generation on GPU. Must live in app_file for ZeroGPU detection.""" |
| | from diffviews.visualization.app import _app_visualizer as visualizer |
| | from diffviews.core.masking import ActivationMasker |
| | from diffviews.core.generator import generate_with_mask_multistep |
| |
|
| | with visualizer._generation_lock: |
| | adapter = visualizer.load_adapter(model_name) |
| | if adapter is None: |
| | return None |
| |
|
| | activation_dict = visualizer.prepare_activation_dict(model_name, all_neighbors) |
| | if activation_dict is None: |
| | return None |
| |
|
| | masker = ActivationMasker(adapter) |
| | for layer_name, activation in activation_dict.items(): |
| | masker.set_mask(layer_name, activation) |
| | masker.register_hooks(list(activation_dict.keys())) |
| |
|
| | try: |
| | result = generate_with_mask_multistep( |
| | adapter, |
| | masker, |
| | class_label=class_label, |
| | num_steps=int(n_steps), |
| | mask_steps=int(m_steps), |
| | sigma_max=float(s_max), |
| | sigma_min=float(s_min), |
| | guidance_scale=float(guidance), |
| | noise_mode=(noise_mode or "stochastic noise").replace(" noise", ""), |
| | num_samples=1, |
| | device=visualizer.device, |
| | extract_layers=extract_layers if can_project else None, |
| | return_trajectory=can_project, |
| | return_intermediates=True, |
| | return_noised_inputs=True, |
| | ) |
| | finally: |
| | masker.remove_hooks() |
| |
|
| | return result |
| |
|
| |
|
| | @spaces.GPU(duration=180) |
| | def extract_layer_on_gpu(model_name, layer_name, batch_size=32): |
| | """Extract layer activations on GPU. Must live in app_file for ZeroGPU detection.""" |
| | from diffviews.visualization.app import _app_visualizer as visualizer |
| | return visualizer.extract_layer_activations(model_name, layer_name, batch_size) |
| |
|
| |
|
| | def _setup(): |
| | """Initialize data, visualizer, and Gradio app.""" |
| | data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data")) |
| | checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "all") |
| | device = get_device() |
| |
|
| | if checkpoint_config == "all": |
| | checkpoints = list(CHECKPOINT_URLS.keys()) |
| | elif checkpoint_config == "none": |
| | checkpoints = [] |
| | else: |
| | checkpoints = [c.strip() for c in checkpoint_config.split(",") if c.strip()] |
| |
|
| | print("=" * 50) |
| | print("DiffViews - Diffusion Activation Visualizer") |
| | print("=" * 50) |
| | print(f"Data directory: {data_dir.absolute()}") |
| | print(f"Device: {device}") |
| | print(f"Checkpoints: {checkpoints}") |
| | print("=" * 50) |
| |
|
| | ensure_data_ready(data_dir, checkpoints) |
| |
|
| | import diffviews.visualization.app as viz_mod |
| | from diffviews.visualization.app import ( |
| | GradioVisualizer, |
| | create_gradio_app, |
| | ) |
| |
|
| | |
| | |
| | viz_mod._generate_on_gpu = generate_on_gpu |
| | viz_mod._extract_layer_on_gpu = extract_layer_on_gpu |
| |
|
| | print("\nInitializing visualizer...") |
| | visualizer = GradioVisualizer(data_dir=data_dir, device=device) |
| |
|
| | print("Creating Gradio app...") |
| | app = create_gradio_app(visualizer) |
| | app.queue(max_size=20) |
| | return app |
| |
|
| |
|
| | |
| | |
| | demo = _setup() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import gradio as gr |
| | from diffviews.visualization.app import CUSTOM_CSS, PLOTLY_HANDLER_JS |
| |
|
| | demo.launch( |
| | server_name="0.0.0.0", |
| | server_port=7860, |
| | share=False, |
| | theme=gr.themes.Soft(), |
| | css=CUSTOM_CSS, |
| | js=PLOTLY_HANDLER_JS, |
| | ) |
| |
|