data-archetype commited on
Commit
b32916f
·
verified ·
1 Parent(s): 3195457

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - diffusion
5
+ - autoencoder
6
+ - image-reconstruction
7
+ - image-tokenizer
8
+ - pytorch
9
+ - fcdm
10
+ - semantic-alignment
11
+ library_name: capacitor_diffae
12
+ ---
13
+
14
+ # data-archetype/semdisdiffae
15
+
16
+ **SemDisDiffAE** (**Sem**antically **Dis**entangled **Diff**usion **A**uto**E**ncoder)
17
+ — a fast image tokenizer with semantically structured 128-channel latents, built
18
+ on FCDM (Fully Convolutional Diffusion Model) blocks with a VP-parameterized
19
+ diagonal Gaussian posterior.
20
+
21
+ Trained with DINOv2 semantic alignment, this VAE was empirically found to
22
+ offer comparable downstream diffusion convergence speed to other semantically
23
+ aligned VAEs such as Flux.2 and PS-VAE v2, while being much faster to encode
24
+ and decode and achieving very high reconstruction quality (38.6 dB mean PSNR
25
+ on 2k images).
26
+
27
+ Built on a pure convolutional architecture with no attention layers in the
28
+ encoder or decoder, enabling efficient inference at any resolution.
29
+
30
+ ## Key Features
31
+
32
+ - **Fast**: ~3 ms/img encode, ~6 ms/img decode (1 step) on H200 — significantly
33
+ faster than Flux.2 VAE
34
+ - **High fidelity**: 38.6 dB mean PSNR (2k images), exceeding Flux.2 VAE (37.0 dB)
35
+ - **Semantically structured latents**: DINOv2-aligned, producing latents with
36
+ clear semantic segmentation visible in PCA projections
37
+ - **Comparable downstream convergence**: empirically matches the downstream
38
+ diffusion training convergence speed of Flux.2 and PS-VAE v2
39
+ - **Pure convolutional**: no attention in encoder/decoder, O(n) in spatial resolution
40
+ - **VP diffusion decoder**: single-step DDIM for PSNR-optimal, multi-step
41
+ with PDG for perceptual sharpening
42
+
43
+ ## Architecture
44
+
45
+ | Property | Value |
46
+ |----------|-------|
47
+ | Parameters | 88.8M |
48
+ | Patch size | 16 |
49
+ | Model dim | 896 |
50
+ | Encoder depth | 4 blocks |
51
+ | Decoder depth | 8 blocks (2+4+2 skip-concat) |
52
+ | Bottleneck | 128 channels |
53
+ | Compression | 16x spatial, 6.0x total |
54
+ | Posterior | Diagonal Gaussian (VP log-SNR) |
55
+ | Block type | FCDM (ConvNeXt + GRN + scale/gate AdaLN) |
56
+
57
+ ## Quick Start
58
+
59
+ ```python
60
+ from capacitor_diffae import CapacitorDiffAE, CapacitorDiffAEInferenceConfig
61
+
62
+ model = CapacitorDiffAE.from_pretrained("data-archetype/semdisdiffae", device="cuda")
63
+
64
+ # Encode (returns posterior mode by default)
65
+ latents = model.encode(images) # [B,3,H,W] in [-1,1] -> [B,128,H/16,W/16]
66
+
67
+ # Decode — PSNR-optimal (1 step, default)
68
+ recon = model.decode(latents, height=H, width=W)
69
+
70
+ # Decode — perceptual sharpening (10 steps + PDG)
71
+ cfg = CapacitorDiffAEInferenceConfig(num_steps=10, pdg=True, pdg_strength=2.0)
72
+ recon = model.decode(latents, height=H, width=W, inference_config=cfg)
73
+
74
+ # Full posterior access
75
+ posterior = model.encode_posterior(images)
76
+ z_sampled = posterior.sample()
77
+ ```
78
+
79
+ ## Recommended Settings
80
+
81
+ | Use case | Steps | PDG | Notes |
82
+ |----------|-------|-----|-------|
83
+ | PSNR-optimal | 1 | off | Default, fastest |
84
+ | Perceptual | 10 | on (2.0) | Sharper, ~15x slower |
85
+
86
+ PDG is primarily useful for more compressed bottlenecks (32 or 64 channels)
87
+ and is rarely necessary for 128-channel models where reconstruction quality
88
+ is already high.
89
+
90
+ ## Training
91
+
92
+ Trained with:
93
+ - Pixel-space VP diffusion reconstruction loss (x-prediction, SiD2 weighting)
94
+ - DINOv2-S semantic alignment (negative cosine, weight 0.01)
95
+ - VP posterior variance expansion (weight 1e-5)
96
+ - Latent scale regularization (weight 0.0001)
97
+ - AdamW optimizer, bf16 mixed precision, EMA decay 0.9995
98
+ - 251k steps on a single GPU
99
+
100
+ See the [technical report](technical_report_semantic.md) for full details.
101
+
102
+ ## Dependencies
103
+
104
+ - PyTorch >= 2.0
105
+ - safetensors (for loading weights)
106
+
107
+ ## License
108
+
109
+ Apache 2.0
_results_appendix_semantic.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 7. Results
2
+
3
+ Reconstruction quality evaluated on a curated set of test images covering photographs, book covers, and documents. Flux.1 VAE (patch 8, 16 channels) is included as a reference at the same 12x compression ratio as the c64 variant.
4
+
5
+ ### 7.1 Interactive Viewer
6
+
7
+ **[Open full-resolution comparison viewer](https://huggingface.co/spaces/data-archetype/irdiffae-results)** — side-by-side reconstructions, RGB deltas, and latent PCA with adjustable image size.
8
+
9
+ ### 7.2 Inference Settings
10
+
11
+ | Setting | Value |
12
+ |---------|-------|
13
+ | Sampler | ddim |
14
+ | Steps | 1 |
15
+ | Schedule | linear |
16
+ | Seed | 42 |
17
+ | PDG | no_path_dropg |
18
+ | Batch size (timing) | 4 |
19
+
20
+ > All models run in bfloat16. Timings measured on an NVIDIA RTX Pro 6000 (Blackwell).
21
+
22
+ ### 7.3 Global Metrics
23
+
24
+ | Metric | semdisdiffae (1 step) | Flux.2 VAE |
25
+ |--------|--------|--------|
26
+ | Avg PSNR (dB) | 35.78 | 34.16 |
27
+ | Avg encode (ms/image) | 2.5 | 46.1 |
28
+ | Avg decode (ms/image) | 5.5 | 91.8 |
29
+
30
+ ### 7.4 Per-Image PSNR (dB)
31
+
32
+ | Image | semdisdiffae (1 step) | Flux.2 VAE |
33
+ |-------|--------|--------|
34
+ | p640x1536:94623 | 35.44 | 33.50 |
35
+ | p640x1536:94624 | 31.33 | 30.03 |
36
+ | p640x1536:94625 | 35.05 | 33.98 |
37
+ | p640x1536:94626 | 33.21 | 31.53 |
38
+ | p640x1536:94627 | 32.54 | 30.53 |
39
+ | p640x1536:94628 | 29.80 | 28.88 |
40
+ | p960x1024:216264 | 46.37 | 45.39 |
41
+ | p960x1024:216265 | 29.70 | 27.80 |
42
+ | p960x1024:216266 | 47.15 | 46.20 |
43
+ | p960x1024:216267 | 40.99 | 39.23 |
44
+ | p960x1024:216268 | 38.47 | 36.13 |
45
+ | p960x1024:216269 | 32.74 | 30.24 |
46
+ | p960x1024:216270 | 36.23 | 34.18 |
47
+ | p960x1024:216271 | 44.41 | 42.18 |
48
+ | p704x1472:94699 | 43.80 | 41.79 |
49
+ | p704x1472:94700 | 32.83 | 32.08 |
50
+ | p704x1472:94701 | 39.00 | 37.90 |
51
+ | p704x1472:94702 | 34.52 | 32.50 |
52
+ | p704x1472:94703 | 32.81 | 31.35 |
53
+ | p704x1472:94704 | 33.38 | 31.84 |
54
+ | p704x1472:94705 | 39.70 | 37.44 |
55
+ | p704x1472:94706 | 35.12 | 33.66 |
56
+ | r256_p1344x704:15577 | 31.02 | 29.98 |
57
+ | r256_p1344x704:15578 | 32.38 | 30.79 |
58
+ | r256_p1344x704:15579 | 33.27 | 31.83 |
59
+ | r256_p1344x704:15580 | 37.84 | 36.03 |
60
+ | r256_p1344x704:15581 | 38.57 | 36.94 |
61
+ | r256_p1344x704:15582 | 33.41 | 32.10 |
62
+ | r256_p1344x704:15583 | 36.67 | 34.54 |
63
+ | r256_p1344x704:15584 | 33.23 | 31.76 |
64
+ | r256_p896x1152:144131 | 35.30 | 33.60 |
65
+ | r256_p896x1152:144132 | 36.99 | 35.32 |
66
+ | r256_p896x1152:144133 | 39.69 | 37.33 |
67
+ | r256_p896x1152:144134 | 36.01 | 34.47 |
68
+ | r256_p896x1152:144135 | 31.20 | 29.87 |
69
+ | r256_p896x1152:144136 | 37.51 | 35.68 |
70
+ | r256_p896x1152:144137 | 33.83 | 32.86 |
71
+ | r256_p896x1152:144138 | 27.39 | 25.63 |
72
+ | VAE_accuracy_test_image | 36.64 | 35.25 |
capacitor_diffae/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CapacitorDiffAE: Standalone diffusion autoencoder with FCDM blocks.
2
+
3
+ Capacitor DiffAE — a fast diffusion autoencoder with a 128-channel spatial
4
+ bottleneck and a VP-parameterized diagonal Gaussian posterior. Built on FCDM
5
+ (Fully Convolutional Diffusion Model) blocks with GRN and scale+gate AdaLN.
6
+
7
+ Usage::
8
+
9
+ from capacitor_diffae import CapacitorDiffAE, CapacitorDiffAEInferenceConfig
10
+
11
+ model = CapacitorDiffAE.from_pretrained("path/to/weights", device="cuda")
12
+
13
+ # Encode (returns posterior mode by default)
14
+ latents = model.encode(images) # images: [B,3,H,W] in [-1,1]
15
+
16
+ # Decode — PSNR-optimal (1 step, default)
17
+ recon = model.decode(latents, height=H, width=W)
18
+
19
+ # Decode — perceptual sharpness (10 steps + path-drop PDG)
20
+ cfg = CapacitorDiffAEInferenceConfig(num_steps=10, pdg=True, pdg_strength=2.0)
21
+ recon = model.decode(latents, height=H, width=W, inference_config=cfg)
22
+ """
23
+
24
+ from .config import CapacitorDiffAEConfig, CapacitorDiffAEInferenceConfig
25
+ from .encoder import EncoderPosterior
26
+ from .model import CapacitorDiffAE
27
+
28
+ __all__ = [
29
+ "CapacitorDiffAE",
30
+ "CapacitorDiffAEConfig",
31
+ "CapacitorDiffAEInferenceConfig",
32
+ "EncoderPosterior",
33
+ ]
capacitor_diffae/adaln.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scale+Gate AdaLN (2-way) for FCDM decoder blocks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class AdaLNScaleGateZeroProjector(nn.Module):
9
+ """Packed 2-way AdaLN projection (SiLU -> Linear), zero-initialized.
10
+
11
+ Outputs [B, 2*d_model] packed as (scale, gate).
12
+ """
13
+
14
+ def __init__(self, d_model: int, d_cond: int) -> None:
15
+ super().__init__()
16
+ self.d_model: int = int(d_model)
17
+ self.d_cond: int = int(d_cond)
18
+ self.act: nn.SiLU = nn.SiLU()
19
+ self.proj: nn.Linear = nn.Linear(self.d_cond, 2 * self.d_model)
20
+ nn.init.zeros_(self.proj.weight)
21
+ nn.init.zeros_(self.proj.bias)
22
+
23
+ def forward_activated(self, act_cond: Tensor) -> Tensor:
24
+ """Return packed modulation for a pre-activated conditioning vector."""
25
+ return self.proj(act_cond)
26
+
27
+ def forward(self, cond: Tensor) -> Tensor:
28
+ """Return packed modulation [B, 2*d_model]."""
29
+ return self.forward_activated(self.act(cond))
30
+
31
+
32
+ class AdaLNScaleGateZeroLowRankDelta(nn.Module):
33
+ """Low-rank delta for 2-way AdaLN: down(d_cond -> rank) -> up(rank -> 2*d_model).
34
+
35
+ Zero-initialized up projection preserves zero-output semantics at init.
36
+ """
37
+
38
+ def __init__(self, *, d_model: int, d_cond: int, rank: int) -> None:
39
+ super().__init__()
40
+ self.d_model: int = int(d_model)
41
+ self.d_cond: int = int(d_cond)
42
+ self.rank: int = int(rank)
43
+ self.down: nn.Linear = nn.Linear(self.d_cond, self.rank, bias=False)
44
+ self.up: nn.Linear = nn.Linear(self.rank, 2 * self.d_model, bias=False)
45
+ nn.init.normal_(self.down.weight, mean=0.0, std=0.02)
46
+ nn.init.zeros_(self.up.weight)
47
+
48
+ def forward(self, act_cond: Tensor) -> Tensor:
49
+ """Return packed delta modulation [B, 2*d_model]."""
50
+ return self.up(self.down(act_cond))
capacitor_diffae/config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Frozen model architecture and user-tunable inference configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import asdict, dataclass
7
+ from pathlib import Path
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class CapacitorDiffAEConfig:
12
+ """Frozen model architecture config. Stored alongside weights as config.json."""
13
+
14
+ in_channels: int = 3
15
+ patch_size: int = 16
16
+ model_dim: int = 896
17
+ encoder_depth: int = 4
18
+ decoder_depth: int = 8
19
+ decoder_start_blocks: int = 2
20
+ decoder_end_blocks: int = 2
21
+ bottleneck_dim: int = 128
22
+ mlp_ratio: float = 4.0
23
+ depthwise_kernel_size: int = 7
24
+ adaln_low_rank_rank: int = 128
25
+ # Encoder posterior kind: "diagonal_gaussian" or "deterministic"
26
+ bottleneck_posterior_kind: str = "diagonal_gaussian"
27
+ # Post-bottleneck normalization: "channel_wise" or "disabled"
28
+ bottleneck_norm_mode: str = "disabled"
29
+ # VP diffusion schedule endpoints
30
+ logsnr_min: float = -10.0
31
+ logsnr_max: float = 10.0
32
+ # Pixel-space noise std for VP diffusion initialization
33
+ pixel_noise_std: float = 0.558
34
+
35
+ def save(self, path: str | Path) -> None:
36
+ """Save config as JSON."""
37
+ p = Path(path)
38
+ p.parent.mkdir(parents=True, exist_ok=True)
39
+ p.write_text(json.dumps(asdict(self), indent=2) + "\n")
40
+
41
+ @classmethod
42
+ def load(cls, path: str | Path) -> CapacitorDiffAEConfig:
43
+ """Load config from JSON."""
44
+ data = json.loads(Path(path).read_text())
45
+ return cls(**data)
46
+
47
+
48
+ @dataclass
49
+ class CapacitorDiffAEInferenceConfig:
50
+ """User-tunable inference parameters with sensible defaults.
51
+
52
+ PDG (Path-Drop Guidance) sharpens reconstructions by degrading conditioning
53
+ in one pass and amplifying the difference. When enabled, uses 2 NFE per step.
54
+ Recommended: ``pdg=True, pdg_strength=2.0, num_steps=10``.
55
+ """
56
+
57
+ num_steps: int = 1 # number of denoising steps (NFE)
58
+ sampler: str = "ddim" # "ddim" or "dpmpp_2m"
59
+ schedule: str = "linear" # "linear" or "cosine"
60
+ pdg: bool = False # enable PDG for perceptual sharpening
61
+ pdg_strength: float = 2.0 # CFG-like strength when pdg=True
62
+ seed: int | None = None
capacitor_diffae/decoder.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Capacitor decoder: skip-concat topology with FCDM blocks and dual PDG.
2
+
3
+ No outer RMSNorms (use_other_outer_rms_norms=False during training):
4
+ norm_in, latent_norm, and norm_out are all absent.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+ from torch import Tensor, nn
11
+
12
+ from .adaln import AdaLNScaleGateZeroLowRankDelta, AdaLNScaleGateZeroProjector
13
+ from .fcdm_block import FCDMBlock
14
+ from .straight_through_encoder import Patchify
15
+ from .time_embed import SinusoidalTimeEmbeddingMLP
16
+
17
+
18
+ class Decoder(nn.Module):
19
+ """VP diffusion decoder conditioned on encoder latents and timestep.
20
+
21
+ Architecture (skip-concat, 2+4+2 default):
22
+ Patchify x_t -> Fuse with upsampled z
23
+ -> Start blocks (2) -> Middle blocks (4) -> Skip fuse -> End blocks (2)
24
+ -> Conv1x1 -> PixelShuffle
25
+
26
+ Dual PDG at inference:
27
+ - Path drop: replace middle block output with ``path_drop_mask_feature``.
28
+ - Token mask: replace a fraction of upsampled latent tokens with
29
+ ``latent_mask_feature`` before fusion.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ in_channels: int,
35
+ patch_size: int,
36
+ model_dim: int,
37
+ depth: int,
38
+ start_block_count: int,
39
+ end_block_count: int,
40
+ bottleneck_dim: int,
41
+ mlp_ratio: float,
42
+ depthwise_kernel_size: int,
43
+ adaln_low_rank_rank: int,
44
+ ) -> None:
45
+ super().__init__()
46
+ self.patch_size = int(patch_size)
47
+ self.model_dim = int(model_dim)
48
+
49
+ # Input processing (no norm_in)
50
+ self.patchify = Patchify(in_channels, patch_size, model_dim)
51
+
52
+ # Latent conditioning path (no latent_norm)
53
+ self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
54
+ self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
55
+
56
+ # Time embedding
57
+ self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
58
+
59
+ # 2-way AdaLN: shared base projector + per-block low-rank deltas
60
+ self.adaln_base = AdaLNScaleGateZeroProjector(
61
+ d_model=model_dim, d_cond=model_dim
62
+ )
63
+ self.adaln_deltas = nn.ModuleList(
64
+ [
65
+ AdaLNScaleGateZeroLowRankDelta(
66
+ d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
67
+ )
68
+ for _ in range(depth)
69
+ ]
70
+ )
71
+
72
+ # Block layout: start + middle + end
73
+ middle_count = depth - start_block_count - end_block_count
74
+ self._middle_start_idx = start_block_count
75
+ self._end_start_idx = start_block_count + middle_count
76
+
77
+ def _make_blocks(count: int) -> nn.ModuleList:
78
+ return nn.ModuleList(
79
+ [
80
+ FCDMBlock(
81
+ model_dim,
82
+ mlp_ratio,
83
+ depthwise_kernel_size=depthwise_kernel_size,
84
+ use_external_adaln=True,
85
+ )
86
+ for _ in range(count)
87
+ ]
88
+ )
89
+
90
+ self.start_blocks = _make_blocks(start_block_count)
91
+ self.middle_blocks = _make_blocks(middle_count)
92
+ self.fuse_skip = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
93
+ self.end_blocks = _make_blocks(end_block_count)
94
+
95
+ # Learned mask feature for path-drop PDG
96
+ self.path_drop_mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
97
+
98
+ # Output head (no norm_out)
99
+ self.out_proj = nn.Conv2d(
100
+ model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
101
+ )
102
+ self.unpatchify = nn.PixelShuffle(patch_size)
103
+
104
+ def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
105
+ """Compute packed AdaLN modulation = shared_base + per-layer delta."""
106
+ act = self.adaln_base.act(cond)
107
+ base_m = self.adaln_base.forward_activated(act)
108
+ delta_m = self.adaln_deltas[layer_idx](act)
109
+ return base_m + delta_m
110
+
111
+ def _run_blocks(
112
+ self, blocks: nn.ModuleList, x: Tensor, cond: Tensor, start_index: int
113
+ ) -> Tensor:
114
+ """Run a group of decoder blocks with per-block AdaLN modulation."""
115
+ for local_idx, block in enumerate(blocks):
116
+ adaln_m = self._adaln_m_for_layer(cond, layer_idx=start_index + local_idx)
117
+ x = block(x, adaln_m=adaln_m)
118
+ return x
119
+
120
+ def forward(
121
+ self,
122
+ x_t: Tensor,
123
+ t: Tensor,
124
+ latents: Tensor,
125
+ *,
126
+ drop_middle_blocks: bool = False,
127
+ ) -> Tensor:
128
+ """Single decoder forward pass.
129
+
130
+ Args:
131
+ x_t: Noised image [B, C, H, W].
132
+ t: Timestep [B] in [0, 1].
133
+ latents: Encoder latents [B, bottleneck_dim, h, w].
134
+ drop_middle_blocks: Replace middle block output with mask feature (PDG).
135
+
136
+ Returns:
137
+ x0 prediction [B, C, H, W].
138
+ """
139
+ x_feat = self.patchify(x_t)
140
+ z_up = self.latent_up(latents)
141
+
142
+ fused = torch.cat([x_feat, z_up], dim=1)
143
+ fused = self.fuse_in(fused)
144
+
145
+ cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
146
+
147
+ start_out = self._run_blocks(self.start_blocks, fused, cond, start_index=0)
148
+
149
+ if drop_middle_blocks:
150
+ middle_out = self.path_drop_mask_feature.to(
151
+ device=x_t.device, dtype=x_t.dtype
152
+ ).expand_as(start_out)
153
+ else:
154
+ middle_out = self._run_blocks(
155
+ self.middle_blocks,
156
+ start_out,
157
+ cond,
158
+ start_index=self._middle_start_idx,
159
+ )
160
+
161
+ skip_fused = torch.cat([start_out, middle_out], dim=1)
162
+ skip_fused = self.fuse_skip(skip_fused)
163
+
164
+ end_out = self._run_blocks(
165
+ self.end_blocks, skip_fused, cond, start_index=self._end_start_idx
166
+ )
167
+
168
+ patches = self.out_proj(end_out)
169
+ return self.unpatchify(patches)
capacitor_diffae/encoder.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Capacitor encoder: patchify -> FCDMBlocks -> diagonal Gaussian posterior.
2
+
3
+ No input RMSNorm (use_other_outer_rms_norms=False during training).
4
+ Post-bottleneck RMSNorm (affine=False) on the mean branch.
5
+ Encoder outputs posterior mode by default: alpha * RMSNorm(mean).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+
12
+ import torch
13
+ from torch import Tensor, nn
14
+
15
+ from .fcdm_block import FCDMBlock
16
+ from .norms import ChannelWiseRMSNorm
17
+ from .straight_through_encoder import Patchify
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class EncoderPosterior:
22
+ """VP-parameterized diagonal Gaussian posterior.
23
+
24
+ mean: Clean signal branch mu [B, bottleneck_dim, h, w]
25
+ logsnr: Per-element log signal-to-noise ratio [B, bottleneck_dim, h, w]
26
+ """
27
+
28
+ mean: Tensor
29
+ logsnr: Tensor
30
+
31
+ @property
32
+ def alpha(self) -> Tensor:
33
+ """VP signal coefficient: sqrt(sigmoid(logsnr))."""
34
+ return torch.sigmoid(self.logsnr).sqrt()
35
+
36
+ @property
37
+ def sigma(self) -> Tensor:
38
+ """VP noise coefficient: sqrt(sigmoid(-logsnr))."""
39
+ return torch.sigmoid(-self.logsnr).sqrt()
40
+
41
+ def mode(self) -> Tensor:
42
+ """Posterior mode in token space: alpha * mean."""
43
+ return self.alpha.to(dtype=self.mean.dtype) * self.mean
44
+
45
+ def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
46
+ """Sample from posterior: alpha * mean + sigma * eps."""
47
+ eps = torch.randn_like(self.mean, generator=generator) # type: ignore[call-overload]
48
+ alpha = self.alpha.to(dtype=self.mean.dtype)
49
+ sigma = self.sigma.to(dtype=self.mean.dtype)
50
+ return alpha * self.mean + sigma * eps
51
+
52
+
53
+ class Encoder(nn.Module):
54
+ """Encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].
55
+
56
+ With diagonal_gaussian posterior, the to_bottleneck projection outputs
57
+ 2 * bottleneck_dim channels, split into mean and logsnr. The default
58
+ encode() returns the posterior mode: alpha * RMSNorm(mean).
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ in_channels: int,
64
+ patch_size: int,
65
+ model_dim: int,
66
+ depth: int,
67
+ bottleneck_dim: int,
68
+ mlp_ratio: float,
69
+ depthwise_kernel_size: int,
70
+ bottleneck_posterior_kind: str = "diagonal_gaussian",
71
+ bottleneck_norm_mode: str = "disabled",
72
+ ) -> None:
73
+ super().__init__()
74
+ self.bottleneck_dim = int(bottleneck_dim)
75
+ self.bottleneck_posterior_kind = bottleneck_posterior_kind
76
+ self.bottleneck_norm_mode = bottleneck_norm_mode
77
+ self.patchify = Patchify(in_channels, patch_size, model_dim)
78
+ self.blocks = nn.ModuleList(
79
+ [
80
+ FCDMBlock(
81
+ model_dim,
82
+ mlp_ratio,
83
+ depthwise_kernel_size=depthwise_kernel_size,
84
+ use_external_adaln=False,
85
+ )
86
+ for _ in range(depth)
87
+ ]
88
+ )
89
+ out_dim = (
90
+ 2 * bottleneck_dim
91
+ if bottleneck_posterior_kind == "diagonal_gaussian"
92
+ else bottleneck_dim
93
+ )
94
+ self.to_bottleneck = nn.Conv2d(model_dim, out_dim, kernel_size=1, bias=True)
95
+ if bottleneck_norm_mode == "channel_wise":
96
+ self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)
97
+ else:
98
+ self.norm_out = nn.Identity()
99
+
100
+ def encode_posterior(self, images: Tensor) -> EncoderPosterior:
101
+ """Encode images and return the full posterior (mean + logsnr).
102
+
103
+ Only valid when bottleneck_posterior_kind == "diagonal_gaussian".
104
+ """
105
+ z = self.patchify(images)
106
+ for block in self.blocks:
107
+ z = block(z)
108
+ projection = self.to_bottleneck(z)
109
+ mean, logsnr = projection.chunk(2, dim=1)
110
+ mean = self.norm_out(mean)
111
+ return EncoderPosterior(mean=mean, logsnr=logsnr)
112
+
113
+ def forward(self, images: Tensor) -> Tensor:
114
+ """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w].
115
+
116
+ Returns posterior mode (alpha * mean) for diagonal_gaussian,
117
+ or deterministic latents otherwise.
118
+ """
119
+ z = self.patchify(images)
120
+ for block in self.blocks:
121
+ z = block(z)
122
+ projection = self.to_bottleneck(z)
123
+ if self.bottleneck_posterior_kind == "diagonal_gaussian":
124
+ mean, logsnr = projection.chunk(2, dim=1)
125
+ mean = self.norm_out(mean)
126
+ alpha = torch.sigmoid(logsnr).sqrt().to(dtype=mean.dtype)
127
+ return alpha * mean
128
+ z = self.norm_out(projection)
129
+ return z
capacitor_diffae/fcdm_block.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDM block: ConvNeXt-style conv block with GRN and scale+gate AdaLN."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor, nn
8
+
9
+ from .norms import ChannelWiseRMSNorm
10
+
11
+
12
+ class GRN(nn.Module):
13
+ """Global Response Normalization for NCHW tensors."""
14
+
15
+ def __init__(self, channels: int, *, eps: float = 1e-6) -> None:
16
+ super().__init__()
17
+ self.eps: float = float(eps)
18
+ c = int(channels)
19
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
20
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
21
+
22
+ def forward(self, x: Tensor) -> Tensor:
23
+ g = torch.linalg.vector_norm(x, ord=2, dim=(2, 3), keepdim=True)
24
+ g_fp32 = g.to(dtype=torch.float32)
25
+ n = (g_fp32 / (g_fp32.mean(dim=1, keepdim=True) + self.eps)).to(dtype=x.dtype)
26
+ gamma = self.gamma.to(device=x.device, dtype=x.dtype)
27
+ beta = self.beta.to(device=x.device, dtype=x.dtype)
28
+ return gamma * (x * n) + beta + x
29
+
30
+
31
+ class FCDMBlock(nn.Module):
32
+ """ConvNeXt-style block with scale+gate AdaLN and GRN.
33
+
34
+ Two modes:
35
+ - Unconditioned (encoder): uses learned layer-scale for near-identity init.
36
+ - External AdaLN (decoder): receives packed [B, 2*C] modulation (scale, gate).
37
+ The gate is applied raw (no tanh).
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ channels: int,
43
+ mlp_ratio: float,
44
+ *,
45
+ depthwise_kernel_size: int = 7,
46
+ use_external_adaln: bool = False,
47
+ norm_eps: float = 1e-6,
48
+ layer_scale_init: float = 1e-3,
49
+ ) -> None:
50
+ super().__init__()
51
+ self.channels: int = int(channels)
52
+ self.mlp_ratio: float = float(mlp_ratio)
53
+
54
+ self.dwconv = nn.Conv2d(
55
+ channels,
56
+ channels,
57
+ kernel_size=depthwise_kernel_size,
58
+ padding=depthwise_kernel_size // 2,
59
+ stride=1,
60
+ groups=channels,
61
+ bias=True,
62
+ )
63
+ self.norm = ChannelWiseRMSNorm(channels, eps=float(norm_eps), affine=False)
64
+ hidden = max(int(float(channels) * float(mlp_ratio)), 1)
65
+ self.pwconv1 = nn.Conv2d(channels, hidden, kernel_size=1, bias=True)
66
+ self.grn = GRN(hidden, eps=1e-6)
67
+ self.pwconv2 = nn.Conv2d(hidden, channels, kernel_size=1, bias=True)
68
+
69
+ if not use_external_adaln:
70
+ self.layer_scale = nn.Parameter(
71
+ torch.full((channels,), float(layer_scale_init))
72
+ )
73
+ else:
74
+ self.register_parameter("layer_scale", None)
75
+
76
+ def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor:
77
+ b, c, _, _ = x.shape
78
+
79
+ if adaln_m is not None:
80
+ m = adaln_m.to(device=x.device, dtype=x.dtype)
81
+ scale, gate = m.chunk(2, dim=-1)
82
+ else:
83
+ scale = gate = None
84
+
85
+ h = self.dwconv(x)
86
+ h = self.norm(h)
87
+
88
+ if scale is not None:
89
+ h = h * (1.0 + scale.view(b, c, 1, 1))
90
+
91
+ h = self.pwconv1(h)
92
+ h = F.gelu(h)
93
+ h = self.grn(h)
94
+ h = self.pwconv2(h)
95
+
96
+ if gate is not None:
97
+ gate_view = gate.view(b, c, 1, 1)
98
+ else:
99
+ gate_view = self.layer_scale.view(1, c, 1, 1).to( # type: ignore[union-attr]
100
+ device=h.device, dtype=h.dtype
101
+ )
102
+
103
+ return x + gate_view * h
capacitor_diffae/model.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CapacitorDiffAE: standalone HuggingFace-compatible diffusion autoencoder."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ from .config import CapacitorDiffAEConfig, CapacitorDiffAEInferenceConfig
11
+ from .decoder import Decoder
12
+ from .encoder import Encoder, EncoderPosterior
13
+ from .samplers import run_ddim, run_dpmpp_2m
14
+ from .vp_diffusion import get_schedule, make_initial_state, sample_noise
15
+
16
+
17
+ def _resolve_model_dir(
18
+ path_or_repo_id: str | Path,
19
+ *,
20
+ revision: str | None,
21
+ cache_dir: str | Path | None,
22
+ ) -> Path:
23
+ """Resolve a local path or HuggingFace Hub repo ID to a local directory."""
24
+ local = Path(path_or_repo_id)
25
+ if local.is_dir():
26
+ return local
27
+ repo_id = str(path_or_repo_id)
28
+ try:
29
+ from huggingface_hub import snapshot_download
30
+ except ImportError:
31
+ raise ImportError(
32
+ f"'{repo_id}' is not an existing local directory. "
33
+ "To download from HuggingFace Hub, install huggingface_hub: "
34
+ "pip install huggingface_hub"
35
+ )
36
+ cache_dir_str = str(cache_dir) if cache_dir is not None else None
37
+ local_dir = snapshot_download(
38
+ repo_id,
39
+ revision=revision,
40
+ cache_dir=cache_dir_str,
41
+ )
42
+ return Path(local_dir)
43
+
44
+
45
+ class CapacitorDiffAE(nn.Module):
46
+ """Standalone Capacitor DiffAE model for HuggingFace distribution.
47
+
48
+ A diffusion autoencoder built on FCDM (Fully Convolutional Diffusion Model)
49
+ blocks. Encodes images to compact 128-channel spatial latents via a
50
+ VP-parameterized diagonal Gaussian posterior, and decodes them back via
51
+ iterative VP diffusion with a skip-concat decoder.
52
+
53
+ Usage::
54
+
55
+ model = CapacitorDiffAE.from_pretrained("path/to/weights")
56
+ model = model.to("cuda", dtype=torch.bfloat16)
57
+
58
+ # Encode (returns posterior mode by default)
59
+ latents = model.encode(images) # images: [B,3,H,W] in [-1,1]
60
+
61
+ # Decode (1 step by default — PSNR-optimal)
62
+ recon = model.decode(latents, height=H, width=W)
63
+
64
+ # Reconstruct (encode + 1-step decode)
65
+ recon = model.reconstruct(images)
66
+ """
67
+
68
+ def __init__(self, config: CapacitorDiffAEConfig) -> None:
69
+ super().__init__()
70
+ self.config = config
71
+
72
+ self.encoder = Encoder(
73
+ in_channels=config.in_channels,
74
+ patch_size=config.patch_size,
75
+ model_dim=config.model_dim,
76
+ depth=config.encoder_depth,
77
+ bottleneck_dim=config.bottleneck_dim,
78
+ mlp_ratio=config.mlp_ratio,
79
+ depthwise_kernel_size=config.depthwise_kernel_size,
80
+ bottleneck_posterior_kind=config.bottleneck_posterior_kind,
81
+ bottleneck_norm_mode=config.bottleneck_norm_mode,
82
+ )
83
+
84
+ self.decoder = Decoder(
85
+ in_channels=config.in_channels,
86
+ patch_size=config.patch_size,
87
+ model_dim=config.model_dim,
88
+ depth=config.decoder_depth,
89
+ start_block_count=config.decoder_start_blocks,
90
+ end_block_count=config.decoder_end_blocks,
91
+ bottleneck_dim=config.bottleneck_dim,
92
+ mlp_ratio=config.mlp_ratio,
93
+ depthwise_kernel_size=config.depthwise_kernel_size,
94
+ adaln_low_rank_rank=config.adaln_low_rank_rank,
95
+ )
96
+
97
+ @classmethod
98
+ def from_pretrained(
99
+ cls,
100
+ path_or_repo_id: str | Path,
101
+ *,
102
+ dtype: torch.dtype = torch.bfloat16,
103
+ device: str | torch.device = "cpu",
104
+ revision: str | None = None,
105
+ cache_dir: str | Path | None = None,
106
+ ) -> CapacitorDiffAE:
107
+ """Load a pretrained model from a local directory or HuggingFace Hub.
108
+
109
+ The directory (or repo) should contain:
110
+ - config.json: Model architecture config.
111
+ - model.safetensors (preferred) or model.pt: Model weights.
112
+
113
+ Args:
114
+ path_or_repo_id: Local directory path or HuggingFace Hub repo ID.
115
+ dtype: Load weights in this dtype (float32 or bfloat16).
116
+ device: Target device.
117
+ revision: Git revision for Hub downloads.
118
+ cache_dir: Where to cache Hub downloads.
119
+
120
+ Returns:
121
+ Loaded model in eval mode.
122
+ """
123
+ model_dir = _resolve_model_dir(
124
+ path_or_repo_id, revision=revision, cache_dir=cache_dir
125
+ )
126
+ config = CapacitorDiffAEConfig.load(model_dir / "config.json")
127
+ model = cls(config)
128
+
129
+ safetensors_path = model_dir / "model.safetensors"
130
+ pt_path = model_dir / "model.pt"
131
+
132
+ if safetensors_path.exists():
133
+ try:
134
+ from safetensors.torch import load_file
135
+
136
+ state_dict = load_file(str(safetensors_path), device=str(device))
137
+ except ImportError:
138
+ raise ImportError(
139
+ "safetensors package required to load .safetensors files. "
140
+ "Install with: pip install safetensors"
141
+ )
142
+ elif pt_path.exists():
143
+ state_dict = torch.load(
144
+ str(pt_path), map_location=device, weights_only=True
145
+ )
146
+ else:
147
+ raise FileNotFoundError(
148
+ f"No model weights found in {model_dir}. "
149
+ "Expected model.safetensors or model.pt."
150
+ )
151
+
152
+ model.load_state_dict(state_dict)
153
+ model = model.to(dtype=dtype, device=torch.device(device))
154
+ model.eval()
155
+ return model
156
+
157
+ def encode(self, images: Tensor) -> Tensor:
158
+ """Encode images to latents (posterior mode).
159
+
160
+ Args:
161
+ images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
162
+
163
+ Returns:
164
+ Latents [B, bottleneck_dim, H/patch, W/patch].
165
+ """
166
+ try:
167
+ model_dtype = next(self.parameters()).dtype
168
+ except StopIteration:
169
+ model_dtype = torch.float32
170
+ return self.encoder(images.to(dtype=model_dtype))
171
+
172
+ def encode_posterior(self, images: Tensor) -> EncoderPosterior:
173
+ """Encode images and return the full posterior (mean + logsnr).
174
+
175
+ Args:
176
+ images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
177
+
178
+ Returns:
179
+ EncoderPosterior with mean and logsnr tensors.
180
+ """
181
+ try:
182
+ model_dtype = next(self.parameters()).dtype
183
+ except StopIteration:
184
+ model_dtype = torch.float32
185
+ return self.encoder.encode_posterior(images.to(dtype=model_dtype))
186
+
187
+ @torch.no_grad()
188
+ def decode(
189
+ self,
190
+ latents: Tensor,
191
+ height: int,
192
+ width: int,
193
+ *,
194
+ inference_config: CapacitorDiffAEInferenceConfig | None = None,
195
+ ) -> Tensor:
196
+ """Decode latents to images via VP diffusion.
197
+
198
+ Args:
199
+ latents: [B, bottleneck_dim, h, w] encoder latents.
200
+ height: Output image height (divisible by patch_size).
201
+ width: Output image width (divisible by patch_size).
202
+ inference_config: Optional inference parameters.
203
+
204
+ Returns:
205
+ Reconstructed images [B, 3, H, W] in float32.
206
+ """
207
+ cfg = inference_config or CapacitorDiffAEInferenceConfig()
208
+ config = self.config
209
+ batch = int(latents.shape[0])
210
+ device = latents.device
211
+
212
+ try:
213
+ model_dtype = next(self.parameters()).dtype
214
+ except StopIteration:
215
+ model_dtype = torch.float32
216
+
217
+ if height % config.patch_size != 0 or width % config.patch_size != 0:
218
+ raise ValueError(
219
+ f"height={height} and width={width} must be divisible by "
220
+ f"patch_size={config.patch_size}"
221
+ )
222
+
223
+ shape = (batch, config.in_channels, height, width)
224
+ noise = sample_noise(
225
+ shape,
226
+ noise_std=config.pixel_noise_std,
227
+ seed=cfg.seed,
228
+ device=torch.device("cpu"),
229
+ dtype=torch.float32,
230
+ )
231
+
232
+ schedule = get_schedule(cfg.schedule, cfg.num_steps).to(device=device)
233
+ initial_state = make_initial_state(
234
+ noise=noise.to(device=device),
235
+ t_start=schedule[0:1],
236
+ logsnr_min=config.logsnr_min,
237
+ logsnr_max=config.logsnr_max,
238
+ )
239
+
240
+ device_type = "cuda" if device.type == "cuda" else "cpu"
241
+ with torch.autocast(device_type=device_type, enabled=False):
242
+ latents_in = latents.to(device=device)
243
+
244
+ def _forward_fn(
245
+ x_t: Tensor,
246
+ t: Tensor,
247
+ latents: Tensor,
248
+ *,
249
+ drop_middle_blocks: bool = False,
250
+ mask_latent_tokens: bool = False,
251
+ ) -> Tensor:
252
+ return self.decoder(
253
+ x_t.to(dtype=model_dtype),
254
+ t,
255
+ latents.to(dtype=model_dtype),
256
+ drop_middle_blocks=drop_middle_blocks,
257
+ )
258
+
259
+ pdg_mode = "path_drop" if cfg.pdg else "disabled"
260
+
261
+ if cfg.sampler == "ddim":
262
+ sampler_fn = run_ddim
263
+ elif cfg.sampler == "dpmpp_2m":
264
+ sampler_fn = run_dpmpp_2m
265
+ else:
266
+ raise ValueError(
267
+ f"Unsupported sampler: {cfg.sampler!r}. Use 'ddim' or 'dpmpp_2m'."
268
+ )
269
+
270
+ result = sampler_fn(
271
+ forward_fn=_forward_fn,
272
+ initial_state=initial_state,
273
+ schedule=schedule,
274
+ latents=latents_in,
275
+ logsnr_min=config.logsnr_min,
276
+ logsnr_max=config.logsnr_max,
277
+ pdg_mode=pdg_mode,
278
+ pdg_strength=cfg.pdg_strength,
279
+ device=device,
280
+ )
281
+
282
+ return result
283
+
284
+ @torch.no_grad()
285
+ def reconstruct(
286
+ self,
287
+ images: Tensor,
288
+ *,
289
+ inference_config: CapacitorDiffAEInferenceConfig | None = None,
290
+ ) -> Tensor:
291
+ """Encode then decode. Convenience wrapper.
292
+
293
+ Args:
294
+ images: [B, 3, H, W] in [-1, 1].
295
+ inference_config: Optional inference parameters.
296
+
297
+ Returns:
298
+ Reconstructed images [B, 3, H, W] in float32.
299
+ """
300
+ latents = self.encode(images)
301
+ _, _, h, w = images.shape
302
+ return self.decode(
303
+ latents, height=h, width=w, inference_config=inference_config
304
+ )
capacitor_diffae/norms.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Channel-wise RMSNorm for NCHW tensors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+
9
+ class ChannelWiseRMSNorm(nn.Module):
10
+ """Channel-wise RMSNorm with float32 reduction for numerical stability.
11
+
12
+ Normalizes across channels per spatial position. Supports optional
13
+ per-channel affine weight and bias.
14
+ """
15
+
16
+ def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None:
17
+ super().__init__()
18
+ self.channels: int = int(channels)
19
+ self._eps: float = float(eps)
20
+ if affine:
21
+ self.weight = nn.Parameter(torch.ones(self.channels))
22
+ self.bias = nn.Parameter(torch.zeros(self.channels))
23
+ else:
24
+ self.register_parameter("weight", None)
25
+ self.register_parameter("bias", None)
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ if x.dim() < 2:
29
+ return x
30
+ # Float32 accumulation for stability
31
+ ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
32
+ inv_rms = torch.rsqrt(ms + self._eps)
33
+ y = x * inv_rms
34
+ if self.weight is not None:
35
+ shape = (1, -1) + (1,) * (x.dim() - 2)
36
+ y = y * self.weight.view(shape).to(dtype=y.dtype)
37
+ if self.bias is not None:
38
+ y = y + self.bias.view(shape).to(dtype=y.dtype)
39
+ return y.to(dtype=x.dtype)
capacitor_diffae/samplers.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DDIM and DPM++2M samplers for VP diffusion with dual PDG support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Protocol
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+ from .vp_diffusion import (
11
+ alpha_sigma_from_logsnr,
12
+ broadcast_time_like,
13
+ shifted_cosine_interpolated_logsnr_from_t,
14
+ )
15
+
16
+
17
+ class DecoderForwardFn(Protocol):
18
+ """Callable that predicts x0 from (x_t, t, latents) with dual PDG flags."""
19
+
20
+ def __call__(
21
+ self,
22
+ x_t: Tensor,
23
+ t: Tensor,
24
+ latents: Tensor,
25
+ *,
26
+ drop_middle_blocks: bool = False,
27
+ mask_latent_tokens: bool = False,
28
+ ) -> Tensor: ...
29
+
30
+
31
+ def _reconstruct_eps_from_x0(
32
+ *, x_t: Tensor, x0_hat: Tensor, alpha: Tensor, sigma: Tensor
33
+ ) -> Tensor:
34
+ """Reconstruct eps_hat from (x_t, x0_hat) under VP parameterization.
35
+
36
+ eps_hat = (x_t - alpha * x0_hat) / sigma. All float32.
37
+ """
38
+ alpha_view = broadcast_time_like(alpha, x_t).to(dtype=torch.float32)
39
+ sigma_view = broadcast_time_like(sigma, x_t).to(dtype=torch.float32)
40
+ x_t_f32 = x_t.to(torch.float32)
41
+ x0_f32 = x0_hat.to(torch.float32)
42
+ return (x_t_f32 - alpha_view * x0_f32) / sigma_view
43
+
44
+
45
+ def _ddim_step(
46
+ *,
47
+ x0_hat: Tensor,
48
+ eps_hat: Tensor,
49
+ alpha_next: Tensor,
50
+ sigma_next: Tensor,
51
+ ref: Tensor,
52
+ ) -> Tensor:
53
+ """DDIM step: x_next = alpha_next * x0_hat + sigma_next * eps_hat."""
54
+ a = broadcast_time_like(alpha_next, ref).to(dtype=torch.float32)
55
+ s = broadcast_time_like(sigma_next, ref).to(dtype=torch.float32)
56
+ return a * x0_hat + s * eps_hat
57
+
58
+
59
+ def _predict_with_pdg(
60
+ forward_fn: DecoderForwardFn,
61
+ state: Tensor,
62
+ t_vec: Tensor,
63
+ latents: Tensor,
64
+ *,
65
+ pdg_mode: str,
66
+ pdg_strength: float,
67
+ ) -> Tensor:
68
+ """Run decoder forward with optional PDG guidance.
69
+
70
+ Args:
71
+ forward_fn: Decoder forward function.
72
+ state: Current noised state [B, C, H, W].
73
+ t_vec: Timestep vector [B].
74
+ latents: Encoder latents.
75
+ pdg_mode: "disabled", "path_drop", or "token_mask".
76
+ pdg_strength: CFG-like strength for PDG.
77
+
78
+ Returns:
79
+ x0_hat prediction in float32.
80
+ """
81
+ if pdg_mode == "path_drop":
82
+ x0_uncond = forward_fn(state, t_vec, latents, drop_middle_blocks=True).to(
83
+ torch.float32
84
+ )
85
+ x0_cond = forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
86
+ torch.float32
87
+ )
88
+ return x0_uncond + pdg_strength * (x0_cond - x0_uncond)
89
+ elif pdg_mode == "token_mask":
90
+ x0_uncond = forward_fn(state, t_vec, latents, mask_latent_tokens=True).to(
91
+ torch.float32
92
+ )
93
+ x0_cond = forward_fn(state, t_vec, latents, mask_latent_tokens=False).to(
94
+ torch.float32
95
+ )
96
+ return x0_uncond + pdg_strength * (x0_cond - x0_uncond)
97
+ else:
98
+ return forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
99
+ torch.float32
100
+ )
101
+
102
+
103
+ def run_ddim(
104
+ *,
105
+ forward_fn: DecoderForwardFn,
106
+ initial_state: Tensor,
107
+ schedule: Tensor,
108
+ latents: Tensor,
109
+ logsnr_min: float,
110
+ logsnr_max: float,
111
+ log_change_high: float = 0.0,
112
+ log_change_low: float = 0.0,
113
+ pdg_mode: str = "disabled",
114
+ pdg_strength: float = 1.5,
115
+ device: torch.device | None = None,
116
+ ) -> Tensor:
117
+ """Run DDIM sampling loop with dual PDG support.
118
+
119
+ Args:
120
+ forward_fn: Decoder forward function (x_t, t, latents) -> x0_hat.
121
+ initial_state: Starting noised state [B, C, H, W] in float32.
122
+ schedule: Descending t-schedule [num_steps] in [0, 1].
123
+ latents: Encoder latents [B, bottleneck_dim, h, w].
124
+ logsnr_min, logsnr_max: VP schedule endpoints.
125
+ log_change_high, log_change_low: Shifted-cosine schedule parameters.
126
+ pdg_mode: "disabled", "path_drop", or "token_mask".
127
+ pdg_strength: CFG-like strength for PDG.
128
+ device: Target device.
129
+
130
+ Returns:
131
+ Denoised samples [B, C, H, W] in float32.
132
+ """
133
+ run_device = device or initial_state.device
134
+ batch_size = int(initial_state.shape[0])
135
+ state = initial_state.to(device=run_device, dtype=torch.float32)
136
+
137
+ # Precompute logSNR, alpha, sigma for all schedule points
138
+ lmb = shifted_cosine_interpolated_logsnr_from_t(
139
+ schedule.to(device=run_device),
140
+ logsnr_min=logsnr_min,
141
+ logsnr_max=logsnr_max,
142
+ log_change_high=log_change_high,
143
+ log_change_low=log_change_low,
144
+ )
145
+ alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
146
+
147
+ for i in range(int(schedule.numel()) - 1):
148
+ t_i = schedule[i]
149
+ a_t = alpha_sched[i].expand(batch_size)
150
+ s_t = sigma_sched[i].expand(batch_size)
151
+ a_next = alpha_sched[i + 1].expand(batch_size)
152
+ s_next = sigma_sched[i + 1].expand(batch_size)
153
+
154
+ # Model prediction with optional PDG
155
+ t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
156
+ x0_hat = _predict_with_pdg(
157
+ forward_fn,
158
+ state,
159
+ t_vec,
160
+ latents,
161
+ pdg_mode=pdg_mode,
162
+ pdg_strength=pdg_strength,
163
+ )
164
+
165
+ eps_hat = _reconstruct_eps_from_x0(
166
+ x_t=state, x0_hat=x0_hat, alpha=a_t, sigma=s_t
167
+ )
168
+ state = _ddim_step(
169
+ x0_hat=x0_hat,
170
+ eps_hat=eps_hat,
171
+ alpha_next=a_next,
172
+ sigma_next=s_next,
173
+ ref=state,
174
+ )
175
+
176
+ return state
177
+
178
+
179
+ def run_dpmpp_2m(
180
+ *,
181
+ forward_fn: DecoderForwardFn,
182
+ initial_state: Tensor,
183
+ schedule: Tensor,
184
+ latents: Tensor,
185
+ logsnr_min: float,
186
+ logsnr_max: float,
187
+ log_change_high: float = 0.0,
188
+ log_change_low: float = 0.0,
189
+ pdg_mode: str = "disabled",
190
+ pdg_strength: float = 1.5,
191
+ device: torch.device | None = None,
192
+ ) -> Tensor:
193
+ """Run DPM++2M sampling loop with dual PDG support.
194
+
195
+ Multi-step solver using exponential integrator formulation in half-lambda space.
196
+ """
197
+ run_device = device or initial_state.device
198
+ batch_size = int(initial_state.shape[0])
199
+ state = initial_state.to(device=run_device, dtype=torch.float32)
200
+
201
+ # Precompute logSNR, alpha, sigma, half-lambda for all schedule points
202
+ lmb = shifted_cosine_interpolated_logsnr_from_t(
203
+ schedule.to(device=run_device),
204
+ logsnr_min=logsnr_min,
205
+ logsnr_max=logsnr_max,
206
+ log_change_high=log_change_high,
207
+ log_change_low=log_change_low,
208
+ )
209
+ alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
210
+ half_lambda = 0.5 * lmb.to(torch.float32)
211
+
212
+ x0_prev: Tensor | None = None
213
+
214
+ for i in range(int(schedule.numel()) - 1):
215
+ t_i = schedule[i]
216
+ s_t = sigma_sched[i].expand(batch_size)
217
+ a_next = alpha_sched[i + 1].expand(batch_size)
218
+ s_next = sigma_sched[i + 1].expand(batch_size)
219
+
220
+ # Model prediction with optional PDG
221
+ t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
222
+ x0_hat = _predict_with_pdg(
223
+ forward_fn,
224
+ state,
225
+ t_vec,
226
+ latents,
227
+ pdg_mode=pdg_mode,
228
+ pdg_strength=pdg_strength,
229
+ )
230
+
231
+ lam_t = half_lambda[i].expand(batch_size)
232
+ lam_next = half_lambda[i + 1].expand(batch_size)
233
+ h = (lam_next - lam_t).to(torch.float32)
234
+ phi_1 = torch.expm1(-h)
235
+
236
+ sigma_ratio = (s_next / s_t).to(torch.float32)
237
+
238
+ if i == 0 or x0_prev is None:
239
+ # First-order step
240
+ state = (
241
+ sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
242
+ - broadcast_time_like(a_next, state).to(torch.float32)
243
+ * broadcast_time_like(phi_1, state).to(torch.float32)
244
+ * x0_hat
245
+ )
246
+ else:
247
+ # Second-order step
248
+ lam_prev = half_lambda[i - 1].expand(batch_size)
249
+ h_0 = (lam_t - lam_prev).to(torch.float32)
250
+ r0 = h_0 / h
251
+ d1_0 = (x0_hat - x0_prev) / broadcast_time_like(r0, x0_hat)
252
+ common = broadcast_time_like(a_next, state).to(
253
+ torch.float32
254
+ ) * broadcast_time_like(phi_1, state).to(torch.float32)
255
+ state = (
256
+ sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
257
+ - common * x0_hat
258
+ - 0.5 * common * d1_0
259
+ )
260
+
261
+ x0_prev = x0_hat
262
+
263
+ return state
capacitor_diffae/straight_through_encoder.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PixelUnshuffle-based patchifier (no residual conv path)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class Patchify(nn.Module):
9
+ """PixelUnshuffle(patch) -> Conv2d 1x1 projection.
10
+
11
+ Converts [B, C, H, W] images into [B, out_channels, H/patch, W/patch] features.
12
+ """
13
+
14
+ def __init__(self, in_channels: int, patch: int, out_channels: int) -> None:
15
+ super().__init__()
16
+ self.patch = int(patch)
17
+ self.unshuffle = nn.PixelUnshuffle(self.patch)
18
+ in_after = in_channels * (self.patch * self.patch)
19
+ self.proj = nn.Conv2d(in_after, out_channels, kernel_size=1, bias=True)
20
+
21
+ def forward(self, x: Tensor) -> Tensor:
22
+ if x.shape[2] % self.patch != 0 or x.shape[3] % self.patch != 0:
23
+ raise ValueError(
24
+ f"Input H={x.shape[2]} and W={x.shape[3]} must be divisible by patch={self.patch}"
25
+ )
26
+ y = self.unshuffle(x)
27
+ return self.proj(y)
capacitor_diffae/time_embed.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sinusoidal timestep embedding with MLP projection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+
11
+ def _log_spaced_frequencies(
12
+ half: int, max_period: float, *, device: torch.device | None = None
13
+ ) -> Tensor:
14
+ """Log-spaced frequencies for sinusoidal embedding."""
15
+ return torch.exp(
16
+ -math.log(max_period)
17
+ * torch.arange(half, device=device, dtype=torch.float32)
18
+ / max(float(half - 1), 1.0)
19
+ )
20
+
21
+
22
+ def sinusoidal_time_embedding(
23
+ t: Tensor,
24
+ dim: int,
25
+ *,
26
+ max_period: float = 10000.0,
27
+ scale: float | None = None,
28
+ freqs: Tensor | None = None,
29
+ ) -> Tensor:
30
+ """Sinusoidal timestep embedding (DDPM/DiT-style). Always float32."""
31
+ t32 = t.to(torch.float32)
32
+ if scale is not None:
33
+ t32 = t32 * float(scale)
34
+ half = dim // 2
35
+ if freqs is not None:
36
+ freqs = freqs.to(device=t32.device, dtype=torch.float32)
37
+ else:
38
+ freqs = _log_spaced_frequencies(half, max_period, device=t32.device)
39
+ angles = t32[:, None] * freqs[None, :]
40
+ return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
41
+
42
+
43
+ class SinusoidalTimeEmbeddingMLP(nn.Module):
44
+ """Sinusoidal time embedding followed by Linear -> SiLU -> Linear."""
45
+
46
+ def __init__(
47
+ self,
48
+ dim: int,
49
+ *,
50
+ freq_dim: int = 256,
51
+ hidden_mult: float = 1.0,
52
+ time_scale: float = 1000.0,
53
+ max_period: float = 10000.0,
54
+ ) -> None:
55
+ super().__init__()
56
+ self.dim = int(dim)
57
+ self.freq_dim = int(freq_dim)
58
+ self.time_scale = float(time_scale)
59
+ self.max_period = float(max_period)
60
+ hidden_dim = max(int(round(int(dim) * float(hidden_mult))), 1)
61
+
62
+ freqs = _log_spaced_frequencies(self.freq_dim // 2, self.max_period)
63
+ self.register_buffer("freqs", freqs, persistent=True)
64
+
65
+ self.proj_in = nn.Linear(self.freq_dim, hidden_dim)
66
+ self.act = nn.SiLU()
67
+ self.proj_out = nn.Linear(hidden_dim, self.dim)
68
+
69
+ def forward(self, t: Tensor) -> Tensor:
70
+ freqs: Tensor = self.freqs # type: ignore[assignment]
71
+ emb_freq = sinusoidal_time_embedding(
72
+ t.to(torch.float32),
73
+ self.freq_dim,
74
+ max_period=self.max_period,
75
+ scale=self.time_scale,
76
+ freqs=freqs,
77
+ )
78
+ dtype_in = self.proj_in.weight.dtype
79
+ hidden = self.proj_in(emb_freq.to(dtype_in))
80
+ hidden = self.act(hidden)
81
+ if hidden.dtype != self.proj_out.weight.dtype:
82
+ hidden = hidden.to(self.proj_out.weight.dtype)
83
+ return self.proj_out(hidden)
capacitor_diffae/vp_diffusion.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VP diffusion math: logSNR schedules, alpha/sigma computation, noise construction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+
11
+ def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]:
12
+ """Compute (alpha, sigma) from logSNR in float32.
13
+
14
+ VP constraint: alpha^2 + sigma^2 = 1.
15
+ """
16
+ lmb32 = lmb.to(dtype=torch.float32)
17
+ alpha = torch.sqrt(torch.sigmoid(lmb32))
18
+ sigma = torch.sqrt(torch.sigmoid(-lmb32))
19
+ return alpha, sigma
20
+
21
+
22
+ def broadcast_time_like(coeff: Tensor, x: Tensor) -> Tensor:
23
+ """Broadcast [B] coefficient to match x for per-sample scaling."""
24
+ view_shape = (int(x.shape[0]),) + (1,) * (x.dim() - 1)
25
+ return coeff.view(view_shape)
26
+
27
+
28
+ def _cosine_interpolated_params(
29
+ logsnr_min: float, logsnr_max: float
30
+ ) -> tuple[float, float]:
31
+ """Compute (a, b) for cosine-interpolated logSNR schedule.
32
+
33
+ logsnr(t) = -2 * log(tan(a*t + b))
34
+ logsnr(0) = logsnr_max, logsnr(1) = logsnr_min
35
+ """
36
+ b = math.atan(math.exp(-0.5 * logsnr_max))
37
+ a = math.atan(math.exp(-0.5 * logsnr_min)) - b
38
+ return a, b
39
+
40
+
41
+ def cosine_interpolated_logsnr_from_t(
42
+ t: Tensor, *, logsnr_min: float, logsnr_max: float
43
+ ) -> Tensor:
44
+ """Map t in [0,1] to logSNR via cosine-interpolated schedule. Always float32."""
45
+ a, b = _cosine_interpolated_params(logsnr_min, logsnr_max)
46
+ t32 = t.to(dtype=torch.float32)
47
+ a_t = torch.tensor(a, device=t32.device, dtype=torch.float32)
48
+ b_t = torch.tensor(b, device=t32.device, dtype=torch.float32)
49
+ u = a_t * t32 + b_t
50
+ return -2.0 * torch.log(torch.tan(u))
51
+
52
+
53
+ def shifted_cosine_interpolated_logsnr_from_t(
54
+ t: Tensor,
55
+ *,
56
+ logsnr_min: float,
57
+ logsnr_max: float,
58
+ log_change_high: float = 0.0,
59
+ log_change_low: float = 0.0,
60
+ ) -> Tensor:
61
+ """SiD2 "shifted cosine" schedule: logSNR with resolution-dependent shifts.
62
+
63
+ lambda(t) = (1-t) * (base(t) + log_change_high) + t * (base(t) + log_change_low)
64
+ """
65
+ base = cosine_interpolated_logsnr_from_t(
66
+ t, logsnr_min=logsnr_min, logsnr_max=logsnr_max
67
+ )
68
+ t32 = t.to(dtype=torch.float32)
69
+ high = base + float(log_change_high)
70
+ low = base + float(log_change_low)
71
+ return (1.0 - t32) * high + t32 * low
72
+
73
+
74
+ def get_schedule(schedule_type: str, num_steps: int) -> Tensor:
75
+ """Generate a descending t-schedule in [0, 1] for VP diffusion sampling.
76
+
77
+ ``num_steps`` is the number of function evaluations (NFE = decoder forward
78
+ passes). Internally the schedule has ``num_steps + 1`` time points
79
+ (including both endpoints).
80
+
81
+ Args:
82
+ schedule_type: "linear" or "cosine".
83
+ num_steps: Number of decoder forward passes (NFE), >= 1.
84
+
85
+ Returns:
86
+ Descending 1D tensor with ``num_steps + 1`` elements from ~1.0 to ~0.0.
87
+ """
88
+ # NOTE: the upstream training code (src/ode/time_schedules.py) uses a
89
+ # different convention where num_steps counts schedule *points* (so NFE =
90
+ # num_steps - 1). This export package corrects the off-by-one so that
91
+ # num_steps means NFE directly. TODO: align the upstream convention.
92
+ n = max(int(num_steps) + 1, 2)
93
+ if schedule_type == "linear":
94
+ base = torch.linspace(0.0, 1.0, n)
95
+ elif schedule_type == "cosine":
96
+ i = torch.arange(n, dtype=torch.float32)
97
+ base = 0.5 * (1.0 - torch.cos(math.pi * (i / (n - 1))))
98
+ else:
99
+ raise ValueError(
100
+ f"Unsupported schedule type: {schedule_type!r}. Use 'linear' or 'cosine'."
101
+ )
102
+ # Descending: high t (noisy) -> low t (clean)
103
+ return torch.flip(base, dims=[0])
104
+
105
+
106
+ def make_initial_state(
107
+ *,
108
+ noise: Tensor,
109
+ t_start: Tensor,
110
+ logsnr_min: float,
111
+ logsnr_max: float,
112
+ log_change_high: float = 0.0,
113
+ log_change_low: float = 0.0,
114
+ ) -> Tensor:
115
+ """Construct VP initial state x_t0 = sigma_start * noise (since x0=0).
116
+
117
+ All math in float32.
118
+ """
119
+ batch = int(noise.shape[0])
120
+ lmb_start = shifted_cosine_interpolated_logsnr_from_t(
121
+ t_start.expand(batch).to(dtype=torch.float32),
122
+ logsnr_min=logsnr_min,
123
+ logsnr_max=logsnr_max,
124
+ log_change_high=log_change_high,
125
+ log_change_low=log_change_low,
126
+ )
127
+ _alpha_start, sigma_start = alpha_sigma_from_logsnr(lmb_start)
128
+ sigma_view = broadcast_time_like(sigma_start, noise)
129
+ return sigma_view * noise.to(dtype=torch.float32)
130
+
131
+
132
+ def sample_noise(
133
+ shape: tuple[int, ...],
134
+ *,
135
+ noise_std: float = 1.0,
136
+ seed: int | None = None,
137
+ device: torch.device | None = None,
138
+ dtype: torch.dtype = torch.float32,
139
+ ) -> Tensor:
140
+ """Sample Gaussian noise with optional seeding. CPU-seeded for reproducibility."""
141
+ if seed is None:
142
+ noise = torch.randn(
143
+ shape, device=device or torch.device("cpu"), dtype=torch.float32
144
+ )
145
+ else:
146
+ gen = torch.Generator(device="cpu")
147
+ gen.manual_seed(int(seed))
148
+ noise = torch.randn(shape, generator=gen, device="cpu", dtype=torch.float32)
149
+ noise = noise.mul(float(noise_std))
150
+ target_device = device if device is not None else torch.device("cpu")
151
+ return noise.to(device=target_device, dtype=dtype)
config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "in_channels": 3,
3
+ "patch_size": 16,
4
+ "model_dim": 896,
5
+ "encoder_depth": 4,
6
+ "decoder_depth": 8,
7
+ "decoder_start_blocks": 2,
8
+ "decoder_end_blocks": 2,
9
+ "bottleneck_dim": 128,
10
+ "mlp_ratio": 4.0,
11
+ "depthwise_kernel_size": 7,
12
+ "adaln_low_rank_rank": 128,
13
+ "bottleneck_posterior_kind": "diagonal_gaussian",
14
+ "bottleneck_norm_mode": "disabled",
15
+ "logsnr_min": -10.0,
16
+ "logsnr_max": 10.0,
17
+ "pixel_noise_std": 0.558
18
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89c2d23ce3c925697b7d8b93daeb0769987b99984858989feabfcc9e8bc8b7fa
3
+ size 355100344
technical_report_semantic.md ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SemDisDiffAE — Technical Report
2
+
3
+ **SemDisDiffAE** (**Sem**antically **Dis**entangled **Diff**usion **A**uto**E**ncoder)
4
+ — a fast diffusion autoencoder with a 128-channel spatial bottleneck built on
5
+ FCDM (Fully Convolutional Diffusion Model) blocks. The encoder uses a
6
+ VP-parameterized diagonal Gaussian posterior (learned log-SNR output head),
7
+ and the decoder reconstructs via single-step VP diffusion.
8
+
9
+ This checkpoint is trained with DINOv2 semantic alignment and variance
10
+ expansion regularization. The name is a nod to DRA (Page et al., 2026) whose
11
+ disentangled representation alignment approach inspired the semantic alignment
12
+ method used here.
13
+
14
+ ## Contents
15
+
16
+ 1. [Architecture](#1-architecture)
17
+ - [FCDM Block](#11-fcdm-block) · [Encoder](#12-encoder) · [Decoder](#13-decoder) · [AdaLN](#14-adaln-shared-base--low-rank-deltas) · [PDG](#15-path-drop-guidance-pdg)
18
+ 2. [Decoder VP Diffusion Parameterization](#2-decoder-vp-diffusion-parameterization)
19
+ - [Forward Process](#21-forward-process) · [Log SNR](#22-log-signal-to-noise-ratio) · [Schedule](#23-cosine-interpolated-schedule) · [X-Prediction](#24-x-prediction-objective) · [Sampling](#25-sampling)
20
+ 3. [Stochastic Posterior](#3-stochastic-posterior)
21
+ - [VP Log-SNR Parameterization](#31-vp-log-snr-parameterization) · [Variance Expansion Loss](#32-variance-expansion-loss) · [Posterior Mode](#33-posterior-mode-for-inference)
22
+ 4. [Semantic Alignment](#4-semantic-alignment)
23
+ 5. [Design Choices](#5-design-choices)
24
+ - [Convolutional Architecture](#51-convolutional-architecture) · [Single-Stride Encoder](#52-single-stride-encoder) · [Diffusion Decoding](#53-diffusion-decoding) · [Skip Connection and PDG](#54-skip-connection-and-path-drop-guidance)
25
+ 6. [Training](#6-training)
26
+ - [Loss Functions](#61-loss-functions) · [Optimizer](#62-optimizer-and-hyperparameters) · [Data](#63-data)
27
+ 7. [Model Configuration](#7-model-configuration)
28
+ 8. [Inference](#8-inference)
29
+ 9. [Results](#9-results)
30
+
31
+ **References:**
32
+
33
+ - **FCDM** — Kwon et al., *Reviving ConvNeXt for Efficient Convolutional Diffusion Models*, [arXiv:2603.09408](https://arxiv.org/abs/2603.09408), 2026.
34
+ - **SiD2** — Hoogeboom et al., *Simpler Diffusion (SiD2): 1.5 FID on ImageNet512 with pixel-space diffusion*, [arXiv:2410.19324](https://arxiv.org/abs/2410.19324), ICLR 2025.
35
+ - **DiTo** — Yin et al., *Diffusion Autoencoders are Scalable Image Tokenizers*, [arXiv:2501.18593](https://arxiv.org/abs/2501.18593), 2025.
36
+ - **DiCo** — Ai et al., *DiCo: Revitalizing ConvNets for Scalable and Efficient Diffusion Modeling*, [arXiv:2505.11196](https://arxiv.org/abs/2505.11196), 2025.
37
+ - **ConvNeXt V2** — Woo et al., *ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders*, [arXiv:2301.00808](https://arxiv.org/abs/2301.00808), CVPR 2023.
38
+ - **Z-image** — Cai et al., *Z-Image: An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer*, [arXiv:2511.22699](https://arxiv.org/abs/2511.22699), 2025.
39
+ - **SPRINT** — Park et al., *Sprint: Sparse-Dense Residual Fusion for Efficient Diffusion Transformers*, [arXiv:2510.21986](https://arxiv.org/abs/2510.21986), 2025.
40
+ - **DINOv2** — Oquab et al., *DINOv2: Learning Robust Visual Features without Supervision*, [arXiv:2304.07193](https://arxiv.org/abs/2304.07193), 2023. Register variant: Darcet et al., *Vision Transformers Need Registers*, [arXiv:2309.16588](https://arxiv.org/abs/2309.16588), ICLR 2024.
41
+ - **iREPA** — Singh et al., *What matters for Representation Alignment: Global Information or Spatial Structure?*, [arXiv:2512.10794](https://arxiv.org/abs/2512.10794), 2025.
42
+ - **DRA** — Page et al., *Boosting Latent Diffusion Models via Disentangled Representation Alignment*, [arXiv:2601.05823](https://arxiv.org/abs/2601.05823), 2026.
43
+ - **VEL** — Li et al., *Taming Sampling Perturbations with Variance Expansion Loss for Latent Diffusion Models*, [arXiv:2603.21085](https://arxiv.org/abs/2603.21085), 2026.
44
+ - **iRDiffAE** — [data-archetype/irdiffae-v1](https://huggingface.co/data-archetype/irdiffae-v1) — predecessor model using DiCo blocks.
45
+
46
+ ---
47
+
48
+ ## 1. Architecture
49
+
50
+ ### 1.1 FCDM Block
51
+
52
+ SemDisDiffAE uses **FCDM blocks** — ConvNeXt-style convolutional blocks
53
+ adapted for diffusion models (Li et al., 2026). Each block follows a single
54
+ unified residual path:
55
+
56
+ ```
57
+ x ──► DWConv 7×7 ──► RMSNorm ──► [Scale] ──► Conv 1×1 ──► GELU ──► GRN ──► Conv 1×1 ──► [Gate] ──► + ──► out
58
+ │ ▲
59
+ └────────────────────────────────────────────────────────────────────────────────────────────────────────┘
60
+ ```
61
+
62
+ This differs from DiCo blocks (used in the predecessor
63
+ [iRDiffAE](https://huggingface.co/data-archetype/irdiffae-v1)) which use two
64
+ separate residual paths (conv + MLP) with Compact Channel Attention (CCA).
65
+ FCDM consolidates into one path, replacing CCA with Global Response
66
+ Normalization (GRN).
67
+
68
+ Key components:
69
+
70
+ - **Depthwise convolution** (7×7, groups=channels): spatial mixing without
71
+ cross-channel interaction. The depthwise conv output feeds directly into
72
+ RMSNorm (no intermediate activation).
73
+
74
+ - **RMSNorm** (non-affine, per-channel): normalizes activations before the
75
+ pointwise MLP, replacing LayerNorm used in standard ConvNeXt.
76
+
77
+ - **Global Response Normalization (GRN)** (from ConvNeXt V2, Woo et al. 2023):
78
+ applied between the two pointwise convolutions. GRN computes per-channel L2
79
+ norms across the spatial dimensions and normalizes by the cross-channel mean:
80
+ ```
81
+ g = ||x||_2 over (H, W)
82
+ n = g / mean(g over channels)
83
+ GRN(x) = gamma * (x * n) + beta + x
84
+ ```
85
+ This encourages feature diversity across channels and prevents channel
86
+ collapse during training.
87
+
88
+ - **Scale+Gate modulation** (decoder only): FCDM blocks use a 2-way modulation
89
+ `(scale, gate)` from the timestep embedding, in contrast to DiCo's 4-way
90
+ `(shift_conv, gate_conv, shift_mlp, gate_mlp)`. Scale is applied after
91
+ RMSNorm: `h = h * (1 + scale)`. Gate is applied to the residual:
92
+ `out = x + gate * h`. The gate is used **raw** (no tanh activation), giving
93
+ unbounded gating — this differs from DiCo which applies tanh to constrain
94
+ the gate to [-1, 1].
95
+
96
+ - **Layer Scale** (encoder only): for unconditioned encoder blocks, a learnable
97
+ per-channel scale (initialized to 1e-3) gates the residual for near-identity
98
+ initialization, following ConvNeXt.
99
+
100
+ ### 1.2 Encoder
101
+
102
+ The encoder uses a single spatial stride (via PixelUnshuffle at the input)
103
+ followed by FCDM blocks at constant spatial resolution, then a bottleneck
104
+ projection that outputs both the posterior mean and per-element log-SNR:
105
+
106
+ ```
107
+ Image [B, 3, H, W]
108
+ ──► PixelUnshuffle(p=16) + Conv 1×1 (3·16² → 896) [Patchify]
109
+ ──► 4 × FCDMBlock (unconditioned, layer-scale gated)
110
+ ──► Conv 1×1 (896 → 256) [Bottleneck projection]
111
+ ──► Split → mean [B, 128, h, w] + logsnr [B, 128, h, w]
112
+ ──► α(logsnr) · mean [Posterior mode]
113
+ ```
114
+
115
+ The single-stride design ensures all encoder blocks see the full spatial
116
+ resolution and full channel width simultaneously. The information bottleneck
117
+ is imposed only at the very end, where a single linear projection selects
118
+ which channels to retain. See Section 4.2 for the rationale.
119
+
120
+ **Note:** This checkpoint uses `bottleneck_norm_mode=disabled`, so no
121
+ post-bottleneck RMSNorm is applied to the mean branch. The posterior mode
122
+ output is simply `α · μ` where `α = √σ(λ)`.
123
+
124
+ ### 1.3 Decoder
125
+
126
+ The decoder predicts x̂₀ from noisy input x_t, conditioned on encoder
127
+ latents z and timestep t:
128
+
129
+ ```
130
+ Noised image x_t [B, 3, H, W]
131
+ ──► PixelUnshuffle(p=16) + Conv 1×1 (3·16² → 896) [Patchify]
132
+ ──► Concatenate with Conv 1×1(latents, 128 → 896) [Latent fusion]
133
+ ──► Conv 1×1 (2·896 → 896)
134
+ ──► 2 × FCDMBlock (AdaLN conditioned) [Start blocks]
135
+ ──► 4 × FCDMBlock (AdaLN conditioned) [Middle blocks]
136
+ ──► Concat(start_out, middle_out) + Conv 1×1 [Skip fusion]
137
+ ──► 2 × FCDMBlock (AdaLN conditioned) [End blocks]
138
+ ──► Conv 1×1 (896 → 3·16²) + PixelShuffle(16) [Unpatchify]
139
+ ──► x̂₀ prediction [B, 3, H, W]
140
+ ```
141
+
142
+ The skip-concat topology with 2+4+2 blocks is inspired by SPRINT's
143
+ sparse-dense residual fusion (Park et al., 2025). See Section 4.4 for the
144
+ design rationale.
145
+
146
+ ### 1.4 AdaLN: Shared Base + Low-Rank Deltas
147
+
148
+ Timestep conditioning follows the Z-image style AdaLN
149
+ ([Cai et al., 2025](https://arxiv.org/abs/2511.22699)): a shared base
150
+ projection plus a low-rank delta per layer.
151
+
152
+ A single base projector is shared across all 8 decoder layers, and each
153
+ layer adds a low-rank correction:
154
+
155
+ ```
156
+ m_i = Base(SiLU(cond)) + Δ_i(SiLU(cond))
157
+ ```
158
+
159
+ where `Base: ℝ^D → ℝ^{2D}` is a linear projection (zero-initialized) and
160
+ `Δ_i: ℝ^D → ℝ^r → ℝ^{2D}` is a low-rank factorization with rank r = 128
161
+ (zero-initialized up-projection).
162
+
163
+ The packed modulation `m_i ∈ ℝ^{B × 2D}` is split into `(scale, gate)` which
164
+ modulate the FCDM block (no shift term):
165
+
166
+ ```
167
+ ĥ = RMSNorm(x) · (1 + scale)
168
+ x ← x + gate · f(ĥ)
169
+ ```
170
+
171
+ ### 1.5 Path-Drop Guidance (PDG)
172
+
173
+ At inference, optional PDG sharpens reconstructions by exploiting the
174
+ skip-concat structure — a classifier-free guidance analogue that does not
175
+ require training with conditioning dropout:
176
+
177
+ 1. **Conditional pass:** run all blocks normally → x̂₀^cond
178
+ 2. **Unconditional pass:** replace the middle block output with a learned
179
+ mask feature m ∈ ℝ^{1×D×1×1} (initialized to zero), effectively dropping
180
+ the deep processing path → x̂₀^uncond
181
+ 3. **Guided prediction:** x̂₀ = x̂₀^uncond + s · (x̂₀^cond - x̂₀^uncond)
182
+
183
+ where s is the guidance strength.
184
+
185
+ For PSNR-optimal reconstruction, PDG is disabled (1 NFE). For perceptual
186
+ sharpening, use 10 steps with PDG strength 2.0. Note that PDG is primarily
187
+ useful for more compressed bottlenecks (e.g. 32 or 64 channels) and is
188
+ rarely necessary for 128-channel models where reconstruction quality is
189
+ already high.
190
+
191
+ ---
192
+
193
+ ## 2. Decoder VP Diffusion Parameterization
194
+
195
+ The decoder uses the variance-preserving (VP) diffusion framework from
196
+ SiD2 with an x-prediction objective.
197
+
198
+ ### 2.1 Forward Process
199
+
200
+ Given a clean image x₀, the forward process constructs a noisy sample at
201
+ continuous time t ∈ [0, 1]:
202
+
203
+ ```
204
+ x_t = α_t · x₀ + σ_t · ε, ε ~ N(0, s²I)
205
+ ```
206
+
207
+ where s = 0.558 is the pixel-space noise standard deviation (estimated from
208
+ the dataset image distribution) and the VP constraint holds: α²_t + σ²_t = 1.
209
+
210
+ ### 2.2 Log Signal-to-Noise Ratio
211
+
212
+ The schedule is parameterized through the log signal-to-noise ratio:
213
+
214
+ ```
215
+ λ_t = log(α²_t / σ²_t)
216
+ ```
217
+
218
+ which monotonically decreases as t → 1 (pure noise). From λ_t we recover
219
+ α_t and σ_t via the sigmoid function:
220
+
221
+ ```
222
+ α_t = √σ(λ_t), σ_t = √σ(-λ_t)
223
+ ```
224
+
225
+ ### 2.3 Cosine-Interpolated Schedule
226
+
227
+ Following SiD2, the logSNR schedule uses cosine interpolation:
228
+
229
+ ```
230
+ λ(t) = -2 log tan(a·t + b)
231
+ ```
232
+
233
+ where a and b are computed to satisfy the boundary conditions
234
+ λ(0) = λ_max = 10 and λ(1) = λ_min = -10.
235
+
236
+ ### 2.4 X-Prediction Objective
237
+
238
+ The model predicts the clean image x̂₀ = f_θ(x_t, t, z) conditioned on
239
+ encoder latents z.
240
+
241
+ **Schedule-invariant loss.** Following SiD2, the training loss is defined as
242
+ an integral over logSNR λ, making it invariant to the choice of noise schedule.
243
+ Since timesteps are sampled uniformly t ~ U(0,1), the change of variable
244
+ introduces a Jacobian factor:
245
+
246
+ ```
247
+ L = E_{t ~ U(0,1)} [ (-dλ/dt) · w(λ(t)) · ||x₀ - x̂₀||² ]
248
+ ```
249
+
250
+ **Sigmoid weighting.** The weighting function uses a sigmoid centered at bias
251
+ b = -2.0, converting from ε-prediction to x-prediction form:
252
+
253
+ ```
254
+ weight(t) = -(1/2) · (dλ/dt) · e^b · σ(λ(t) - b)
255
+ ```
256
+
257
+ ### 2.5 Sampling
258
+
259
+ Decoding uses DDIM by default. With 1 NFE (default), the model runs a single
260
+ evaluation at t_start ≈ 1 (near pure noise) and directly outputs the x₀
261
+ prediction. This is equivalent to a denoising autoencoder that maps
262
+ `σ_start · noise → x̂₀` conditioned on encoder latents.
263
+
264
+ DPM++2M is also supported as an alternative sampler, using a half-lambda
265
+ exponential integrator for faster convergence with more steps.
266
+
267
+ ---
268
+
269
+ ## 3. Stochastic Posterior
270
+
271
+ ### 3.1 VP Log-SNR Parameterization
272
+
273
+ Instead of a KL-divergence penalty on a Gaussian encoder, SemDisDiffAE
274
+ parameterizes the bottleneck posterior using the VP interpolation convention.
275
+ This approach uses a VP-style noise interpolation in the encoder bottleneck
276
+ as an alternative to the traditional VAE KL penalty.
277
+
278
+ The encoder outputs two sets of 128 channels:
279
+
280
+ - **μ** — the clean signal (posterior mean)
281
+ - **λ** — per-element log signal-to-noise ratio
282
+
283
+ The posterior distribution is:
284
+
285
+ ```
286
+ z = α(λ) · μ + σ(λ) · ε, ε ~ N(0, I)
287
+ ```
288
+
289
+ where α = √σ(λ) and σ = √σ(-λ) (sigmoid parameterization). This is
290
+ equivalent to a Gaussian with mean α·μ and variance σ².
291
+
292
+ Using a VP interpolation rather than simple additive noise decouples token
293
+ scale from stochasticity. With additive noise (`z = μ + σε`), the encoder
294
+ faces gradient pressure to scale latents up to counter the noise — the SNR
295
+ depends on the magnitude of μ. The VP formulation (`z = α·μ + σ·ε` with
296
+ `α² + σ² = 1`) removes this coupling: the noise level is controlled
297
+ entirely by the predicted log-SNR, independent of the latent magnitude.
298
+
299
+ ### 3.2 Variance Expansion Loss
300
+
301
+ To prevent posterior collapse (where the encoder learns to set σ → 0 and
302
+ ignore the stochastic component entirely), we adopt a **variance expansion
303
+ loss** inspired by VEL (Li et al., 2026,
304
+ [arXiv:2603.21085](https://arxiv.org/abs/2603.21085)):
305
+
306
+ ```
307
+ L_var = -mean(log(σ² + δ))
308
+ ```
309
+
310
+ where σ² is the posterior variance derived from the predicted log-SNR and
311
+ δ is a small epsilon (1e-6) for numerical stability. This loss encourages
312
+ non-zero posterior variance by penalizing small σ².
313
+
314
+ VEL proposes the form `1/(σ² + δ)` for variance expansion. We found this to
315
+ be too aggressive — the `1/σ²` gradient pushes variance up very rapidly,
316
+ leading to excessive high-frequency noise in the latent space. We use the
317
+ `-log(σ² + δ)` form instead, which provides a gentler, logarithmic penalty
318
+ that stabilizes training.
319
+
320
+ **For this checkpoint:** the variance expansion loss is active with weight
321
+ **1e-5**.
322
+
323
+ > **Key finding: latent spectral structure matters for downstream diffusion.**
324
+ >
325
+ > Reconstruction quality is not very sensitive to the posterior noise level —
326
+ > good PSNR is achievable even with log-SNR as low as -2. However, the
327
+ > posterior noise level has a strong effect on the **spatial frequency
328
+ > content** of the latent space. When variance expansion is too aggressive,
329
+ > the latent space develops excessive high-frequency content; when it is
330
+ > too weak or absent, latents become overly smooth.
331
+ >
332
+ > We found empirically that downstream diffusion models converge best when
333
+ > the latent space has a **radial power spectral density (PSD) decay
334
+ > exponent of approximately 1.5** — deviating significantly in either
335
+ > direction (too smooth or too high-frequency) consistently yields worse
336
+ > downstream training convergence. We monitor this metric during training
337
+ > validation to guide the variance expansion weight.
338
+ >
339
+ > The weight of 1e-5 for this checkpoint was chosen to target this spectral
340
+ > sweet spot.
341
+
342
+ ### 3.3 Posterior Mode for Inference
343
+
344
+ At inference, the encoder returns the **posterior mode**: `z = α(λ) · μ`. For
345
+ this checkpoint, the posterior log-SNR is very high (posterior variance is
346
+ negligible), so sampling and mode are nearly identical.
347
+
348
+ The `encode_posterior()` method is available for users who need the full
349
+ posterior distribution.
350
+
351
+ ---
352
+
353
+ ## 4. Semantic Alignment
354
+
355
+ This checkpoint uses **semantic alignment** to encourage semantically
356
+ structured latent representations. The approach is inspired by DRA
357
+ (Page et al., 2026, [arXiv:2601.05823](https://arxiv.org/abs/2601.05823))
358
+ which aligns autoencoder latents with frozen vision encoder features. Our
359
+ implementation differs in the projection architecture and noise schedule.
360
+
361
+ ### 4.1 Teacher
362
+
363
+ A frozen DINOv2-S with registers
364
+ (timm: `vit_small_patch16_dinov3.lvd_1689m`, 384-dim patch tokens) provides
365
+ the target spatial semantic features.
366
+
367
+ ### 4.2 Projection Head
368
+
369
+ The student projection head maps noisy encoder latents to the teacher's
370
+ token space. It consists of:
371
+
372
+ ```
373
+ Noisy latents z_noisy ∈ ℝ^{B×128×h×w}
374
+ ──► Conv 1×1 (128 → 384) [Channel projection]
375
+ ──► Flatten to tokens [B, T, 384]
376
+ ──► DiT transformer block [Single block, 6 heads × 64 dim]
377
+ (self-attention with axial RoPE 2D + AdaLN conditioned on τ)
378
+ ──► RMSNorm
379
+ ──► student tokens ∈ ℝ^{B×T×384}
380
+ ```
381
+
382
+ The DiT block uses standard multi-head self-attention with 2D axial
383
+ rotary position embeddings (RoPE) and AdaLN-Zero timestep conditioning.
384
+ This gives the projection head global spatial reasoning — important for
385
+ matching the teacher's self-attention-based representations — while the
386
+ main encoder/decoder remain purely convolutional.
387
+
388
+ ### 4.3 Noisy Alignment
389
+
390
+ Unlike standard representation alignment which operates on clean latents,
391
+ we align **noisy** latent versions. The noise level τ is sampled from a
392
+ Beta(2,2) distribution (concentrated around τ=0.5) using flow matching
393
+ linear interpolation:
394
+
395
+ ```
396
+ z_noisy = (1 - τ) · z + τ · ε, ε ~ N(0, I), τ ~ Beta(2, 2)
397
+ ```
398
+
399
+ The projection head receives both the noisy latents and the noise level τ
400
+ (via its AdaLN conditioning). This trains the head to extract semantic
401
+ information even from partially corrupted latents, improving robustness
402
+ for downstream diffusion models which operate on noised latent inputs.
403
+
404
+ ### 4.4 Training Details
405
+
406
+ The alignment loss is the mean negative cosine similarity between student
407
+ and teacher tokens, weighted at **0.01** throughout training. The student
408
+ projection head operates on all 128 bottleneck channels, unlike the
409
+ predecessor iRDiffAE which aligned only the first 64 of 128 channels.
410
+
411
+ Note that the projection head is a training-only component — it is not
412
+ included in the exported model weights.
413
+
414
+ ---
415
+
416
+ ## 5. Design Choices
417
+
418
+ ### 5.1 Convolutional Architecture
419
+
420
+ SemDisDiffAE uses a fully convolutional architecture rather than a vision
421
+ transformer. For an autoencoder whose goal is faithful pixel-level
422
+ reconstruction (not global semantic understanding), convolutions offer:
423
+
424
+ - **Resolution generalization.** Convolutions operate on local patches and
425
+ generalize naturally to arbitrary image dimensions without interpolating
426
+ position embeddings or suffering attention distribution shift.
427
+ - **Translation invariance.** Weight sharing across spatial positions is well
428
+ matched to reconstruction, where the same local patterns (edges, textures)
429
+ conditioned on the low-frequency latent recur throughout the image.
430
+ - **Locality.** Reconstruction quality depends on preserving fine spatial
431
+ detail. Convolutions are inherently local operators, avoiding the quadratic
432
+ cost of global attention while focusing computation where it matters most.
433
+
434
+ ### 5.2 Single-Stride Encoder with Final Bottleneck
435
+
436
+ The encoder uses a single spatial stride (PixelUnshuffle at the input)
437
+ followed by blocks at constant spatial resolution, then a final 1×1 convolution
438
+ to project to the bottleneck. This differs from classical VAE encoders that use
439
+ progressive downsampling with channel expansion at each stage.
440
+
441
+ The single-stride design ensures that all encoder blocks see the full spatial
442
+ resolution and full channel width simultaneously. The information bottleneck is
443
+ imposed only at the very end, where a single linear projection selects which
444
+ channels to retain.
445
+
446
+ ### 5.3 Diffusion Decoding
447
+
448
+ The main advantage of diffusion decoding over the standard GAN + LPIPS
449
+ approach is **simplicity and speed of experimentation**. The training
450
+ objective is a straightforward weighted MSE — no discriminator, no LPIPS
451
+ perceptual loss, no delicate adversarial balancing. This makes it very fast to train and easy to iterate on — typically a few
452
+ hours on a single GPU is sufficient. This checkpoint was trained for 251k
453
+ steps. By contrast, GAN + LPIPS-based VAEs require many days of large-GPU
454
+ time and are notoriously difficult to stabilize from scratch.
455
+
456
+ This simplicity enables rapid experimentation with latent space shaping to
457
+ get it as diffusion-friendly as possible, while still achieving excellent
458
+ reconstruction quality.
459
+
460
+ ### 5.4 Skip Connection and Path-Drop Guidance
461
+
462
+ The decoder's start → middle → skip-fuse → end architecture is inspired by
463
+ SPRINT's sparse-dense residual fusion (Park et al., 2025). The design serves
464
+ three purposes:
465
+
466
+ 1. **Regularization.** The skip path ensures that even if the middle blocks
467
+ are dropped or poorly conditioned, the end blocks still receive meaningful
468
+ features from the start blocks.
469
+ 2. **High-frequency preservation.** The start blocks (which see the input most
470
+ directly) pass fine detail through the skip to the end blocks.
471
+ 3. **Path-Drop Guidance.** At inference, replacing the middle block output
472
+ with a learned mask feature creates an "unconditional" prediction that
473
+ preserves the skip path but drops the deep processing. Interpolating
474
+ between conditional and unconditional predictions (as in classifier-free
475
+ guidance) sharpens the output without requiring training-time dropout.
476
+
477
+ ---
478
+
479
+ ## 6. Training
480
+
481
+ ### 6.1 Loss Functions
482
+
483
+ The total training loss is:
484
+
485
+ ```
486
+ L_total = L_recon + 0.01 · L_semantic + 0.0001 · L_scale + 1e-5 · L_var
487
+ ```
488
+
489
+ | Loss | Weight | Description |
490
+ |------|--------|-------------|
491
+ | **Reconstruction** (L_recon) | 1.0 | SiD2 sigmoid-weighted x-prediction MSE (bias b = -2.0). Per-pixel `(x̂₀ - x₀)²` averaged over (C, H, W) per sample, multiplied by the SiD2 per-sample weight `w(t) = -½ · dλ/dt · e^b · σ(λ-b)`, then averaged over the batch |
492
+ | **Semantic alignment** (L_semantic) | 0.01 | Per-token `(1 - cosine(student, teacher))` averaged over all tokens and batch (see §4) |
493
+ | **Latent scale penalty** (L_scale) | 0.0001 | Per-channel variance `var_c` estimated over the batch and spatial dims (B, H, W), then `(log(var_c + ε) - log(target))²` averaged over channels. Target variance = 1.0 |
494
+ | **Posterior variance expansion** (L_var) | 1e-5 | Per-element `-log(σ² + δ)` where σ² is the posterior variance derived from the predicted log-SNR, averaged over all dims (B, C, H, W). See §3.2 |
495
+
496
+ **Note on loss scales:** The decoder reconstruction loss has a small
497
+ effective magnitude due to the SiD2 VP x-prediction weighting (the Jacobian
498
+ dλ/dt and sigmoid weighting compress the per-sample loss scale). As a
499
+ result, all auxiliary loss weights must be kept correspondingly small to
500
+ avoid dominating the reconstruction objective.
501
+
502
+ ### 6.2 Optimizer and Hyperparameters
503
+
504
+ | Parameter | Value |
505
+ |-----------|-------|
506
+ | Optimizer | AdamW (β₁=0.9, β₂=0.99) |
507
+ | Learning rate | 1e-4 (constant after warmup) |
508
+ | Weight decay | 0.0 |
509
+ | Warmup steps | 2,000 |
510
+ | Gradient clip | 1.0 (max norm) |
511
+ | Precision | AMP bfloat16 (FP32 master weights, TF32 matmul) |
512
+ | EMA decay | 0.9995 (updated every step) |
513
+ | Batch size | 128 |
514
+ | Timestep sampling | Uniform with SiD2 logSNR shift -1.0 |
515
+ | Compilation | `torch.compile` enabled |
516
+ | Training steps | 251k |
517
+ | Hardware | Single GPU |
518
+
519
+ Convergence is fast — training is stopped when the training loss starts
520
+ plateauing, which typically occurs within a few hours on a single GPU.
521
+
522
+ ### 6.3 Data
523
+
524
+ Training uses ~5M images at various resolutions: mostly photographs, with
525
+ a significant proportion of illustrations and text-heavy images (documents,
526
+ screenshots, book covers, diagrams) to encourage crisp line and edge
527
+ reconstruction. Images are loaded via two strategies in a 50/50 mix:
528
+
529
+ - **Full-image downsampling:** images are bucketed by aspect ratio and
530
+ downsampled to ~256² resolution (preserving aspect ratio).
531
+ - **Random 256×256 crops:** deterministic patches extracted from images
532
+ stored at ≥512px resolution.
533
+
534
+ This mixed strategy exposes the model to both global scene composition (via
535
+ downsampled full images) and fine local detail (via crops from higher-resolution
536
+ sources).
537
+
538
+ ---
539
+
540
+ ## 7. Model Configuration
541
+
542
+ | Parameter | Value |
543
+ |-----------|-------|
544
+ | Patch size | 16 |
545
+ | Model dimension | 896 |
546
+ | Encoder depth | 4 blocks |
547
+ | Decoder depth | 8 blocks (2 start + 4 middle + 2 end) |
548
+ | Bottleneck dimension | 128 channels |
549
+ | Spatial compression | 16× (H/16 × W/16) |
550
+ | Total compression | 6.0× (3·256 / 128) |
551
+ | MLP ratio | 4.0 |
552
+ | Depthwise kernel | 7×7 |
553
+ | AdaLN per-block delta rank | 128 |
554
+ | Block type | FCDM (ConvNeXt + GRN + scale/gate AdaLN) |
555
+ | Posterior | Diagonal Gaussian (VP log-SNR), variance expansion weight 1e-5 |
556
+ | Bottleneck norm | Disabled |
557
+ | λ_min, λ_max | -10, +10 |
558
+ | Sigmoid bias b | -2.0 |
559
+ | Pixel noise std s | 0.558 |
560
+ | Parameters | 88.8M |
561
+
562
+ ---
563
+
564
+ ## 8. Inference
565
+
566
+ ### Recommended Settings
567
+
568
+ | Use case | Steps (NFE) | PDG | Sampler | Notes |
569
+ |----------|-------------|-----|---------|-------|
570
+ | **PSNR-optimal** | 1 | off | DDIM | Default. Fastest. |
571
+ | **Perceptual** | 10 | on (2.0) | DDIM | Sharper details, ~15× slower (PDG skips middle blocks) |
572
+
573
+ ### Usage
574
+
575
+ ```python
576
+ from capacitor_diffae import CapacitorDiffAE, CapacitorDiffAEInferenceConfig
577
+
578
+ # Load model
579
+ model = CapacitorDiffAE.from_pretrained("data-archetype/semdisdiffae", device="cuda")
580
+
581
+ # Encode (returns posterior mode by default)
582
+ latents = model.encode(images) # [B,3,H,W] → [B,128,H/16,W/16]
583
+
584
+ # Decode (1 step)
585
+ recon = model.decode(latents, height=H, width=W)
586
+
587
+ # Full posterior access
588
+ posterior = model.encode_posterior(images)
589
+ print(posterior.mean.shape, posterior.logsnr.shape)
590
+ z_sampled = posterior.sample()
591
+ ```
592
+
593
+ ---
594
+
595
+ ## 9. Results
596
+
597
+ ## 7. Results
598
+
599
+ Reconstruction quality evaluated on a curated set of test images covering photographs, book covers, and documents. Flux.1 VAE (patch 8, 16 channels) is included as a reference at the same 12x compression ratio as the c64 variant.
600
+
601
+ ### 7.1 Interactive Viewer
602
+
603
+ **[Open full-resolution comparison viewer](https://huggingface.co/spaces/data-archetype/irdiffae-results)** — side-by-side reconstructions, RGB deltas, and latent PCA with adjustable image size.
604
+
605
+ ### 7.2 Inference Settings
606
+
607
+ | Setting | Value |
608
+ |---------|-------|
609
+ | Sampler | ddim |
610
+ | Steps | 1 |
611
+ | Schedule | linear |
612
+ | Seed | 42 |
613
+ | PDG | no_path_dropg |
614
+ | Batch size (timing) | 4 |
615
+
616
+ > All models run in bfloat16. Timings measured on an NVIDIA RTX Pro 6000 (Blackwell).
617
+
618
+ ### 7.3 Global Metrics
619
+
620
+ | Metric | semdisdiffae (1 step) | Flux.2 VAE |
621
+ |--------|--------|--------|
622
+ | Avg PSNR (dB) | 35.78 | 34.16 |
623
+ | Avg encode (ms/image) | 2.5 | 46.1 |
624
+ | Avg decode (ms/image) | 5.5 | 91.8 |
625
+
626
+ ### 7.4 Per-Image PSNR (dB)
627
+
628
+ | Image | semdisdiffae (1 step) | Flux.2 VAE |
629
+ |-------|--------|--------|
630
+ | p640x1536:94623 | 35.44 | 33.50 |
631
+ | p640x1536:94624 | 31.33 | 30.03 |
632
+ | p640x1536:94625 | 35.05 | 33.98 |
633
+ | p640x1536:94626 | 33.21 | 31.53 |
634
+ | p640x1536:94627 | 32.54 | 30.53 |
635
+ | p640x1536:94628 | 29.80 | 28.88 |
636
+ | p960x1024:216264 | 46.37 | 45.39 |
637
+ | p960x1024:216265 | 29.70 | 27.80 |
638
+ | p960x1024:216266 | 47.15 | 46.20 |
639
+ | p960x1024:216267 | 40.99 | 39.23 |
640
+ | p960x1024:216268 | 38.47 | 36.13 |
641
+ | p960x1024:216269 | 32.74 | 30.24 |
642
+ | p960x1024:216270 | 36.23 | 34.18 |
643
+ | p960x1024:216271 | 44.41 | 42.18 |
644
+ | p704x1472:94699 | 43.80 | 41.79 |
645
+ | p704x1472:94700 | 32.83 | 32.08 |
646
+ | p704x1472:94701 | 39.00 | 37.90 |
647
+ | p704x1472:94702 | 34.52 | 32.50 |
648
+ | p704x1472:94703 | 32.81 | 31.35 |
649
+ | p704x1472:94704 | 33.38 | 31.84 |
650
+ | p704x1472:94705 | 39.70 | 37.44 |
651
+ | p704x1472:94706 | 35.12 | 33.66 |
652
+ | r256_p1344x704:15577 | 31.02 | 29.98 |
653
+ | r256_p1344x704:15578 | 32.38 | 30.79 |
654
+ | r256_p1344x704:15579 | 33.27 | 31.83 |
655
+ | r256_p1344x704:15580 | 37.84 | 36.03 |
656
+ | r256_p1344x704:15581 | 38.57 | 36.94 |
657
+ | r256_p1344x704:15582 | 33.41 | 32.10 |
658
+ | r256_p1344x704:15583 | 36.67 | 34.54 |
659
+ | r256_p1344x704:15584 | 33.23 | 31.76 |
660
+ | r256_p896x1152:144131 | 35.30 | 33.60 |
661
+ | r256_p896x1152:144132 | 36.99 | 35.32 |
662
+ | r256_p896x1152:144133 | 39.69 | 37.33 |
663
+ | r256_p896x1152:144134 | 36.01 | 34.47 |
664
+ | r256_p896x1152:144135 | 31.20 | 29.87 |
665
+ | r256_p896x1152:144136 | 37.51 | 35.68 |
666
+ | r256_p896x1152:144137 | 33.83 | 32.86 |
667
+ | r256_p896x1152:144138 | 27.39 | 25.63 |
668
+ | VAE_accuracy_test_image | 36.64 | 35.25 |
669
+