SD/E2E-SD VAE to DINOv2 Bridge

GET Dino v2 features from SD VAE latents

Item Value
Input SD/E2E-SD VAE latent
Input shape [B, 4, 32, 32]
Output DINOv2 patch-token features
Output shape [B, 64, 768]
Patch grid 8 ร— 8
CLS token Not included
DINO target family DINOv2-base-style, 768-dim
Bridge body Adapter + Transformer bridge
Current training dataset kingsidharth/zangei-dit-stage-1-250k
Training rows ~220k
Main checkpoint checkpoints/best.pt
Latest checkpoint checkpoints/latest.pt

Architecture

Our design cleanly separates the modality-specific layers from the spatial processing body:

Latent Adapter: A lightweight convolutional stem (VAE-specific). Maps 4-channel VAE latents up to the bridge's working width. Bridge Backbone: A standard transformer body (width 768, depth 8) that remains VAE-agnostic. Token Head: A linear projection that maps transformer outputs to the expected DINO patch targets (e.g., 64 tokens of 768 dim).

Note: This decoupled design means for future models like FLUX, you can swap out just the Latent Adapter (to handle 16-channel latents) while freezing/reusing the learned bridge backbone.

Loss Function

The training utilizes a composite, geometry-aware loss function (bridge_loss) designed to prioritize structural and directional alignment over raw magnitude matching:

  • Cosine Loss (Weight: 1.0): 1.0 - cosine_similarity(pred, target). The primary driver, focusing heavily on matching the semantic direction of the DINOv2 features.
  • MSE Norm (Weight: 0.25): Standard MSE applied after L2-normalizing the predictions and targets.
  • MSE Raw (Weight: 0.05): Standard MSE applied to the raw values. Keeps the scale grounded without letting magnitude differences dominate the gradients.

Training History

Epoch Val loss Val cosine โ†‘ Val NMSE โ†“ Retrieval@1 โ†‘ Retrieval@5 โ†‘ Retrieval@10 โ†‘
1 0.401512 0.624136 0.610722 0.751818 0.905000 0.940455
2 0.335588 0.686514 0.526457 0.910000 0.979545 0.987727
3 0.303365 0.716936 0.483648 0.960000 0.990000 0.995000
4 0.282880 0.736246 0.455694 0.975455 0.995000 0.996818
5 0.268303 0.749953 0.434974 0.987273 0.996364 0.998182
6 0.258487 0.759174 0.420825 0.988182 0.996364 0.998636

Current Quality Read

The bridge is learning correctly.

Strong signal" Retrieval@1 reached ~98.8% on the held-out validation subset.

This means the predicted features preserve enough image identity / semantic structure to retrieve the matching true DINO target among validation candidates.

However, raw patch cosine is still: ~0.759

So the bridge is not yet a perfect DINO replacement. It is already useful for ranking / retrieval-like proxy supervision, but should be improved before being treated as a high-fidelity DINO teacher.

Suggested target before production use as a serious DINO proxy:

  • val cosine: 0.85+
  • val NMSE: <0.25
  • retrieval@1: remain >0.95 on larger external eval

How to Use

Prepare Data: Pre-pack your SD latents ([N, 4, 32, 32]) and DINOv2 features ([N, 64, 768]) into memory-mappable .npy files. Ensure you are targeting the 8x8 patch grid, excluding the DINO CLS token. Configure: Update paths and training knobs in the @dataclass class CFG (Cell 3). This serves as the single source of truth for the run. Run All: The notebook will handle package installation, wandb logging, dataset splitting, and mixed-precision (AMP) training automatically.

Basic

import torch

ckpt = torch.load("checkpoints/best.pt", map_location="cpu")


state_dict = ckpt["model"] if "model" in ckpt else ckpt

model = DinoBridgeV3(
    in_ch=4,
    target_tokens=64,
    target_dim=768,
    adapter_mid_channels=256,
    adapter_out_channels=512,
    adapter_depth=2,
    width=768,
    depth=8,
    heads=12,
    mlp_ratio=4.0,
    dropout=0.02,
)

model.load_state_dict(state_dict, strict=True)
model.eval().cuda()

Advanced

import torch
import torch.nn.functional as F

best_ckpt_path = OUT_DIR / "best.pt"
print(f"Loading checkpoint from: {best_ckpt_path}")

inference_model = DinoBridgeV3(
    in_ch=in_ch,
    target_tokens=TARGET_TOKENS,
    target_dim=TARGET_DIM,
    adapter_mid_channels=cfg.adapter_mid_channels,
    adapter_out_channels=cfg.adapter_out_channels,
    adapter_depth=cfg.adapter_depth,
    width=cfg.model_width,
    depth=cfg.model_depth,
    heads=cfg.model_heads,
    mlp_ratio=cfg.mlp_ratio,
    dropout=cfg.dropout,
).to(device)


if best_ckpt_path.exists():
    ckpt = torch.load(best_ckpt_path, map_location=device)
    inference_model.load_state_dict(ckpt["model"])
    print("Weights loaded successfully.")
else:
    print("Checkpoint not found. Make sure you have completed at least one training epoch.")

inference_model.eval()

# In practice, this would be the output from your VAE encoder: latent = vae.encode(image)
sample_latent, sample_target = val_ds[0]
sample_latent = sample_latent.unsqueeze(0).to(device).float() # Add batch dim: [1, C, H, W]

with torch.no_grad():
    with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=(device=="cuda" and cfg.amp)):
        pred_dino_features = inference_model(sample_latent)

print("\n--- Inference Results ---")
print("Input latent shape:", sample_latent.shape)
print("Predicted DINO features shape:", pred_dino_features.shape)
print("Ground truth DINO features shape:", sample_target.shape)

# 5. Quick comparison to ground truth (Cosine Similarity)
pred_norm = F.normalize(pred_dino_features[0].float(), dim=-1)
target_norm = F.normalize(sample_target.to(device).float(), dim=-1)
sim = F.cosine_similarity(pred_norm, target_norm, dim=-1).mean().item()
print(f"Average Cosine Similarity for this sample: {sim:.4f}")

Short HF model-card table

Section Value
Repo kingsidharth/sd_vae_2_dino_v2_bridge
Task VAE latent โ†’ DINOv2 feature prediction
Input [B, 4, 32, 32] SD/E2E-SD latent
Output [B, 64, 768] DINOv2 patch tokens
Best checkpoint checkpoints/best.pt
Safe checkpoint after interrupt checkpoints/epoch_006.pt
Latest checkpoint caveat latest.pt may be incomplete if interrupted during final save
Best logged val cosine 0.759174
Best logged val NMSE 0.420825
Best logged Retrieval@1 0.988182
Best logged Retrieval@5 0.996364
Best logged Retrieval@10 0.998636

The training log shows validation improving consistently from epoch 1 to epoch 6: cosine rose from 0.624136 to 0.759174, NMSE fell from 0.610722 to 0.420825, and Retrieval@1 rose from 0.751818 to 0.988182. The run was interrupted during the final latest.pt save after epoch_006.pt had already been saved, so best.pt / epoch_006.pt are the safer checkpoints. :contentReference[oaicite:0]{index=0}

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for kingsidharth/sd_vae_2_dino_v2_bridge

Finetuned
(29)
this model

Dataset used to train kingsidharth/sd_vae_2_dino_v2_bridge