data-archetype commited on
Commit
e40f02a
·
verified ·
1 Parent(s): d2b29ca

Upload capacitor_decoder private initial export

Browse files
README.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - diffusion
5
+ - autoencoder
6
+ - image-reconstruction
7
+ - decoder-only
8
+ - flux-compatible
9
+ - pytorch
10
+ ---
11
+
12
+ # data-archetype/capacitor_decoder
13
+
14
+ **Capacitor decoder**: a faster, lighter FLUX.2-compatible latent decoder built
15
+ on the
16
+ [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae)
17
+ architecture.
18
+
19
+ ## Decode Speed
20
+
21
+ | Resolution | Speedup vs FLUX.2 | Peak VRAM Reduction | capacitor_decoder (ms/image) | FLUX.2 VAE (ms/image) | capacitor_decoder peak VRAM | FLUX.2 peak VRAM |
22
+ |---:|---:|---:|---:|---:|---:|---:|
23
+ | `512x512` | `1.85x` | `59.3%` | `11.40` | `21.14` | `391.6 MiB` | `961.9 MiB` |
24
+ | `1024x1024` | `3.28x` | `79.1%` | `26.31` | `86.24` | `601.4 MiB` | `2876.4 MiB` |
25
+ | `2048x2048` | `4.70x` | `86.4%` | `86.29` | `405.84` | `1437.4 MiB` | `10531.4 MiB` |
26
+
27
+ These measurements are decode-only. Each image is first encoded once with the
28
+ same FLUX.2 encoder, latents are cached in memory, and then both decoders are
29
+ timed over the same cached latent set.
30
+
31
+ ## 2k PSNR Benchmark
32
+
33
+ | Model | Mean PSNR (dB) | Std (dB) | Median (dB) | Min (dB) | P5 (dB) | P95 (dB) | Max (dB) |
34
+ |---|---:|---:|---:|---:|---:|---:|---:|
35
+ | FLUX.2 VAE | 36.2849 | 4.5332 | 36.0728 | 22.7343 | 28.8853 | 43.6343 | 47.3836 |
36
+ | capacitor_decoder | 36.3399 | 4.4980 | 36.2891 | 23.2770 | 29.0597 | 43.6603 | 47.4312 |
37
+
38
+ | Delta vs FLUX.2 | Mean (dB) | Std (dB) | Median (dB) | Min (dB) | P5 (dB) | P95 (dB) | Max (dB) |
39
+ |---|---:|---:|---:|---:|---:|---:|---:|
40
+ | capacitor_decoder - FLUX.2 | 0.0550 | 0.5308 | 0.0618 | -1.9683 | -0.8108 | 0.8859 | 2.8071 |
41
+
42
+ Evaluated on `2000` validation images: roughly `2/3`
43
+ photographs and `1/3` book covers. Each image is encoded once with FLUX.2 and
44
+ reused for both decoders.
45
+
46
+ ## Usage
47
+
48
+ ```python
49
+ import torch
50
+ from diffusers.models import AutoencoderKLFlux2
51
+
52
+ from capacitor_decoder import CapacitorDecoder, CapacitorDecoderInferenceConfig
53
+
54
+
55
+ def flux2_patchify_and_whiten(
56
+ latents: torch.Tensor,
57
+ vae: AutoencoderKLFlux2,
58
+ ) -> torch.Tensor:
59
+ b, c, h, w = latents.shape
60
+ if h % 2 != 0 or w % 2 != 0:
61
+ raise ValueError(f"Expected even FLUX.2 latent grid, got H={h}, W={w}")
62
+ z = latents.reshape(b, c, h // 2, 2, w // 2, 2)
63
+ z = z.permute(0, 1, 3, 5, 2, 4).reshape(b, c * 4, h // 2, w // 2)
64
+ mean = vae.bn.running_mean.view(1, -1, 1, 1).to(device=z.device, dtype=torch.float32)
65
+ var = vae.bn.running_var.view(1, -1, 1, 1).to(device=z.device, dtype=torch.float32)
66
+ std = torch.sqrt(var + float(vae.config.batch_norm_eps))
67
+ return (z.to(torch.float32) - mean) / std
68
+
69
+
70
+ device = "cuda"
71
+ flux2 = AutoencoderKLFlux2.from_pretrained(
72
+ "BiliSakura/VAEs",
73
+ subfolder="FLUX2-VAE",
74
+ torch_dtype=torch.bfloat16,
75
+ ).to(device)
76
+ decoder = CapacitorDecoder.from_pretrained(
77
+ "data-archetype/capacitor_decoder",
78
+ device=device,
79
+ dtype=torch.bfloat16,
80
+ )
81
+
82
+ image = ... # [1, 3, H, W] in [-1, 1], with H and W divisible by 16
83
+
84
+ with torch.inference_mode():
85
+ posterior = flux2.encode(image.to(device=device, dtype=torch.bfloat16))
86
+ latent_mean = posterior.latent_dist.mean
87
+
88
+ # Default path: match the usual FLUX.2 convention.
89
+ # Whiten here, then let capacitor_decoder unwhiten internally before decode.
90
+ latents = flux2_patchify_and_whiten(latent_mean, flux2)
91
+ recon = decoder.decode(
92
+ latents,
93
+ height=int(image.shape[-2]),
94
+ width=int(image.shape[-1]),
95
+ inference_config=CapacitorDecoderInferenceConfig(num_steps=1),
96
+ )
97
+ ```
98
+
99
+ Whitening and dewhitening are optional, but they **must** stay consistent. The
100
+ default above matches the usual FLUX.2 pipeline behavior. If your upstream path
101
+ already gives you raw patchified decoder-space latents instead, skip whitening
102
+ upstream and call `decode(..., latents_are_flux2_whitened=False)`.
103
+
104
+ ## Details
105
+
106
+ - Default input contract: FLUX.2 patchified latents with FLUX.2 BN whitening still applied.
107
+ - Default decoder behavior: unwhiten with saved FLUX.2 BN running stats, then decode.
108
+ - Optional raw-latent mode: disable whitening upstream and call `decode(..., latents_are_flux2_whitened=False)`.
109
+ - Reused decoder architecture: [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae)
110
+ - [Technical report](technical_report_capacitor_decoder.md)
111
+ - [SemDisDiffAE technical report](https://huggingface.co/data-archetype/semdisdiffae/blob/main/technical_report_semantic.md)
112
+ - [Results viewer](https://huggingface.co/spaces/data-archetype/capacitor_decoder-results)
113
+
114
+ ## Citation
115
+
116
+ ```bibtex
117
+ @misc{capacitor_decoder,
118
+ title = {Capacitor Decoder: A Faster, Lighter FLUX.2-Compatible Latent Decoder},
119
+ author = {data-archetype},
120
+ email = {data-archetype@proton.me},
121
+ year = {2026},
122
+ month = apr,
123
+ url = {https://huggingface.co/data-archetype/capacitor_decoder},
124
+ }
125
+ ```
126
+
capacitor_decoder/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """capacitor_decoder: standalone decoder-only Capacitor export.
2
+
3
+ The package exposes a VP-diffusion decoder that consumes FLUX.2 latents.
4
+
5
+ Usage::
6
+
7
+ from capacitor_decoder import CapacitorDecoder, CapacitorDecoderInferenceConfig
8
+
9
+ model = CapacitorDecoder.from_pretrained("path/to/export", device="cuda")
10
+
11
+ # Default: expects FLUX.2 pipeline-style encoded latents
12
+ recon = model.decode(latents, height=H, width=W)
13
+
14
+ # Bypass unwhitening when latents are already raw decoder-space tokens
15
+ recon = model.decode(
16
+ latents,
17
+ height=H,
18
+ width=W,
19
+ latents_are_flux2_whitened=False,
20
+ )
21
+ """
22
+
23
+ from .config import CapacitorDecoderConfig, CapacitorDecoderInferenceConfig
24
+ from .model import CapacitorDecoder
25
+
26
+ __all__ = [
27
+ "CapacitorDecoder",
28
+ "CapacitorDecoderConfig",
29
+ "CapacitorDecoderInferenceConfig",
30
+ ]
capacitor_decoder/adaln.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scale+Gate AdaLN (2-way) for FCDM decoder blocks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class AdaLNScaleGateZeroProjector(nn.Module):
9
+ """Packed 2-way AdaLN projection (SiLU -> Linear), zero-initialized.
10
+
11
+ Outputs [B, 2*d_model] packed as (scale, gate).
12
+ """
13
+
14
+ def __init__(self, d_model: int, d_cond: int) -> None:
15
+ super().__init__()
16
+ self.d_model: int = int(d_model)
17
+ self.d_cond: int = int(d_cond)
18
+ self.act: nn.SiLU = nn.SiLU()
19
+ self.proj: nn.Linear = nn.Linear(self.d_cond, 2 * self.d_model)
20
+ nn.init.zeros_(self.proj.weight)
21
+ nn.init.zeros_(self.proj.bias)
22
+
23
+ def forward_activated(self, act_cond: Tensor) -> Tensor:
24
+ """Return packed modulation for a pre-activated conditioning vector."""
25
+ return self.proj(act_cond)
26
+
27
+ def forward(self, cond: Tensor) -> Tensor:
28
+ """Return packed modulation [B, 2*d_model]."""
29
+ return self.forward_activated(self.act(cond))
30
+
31
+
32
+ class AdaLNScaleGateZeroLowRankDelta(nn.Module):
33
+ """Low-rank delta for 2-way AdaLN: down(d_cond -> rank) -> up(rank -> 2*d_model).
34
+
35
+ Zero-initialized up projection preserves zero-output semantics at init.
36
+ """
37
+
38
+ def __init__(self, *, d_model: int, d_cond: int, rank: int) -> None:
39
+ super().__init__()
40
+ self.d_model: int = int(d_model)
41
+ self.d_cond: int = int(d_cond)
42
+ self.rank: int = int(rank)
43
+ self.down: nn.Linear = nn.Linear(self.d_cond, self.rank, bias=False)
44
+ self.up: nn.Linear = nn.Linear(self.rank, 2 * self.d_model, bias=False)
45
+ nn.init.normal_(self.down.weight, mean=0.0, std=0.02)
46
+ nn.init.zeros_(self.up.weight)
47
+
48
+ def forward(self, act_cond: Tensor) -> Tensor:
49
+ """Return packed delta modulation [B, 2*d_model]."""
50
+ return self.up(self.down(act_cond))
capacitor_decoder/config.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Frozen architecture and user-tunable inference config for capacitor_decoder."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import asdict, dataclass
7
+ from pathlib import Path
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class CapacitorDecoderConfig:
12
+ """Frozen architecture config stored alongside exported weights."""
13
+
14
+ in_channels: int = 3
15
+ patch_size: int = 16
16
+ model_dim: int = 896
17
+ decoder_depth: int = 8
18
+ decoder_start_blocks: int = 2
19
+ decoder_end_blocks: int = 2
20
+ bottleneck_dim: int = 128
21
+ mlp_ratio: float = 4.0
22
+ depthwise_kernel_size: int = 7
23
+ adaln_low_rank_rank: int = 128
24
+ logsnr_min: float = -10.0
25
+ logsnr_max: float = 10.0
26
+ pixel_noise_std: float = 0.558
27
+ flux2_batch_norm_eps: float = 1e-5
28
+
29
+ @property
30
+ def latent_channels(self) -> int:
31
+ """Channel width of the exported latent space."""
32
+
33
+ return self.bottleneck_dim
34
+
35
+ @property
36
+ def effective_patch_size(self) -> int:
37
+ """Effective spatial stride from image to latent grid."""
38
+
39
+ return self.patch_size
40
+
41
+ def save(self, path: str | Path) -> None:
42
+ """Save config as JSON."""
43
+
44
+ output_path = Path(path)
45
+ output_path.parent.mkdir(parents=True, exist_ok=True)
46
+ output_path.write_text(json.dumps(asdict(self), indent=2) + "\n")
47
+
48
+ @classmethod
49
+ def load(cls, path: str | Path) -> CapacitorDecoderConfig:
50
+ """Load config from JSON."""
51
+
52
+ data = json.loads(Path(path).read_text())
53
+ return cls(**data)
54
+
55
+
56
+ @dataclass
57
+ class CapacitorDecoderInferenceConfig:
58
+ """User-tunable decoder inference parameters."""
59
+
60
+ num_steps: int = 1
61
+ sampler: str = "ddim"
62
+ schedule: str = "linear"
63
+ pdg: bool = False
64
+ pdg_strength: float = 2.0
65
+ seed: int | None = None
capacitor_decoder/decoder.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDM DiffAE decoder: skip-concat topology with FCDM blocks and path-drop PDG.
2
+
3
+ No outer RMSNorms (use_other_outer_rms_norms=False during training):
4
+ norm_in, latent_norm, and norm_out are all absent.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+ from torch import Tensor, nn
11
+
12
+ from .adaln import AdaLNScaleGateZeroLowRankDelta, AdaLNScaleGateZeroProjector
13
+ from .fcdm_block import FCDMBlock
14
+ from .straight_through_encoder import Patchify
15
+ from .time_embed import SinusoidalTimeEmbeddingMLP
16
+
17
+
18
+ class Decoder(nn.Module):
19
+ """VP diffusion decoder conditioned on encoder latents and timestep.
20
+
21
+ Architecture (skip-concat, 2+4+2 default):
22
+ Patchify x_t -> Fuse with upsampled z
23
+ -> Start blocks (2) -> Middle blocks (4) -> Skip fuse -> End blocks (2)
24
+ -> Conv1x1 -> PixelShuffle
25
+
26
+ Path-Drop Guidance (PDG) at inference:
27
+ - Replace middle block output with ``path_drop_mask_feature`` to create
28
+ an unconditional prediction, then extrapolate.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ in_channels: int,
34
+ patch_size: int,
35
+ model_dim: int,
36
+ depth: int,
37
+ start_block_count: int,
38
+ end_block_count: int,
39
+ bottleneck_dim: int,
40
+ mlp_ratio: float,
41
+ depthwise_kernel_size: int,
42
+ adaln_low_rank_rank: int,
43
+ ) -> None:
44
+ super().__init__()
45
+ self.patch_size = int(patch_size)
46
+ self.model_dim = int(model_dim)
47
+
48
+ # Input processing (no norm_in)
49
+ self.patchify = Patchify(in_channels, patch_size, model_dim)
50
+
51
+ # Latent conditioning path (no latent_norm)
52
+ self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
53
+ self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
54
+
55
+ # Time embedding
56
+ self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
57
+
58
+ # 2-way AdaLN: shared base projector + per-block low-rank deltas
59
+ self.adaln_base = AdaLNScaleGateZeroProjector(
60
+ d_model=model_dim, d_cond=model_dim
61
+ )
62
+ self.adaln_deltas = nn.ModuleList(
63
+ [
64
+ AdaLNScaleGateZeroLowRankDelta(
65
+ d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
66
+ )
67
+ for _ in range(depth)
68
+ ]
69
+ )
70
+
71
+ # Block layout: start + middle + end
72
+ middle_count = depth - start_block_count - end_block_count
73
+ self._middle_start_idx = start_block_count
74
+ self._end_start_idx = start_block_count + middle_count
75
+
76
+ def _make_blocks(count: int) -> nn.ModuleList:
77
+ return nn.ModuleList(
78
+ [
79
+ FCDMBlock(
80
+ model_dim,
81
+ mlp_ratio,
82
+ depthwise_kernel_size=depthwise_kernel_size,
83
+ use_external_adaln=True,
84
+ )
85
+ for _ in range(count)
86
+ ]
87
+ )
88
+
89
+ self.start_blocks = _make_blocks(start_block_count)
90
+ self.middle_blocks = _make_blocks(middle_count)
91
+ self.fuse_skip = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
92
+ self.end_blocks = _make_blocks(end_block_count)
93
+
94
+ # Learned mask feature for path-drop PDG
95
+ self.path_drop_mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
96
+
97
+ # Output head (no norm_out)
98
+ self.out_proj = nn.Conv2d(
99
+ model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
100
+ )
101
+ self.unpatchify = nn.PixelShuffle(patch_size)
102
+
103
+ def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
104
+ """Compute packed AdaLN modulation = shared_base + per-layer delta."""
105
+ act = self.adaln_base.act(cond)
106
+ base_m = self.adaln_base.forward_activated(act)
107
+ delta_m = self.adaln_deltas[layer_idx](act)
108
+ return base_m + delta_m
109
+
110
+ def _run_blocks(
111
+ self, blocks: nn.ModuleList, x: Tensor, cond: Tensor, start_index: int
112
+ ) -> Tensor:
113
+ """Run a group of decoder blocks with per-block AdaLN modulation."""
114
+ for local_idx, block in enumerate(blocks):
115
+ adaln_m = self._adaln_m_for_layer(cond, layer_idx=start_index + local_idx)
116
+ x = block(x, adaln_m=adaln_m)
117
+ return x
118
+
119
+ def forward(
120
+ self,
121
+ x_t: Tensor,
122
+ t: Tensor,
123
+ latents: Tensor,
124
+ *,
125
+ drop_middle_blocks: bool = False,
126
+ ) -> Tensor:
127
+ """Single decoder forward pass.
128
+
129
+ Args:
130
+ x_t: Noised image [B, C, H, W].
131
+ t: Timestep [B] in [0, 1].
132
+ latents: Encoder latents [B, bottleneck_dim, h, w].
133
+ drop_middle_blocks: Replace middle block output with mask feature (PDG).
134
+
135
+ Returns:
136
+ x0 prediction [B, C, H, W].
137
+ """
138
+ x_feat = self.patchify(x_t)
139
+ z_up = self.latent_up(latents)
140
+
141
+ fused = torch.cat([x_feat, z_up], dim=1)
142
+ fused = self.fuse_in(fused)
143
+
144
+ cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
145
+
146
+ start_out = self._run_blocks(self.start_blocks, fused, cond, start_index=0)
147
+
148
+ if drop_middle_blocks:
149
+ middle_out = self.path_drop_mask_feature.to(
150
+ device=x_t.device, dtype=x_t.dtype
151
+ ).expand_as(start_out)
152
+ else:
153
+ middle_out = self._run_blocks(
154
+ self.middle_blocks,
155
+ start_out,
156
+ cond,
157
+ start_index=self._middle_start_idx,
158
+ )
159
+
160
+ skip_fused = torch.cat([start_out, middle_out], dim=1)
161
+ skip_fused = self.fuse_skip(skip_fused)
162
+
163
+ end_out = self._run_blocks(
164
+ self.end_blocks, skip_fused, cond, start_index=self._end_start_idx
165
+ )
166
+
167
+ patches = self.out_proj(end_out)
168
+ return self.unpatchify(patches)
capacitor_decoder/fcdm_block.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDM block: ConvNeXt-style conv block with GRN and scale+gate AdaLN."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor, nn
8
+
9
+ from .norms import ChannelWiseRMSNorm
10
+
11
+
12
+ class GRN(nn.Module):
13
+ """Global Response Normalization for NCHW tensors."""
14
+
15
+ def __init__(self, channels: int, *, eps: float = 1e-6) -> None:
16
+ super().__init__()
17
+ self.eps: float = float(eps)
18
+ c = int(channels)
19
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
20
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
21
+
22
+ def forward(self, x: Tensor) -> Tensor:
23
+ g = torch.linalg.vector_norm(x, ord=2, dim=(2, 3), keepdim=True)
24
+ g_fp32 = g.to(dtype=torch.float32)
25
+ n = (g_fp32 / (g_fp32.mean(dim=1, keepdim=True) + self.eps)).to(dtype=x.dtype)
26
+ gamma = self.gamma.to(device=x.device, dtype=x.dtype)
27
+ beta = self.beta.to(device=x.device, dtype=x.dtype)
28
+ return gamma * (x * n) + beta + x
29
+
30
+
31
+ class FCDMBlock(nn.Module):
32
+ """ConvNeXt-style block with scale+gate AdaLN and GRN.
33
+
34
+ Two modes:
35
+ - Unconditioned (encoder): uses learned layer-scale for near-identity init.
36
+ - External AdaLN (decoder): receives packed [B, 2*C] modulation (scale, gate).
37
+ The gate is applied raw (no tanh).
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ channels: int,
43
+ mlp_ratio: float,
44
+ *,
45
+ depthwise_kernel_size: int = 7,
46
+ use_external_adaln: bool = False,
47
+ norm_eps: float = 1e-6,
48
+ layer_scale_init: float = 1e-3,
49
+ ) -> None:
50
+ super().__init__()
51
+ self.channels: int = int(channels)
52
+ self.mlp_ratio: float = float(mlp_ratio)
53
+
54
+ self.dwconv = nn.Conv2d(
55
+ channels,
56
+ channels,
57
+ kernel_size=depthwise_kernel_size,
58
+ padding=depthwise_kernel_size // 2,
59
+ stride=1,
60
+ groups=channels,
61
+ bias=True,
62
+ )
63
+ self.norm = ChannelWiseRMSNorm(channels, eps=float(norm_eps), affine=False)
64
+ hidden = max(int(float(channels) * float(mlp_ratio)), 1)
65
+ self.pwconv1 = nn.Conv2d(channels, hidden, kernel_size=1, bias=True)
66
+ self.grn = GRN(hidden, eps=1e-6)
67
+ self.pwconv2 = nn.Conv2d(hidden, channels, kernel_size=1, bias=True)
68
+
69
+ if not use_external_adaln:
70
+ self.layer_scale = nn.Parameter(
71
+ torch.full((channels,), float(layer_scale_init))
72
+ )
73
+ else:
74
+ self.register_parameter("layer_scale", None)
75
+
76
+ def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor:
77
+ b, c, _, _ = x.shape
78
+
79
+ if adaln_m is not None:
80
+ m = adaln_m.to(device=x.device, dtype=x.dtype)
81
+ scale, gate = m.chunk(2, dim=-1)
82
+ else:
83
+ scale = gate = None
84
+
85
+ h = self.dwconv(x)
86
+ h = self.norm(h)
87
+
88
+ if scale is not None:
89
+ h = h * (1.0 + scale.view(b, c, 1, 1))
90
+
91
+ h = self.pwconv1(h)
92
+ h = F.gelu(h)
93
+ h = self.grn(h)
94
+ h = self.pwconv2(h)
95
+
96
+ if gate is not None:
97
+ gate_view = gate.view(b, c, 1, 1)
98
+ else:
99
+ gate_view = self.layer_scale.view(1, c, 1, 1).to( # type: ignore[union-attr]
100
+ device=h.device, dtype=h.dtype
101
+ )
102
+
103
+ return x + gate_view * h
capacitor_decoder/model.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Standalone decoder-only Capacitor model for HuggingFace distribution."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ from .config import CapacitorDecoderConfig, CapacitorDecoderInferenceConfig
11
+ from .decoder import Decoder
12
+ from .samplers import run_ddim, run_dpmpp_2m
13
+ from .vp_diffusion import get_schedule, make_initial_state, sample_noise
14
+
15
+
16
+ def _resolve_model_dir(
17
+ path_or_repo_id: str | Path,
18
+ *,
19
+ revision: str | None,
20
+ cache_dir: str | Path | None,
21
+ ) -> Path:
22
+ """Resolve a local path or HuggingFace Hub repo ID to a local directory."""
23
+
24
+ local = Path(path_or_repo_id)
25
+ if local.is_dir():
26
+ return local
27
+ repo_id = str(path_or_repo_id)
28
+ try:
29
+ from huggingface_hub import snapshot_download
30
+ except ImportError as exc:
31
+ raise ImportError(
32
+ f"'{repo_id}' is not an existing local directory. "
33
+ "To download from HuggingFace Hub, install huggingface_hub: "
34
+ "pip install huggingface_hub"
35
+ ) from exc
36
+ cache_dir_str = str(cache_dir) if cache_dir is not None else None
37
+ local_dir = snapshot_download(
38
+ repo_id,
39
+ revision=revision,
40
+ cache_dir=cache_dir_str,
41
+ )
42
+ return Path(local_dir)
43
+
44
+
45
+ class CapacitorDecoder(nn.Module):
46
+ """Decoder-only Capacitor export that reconstructs from FLUX.2 latents.
47
+
48
+ Default input convention matches the public FLUX.2 diffusers pipeline:
49
+ patchified latents whitened by the VAE BN running statistics. The decoder
50
+ therefore unwhitens by default before applying the VP-diffusion decode path.
51
+ """
52
+
53
+ def __init__(self, config: CapacitorDecoderConfig) -> None:
54
+ super().__init__()
55
+ self.config = config
56
+ self.register_buffer(
57
+ "flux2_running_mean",
58
+ torch.zeros((config.latent_channels,), dtype=torch.float32),
59
+ )
60
+ self.register_buffer(
61
+ "flux2_running_var",
62
+ torch.ones((config.latent_channels,), dtype=torch.float32),
63
+ )
64
+ self.decoder = Decoder(
65
+ in_channels=config.in_channels,
66
+ patch_size=config.patch_size,
67
+ model_dim=config.model_dim,
68
+ depth=config.decoder_depth,
69
+ start_block_count=config.decoder_start_blocks,
70
+ end_block_count=config.decoder_end_blocks,
71
+ bottleneck_dim=config.bottleneck_dim,
72
+ mlp_ratio=config.mlp_ratio,
73
+ depthwise_kernel_size=config.depthwise_kernel_size,
74
+ adaln_low_rank_rank=config.adaln_low_rank_rank,
75
+ )
76
+
77
+ @classmethod
78
+ def from_pretrained(
79
+ cls,
80
+ path_or_repo_id: str | Path,
81
+ *,
82
+ dtype: torch.dtype = torch.bfloat16,
83
+ device: str | torch.device = "cpu",
84
+ revision: str | None = None,
85
+ cache_dir: str | Path | None = None,
86
+ ) -> CapacitorDecoder:
87
+ """Load a pretrained model from a local directory or HuggingFace Hub."""
88
+
89
+ model_dir = _resolve_model_dir(
90
+ path_or_repo_id, revision=revision, cache_dir=cache_dir
91
+ )
92
+ config = CapacitorDecoderConfig.load(model_dir / "config.json")
93
+ model = cls(config)
94
+
95
+ safetensors_path = model_dir / "model.safetensors"
96
+ pt_path = model_dir / "model.pt"
97
+ if safetensors_path.exists():
98
+ try:
99
+ from safetensors.torch import load_file
100
+ except ImportError as exc:
101
+ raise ImportError(
102
+ "safetensors package required to load .safetensors files. "
103
+ "Install with: pip install safetensors"
104
+ ) from exc
105
+ state_dict = load_file(str(safetensors_path), device=str(device))
106
+ elif pt_path.exists():
107
+ state_dict = torch.load(str(pt_path), map_location=device, weights_only=True)
108
+ else:
109
+ raise FileNotFoundError(
110
+ f"No model weights found in {model_dir}. "
111
+ "Expected model.safetensors or model.pt."
112
+ )
113
+
114
+ model.load_state_dict(state_dict)
115
+ model = model.to(dtype=dtype, device=torch.device(device))
116
+ model.eval()
117
+ return model
118
+
119
+ def flux2_norm_stats(
120
+ self,
121
+ *,
122
+ device: torch.device,
123
+ dtype: torch.dtype,
124
+ ) -> tuple[Tensor, Tensor]:
125
+ """Return FLUX.2 latent BN running stats as rank-4 tensors."""
126
+
127
+ mean = self.flux2_running_mean.view(1, -1, 1, 1).to(device=device, dtype=dtype)
128
+ var = self.flux2_running_var.view(1, -1, 1, 1).to(device=device, dtype=dtype)
129
+ std = torch.sqrt(var + float(self.config.flux2_batch_norm_eps))
130
+ return mean, std
131
+
132
+ def unwhiten_flux2_latents(self, latents: Tensor) -> Tensor:
133
+ """Undo FLUX.2 BN whitening on patchified latents."""
134
+
135
+ z_fp32 = latents.to(torch.float32)
136
+ mean, std = self.flux2_norm_stats(device=z_fp32.device, dtype=torch.float32)
137
+ return z_fp32 * std + mean
138
+
139
+ @torch.no_grad()
140
+ def decode(
141
+ self,
142
+ latents: Tensor,
143
+ height: int,
144
+ width: int,
145
+ *,
146
+ latents_are_flux2_whitened: bool = True,
147
+ inference_config: CapacitorDecoderInferenceConfig | None = None,
148
+ ) -> Tensor:
149
+ """Decode FLUX.2 latents to images via VP diffusion.
150
+
151
+ Args:
152
+ latents: [B, 128, H/16, W/16] patchified FLUX.2 latents.
153
+ height: Output image height.
154
+ width: Output image width.
155
+ latents_are_flux2_whitened: When True, interpret the input as
156
+ pipeline-style FLUX.2 encoded latents and unwhiten them using the
157
+ saved FLUX.2 BN running stats before decode. Set to False when the
158
+ input is already raw decoder-space latents.
159
+ inference_config: Optional inference parameters.
160
+ """
161
+
162
+ cfg = inference_config or CapacitorDecoderInferenceConfig()
163
+ config = self.config
164
+ if int(latents.shape[1]) != int(config.latent_channels):
165
+ raise ValueError(
166
+ "Latent channel count must match exported decoder input width: "
167
+ f"got C={int(latents.shape[1])}, expected={int(config.latent_channels)}"
168
+ )
169
+ if height % int(config.effective_patch_size) != 0 or width % int(
170
+ config.effective_patch_size
171
+ ) != 0:
172
+ raise ValueError(
173
+ f"height={height} and width={width} must be divisible by "
174
+ f"effective_patch_size={int(config.effective_patch_size)}"
175
+ )
176
+
177
+ try:
178
+ model_dtype = next(self.parameters()).dtype
179
+ except StopIteration:
180
+ model_dtype = torch.float32
181
+
182
+ latents_fp32 = (
183
+ self.unwhiten_flux2_latents(latents)
184
+ if bool(latents_are_flux2_whitened)
185
+ else latents.to(torch.float32)
186
+ )
187
+ latents_in = latents_fp32.to(dtype=model_dtype)
188
+
189
+ shape = (int(latents.shape[0]), config.in_channels, height, width)
190
+ noise = sample_noise(
191
+ shape,
192
+ noise_std=config.pixel_noise_std,
193
+ seed=cfg.seed,
194
+ device=torch.device("cpu"),
195
+ dtype=torch.float32,
196
+ )
197
+
198
+ schedule = get_schedule(cfg.schedule, cfg.num_steps).to(device=latents.device)
199
+ initial_state = make_initial_state(
200
+ noise=noise.to(device=latents.device),
201
+ t_start=schedule[0:1],
202
+ logsnr_min=config.logsnr_min,
203
+ logsnr_max=config.logsnr_max,
204
+ )
205
+ device_type = "cuda" if latents.device.type == "cuda" else "cpu"
206
+ with torch.autocast(device_type=device_type, enabled=False):
207
+
208
+ def _forward_fn(
209
+ x_t: Tensor,
210
+ t: Tensor,
211
+ condition_latents: Tensor,
212
+ *,
213
+ drop_middle_blocks: bool = False,
214
+ mask_latent_tokens: bool = False,
215
+ ) -> Tensor:
216
+ del mask_latent_tokens
217
+ return self.decoder(
218
+ x_t.to(dtype=model_dtype),
219
+ t,
220
+ condition_latents.to(dtype=model_dtype),
221
+ drop_middle_blocks=drop_middle_blocks,
222
+ )
223
+
224
+ pdg_mode = "path_drop" if bool(cfg.pdg) else "disabled"
225
+ if cfg.sampler == "ddim":
226
+ sampler_fn = run_ddim
227
+ elif cfg.sampler == "dpmpp_2m":
228
+ sampler_fn = run_dpmpp_2m
229
+ else:
230
+ raise ValueError(
231
+ f"Unsupported sampler: {cfg.sampler!r}. Use 'ddim' or 'dpmpp_2m'."
232
+ )
233
+ result = sampler_fn(
234
+ forward_fn=_forward_fn,
235
+ initial_state=initial_state,
236
+ schedule=schedule,
237
+ latents=latents_in,
238
+ logsnr_min=config.logsnr_min,
239
+ logsnr_max=config.logsnr_max,
240
+ pdg_mode=pdg_mode,
241
+ pdg_strength=float(cfg.pdg_strength),
242
+ device=latents.device,
243
+ )
244
+ return result.to(torch.float32).clamp(-1.0, 1.0)
capacitor_decoder/norms.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Channel-wise RMSNorm for NCHW tensors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+
9
+ class ChannelWiseRMSNorm(nn.Module):
10
+ """Channel-wise RMSNorm with float32 reduction for numerical stability.
11
+
12
+ Normalizes across channels per spatial position. Supports optional
13
+ per-channel affine weight and bias.
14
+ """
15
+
16
+ def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None:
17
+ super().__init__()
18
+ self.channels: int = int(channels)
19
+ self._eps: float = float(eps)
20
+ if affine:
21
+ self.weight = nn.Parameter(torch.ones(self.channels))
22
+ self.bias = nn.Parameter(torch.zeros(self.channels))
23
+ else:
24
+ self.register_parameter("weight", None)
25
+ self.register_parameter("bias", None)
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ if x.dim() < 2:
29
+ return x
30
+ # Float32 accumulation for stability
31
+ ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
32
+ inv_rms = torch.rsqrt(ms + self._eps)
33
+ y = x * inv_rms.to(dtype=x.dtype)
34
+ if self.weight is not None:
35
+ shape = (1, -1) + (1,) * (x.dim() - 2)
36
+ y = y * self.weight.view(shape).to(dtype=x.dtype)
37
+ if self.bias is not None:
38
+ y = y + self.bias.view(shape).to(dtype=x.dtype)
39
+ return y
capacitor_decoder/samplers.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DDIM and DPM++2M samplers for VP diffusion with path-drop PDG support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Protocol
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+ from .vp_diffusion import (
11
+ alpha_sigma_from_logsnr,
12
+ broadcast_time_like,
13
+ shifted_cosine_interpolated_logsnr_from_t,
14
+ )
15
+
16
+
17
+ class DecoderForwardFn(Protocol):
18
+ """Callable that predicts x0 from (x_t, t, latents) with path-drop PDG flag."""
19
+
20
+ def __call__(
21
+ self,
22
+ x_t: Tensor,
23
+ t: Tensor,
24
+ latents: Tensor,
25
+ *,
26
+ drop_middle_blocks: bool = False,
27
+ mask_latent_tokens: bool = False,
28
+ ) -> Tensor: ...
29
+
30
+
31
+ def _reconstruct_eps_from_x0(
32
+ *, x_t: Tensor, x0_hat: Tensor, alpha: Tensor, sigma: Tensor
33
+ ) -> Tensor:
34
+ """Reconstruct eps_hat from (x_t, x0_hat) under VP parameterization.
35
+
36
+ eps_hat = (x_t - alpha * x0_hat) / sigma. All float32.
37
+ """
38
+ alpha_view = broadcast_time_like(alpha, x_t).to(dtype=torch.float32)
39
+ sigma_view = broadcast_time_like(sigma, x_t).to(dtype=torch.float32)
40
+ x_t_f32 = x_t.to(torch.float32)
41
+ x0_f32 = x0_hat.to(torch.float32)
42
+ return (x_t_f32 - alpha_view * x0_f32) / sigma_view
43
+
44
+
45
+ def _ddim_step(
46
+ *,
47
+ x0_hat: Tensor,
48
+ eps_hat: Tensor,
49
+ alpha_next: Tensor,
50
+ sigma_next: Tensor,
51
+ ref: Tensor,
52
+ ) -> Tensor:
53
+ """DDIM step: x_next = alpha_next * x0_hat + sigma_next * eps_hat."""
54
+ a = broadcast_time_like(alpha_next, ref).to(dtype=torch.float32)
55
+ s = broadcast_time_like(sigma_next, ref).to(dtype=torch.float32)
56
+ return a * x0_hat + s * eps_hat
57
+
58
+
59
+ def _predict_with_pdg(
60
+ forward_fn: DecoderForwardFn,
61
+ state: Tensor,
62
+ t_vec: Tensor,
63
+ latents: Tensor,
64
+ *,
65
+ pdg_mode: str,
66
+ pdg_strength: float,
67
+ ) -> Tensor:
68
+ """Run decoder forward with optional PDG guidance.
69
+
70
+ Args:
71
+ forward_fn: Decoder forward function.
72
+ state: Current noised state [B, C, H, W].
73
+ t_vec: Timestep vector [B].
74
+ latents: Encoder latents.
75
+ pdg_mode: "disabled" or "path_drop".
76
+ pdg_strength: CFG-like strength for PDG.
77
+
78
+ Returns:
79
+ x0_hat prediction in float32.
80
+ """
81
+ if pdg_mode == "path_drop":
82
+ x0_uncond = forward_fn(state, t_vec, latents, drop_middle_blocks=True).to(
83
+ torch.float32
84
+ )
85
+ x0_cond = forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
86
+ torch.float32
87
+ )
88
+ return x0_uncond + pdg_strength * (x0_cond - x0_uncond)
89
+ else:
90
+ return forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
91
+ torch.float32
92
+ )
93
+
94
+
95
+ def run_ddim(
96
+ *,
97
+ forward_fn: DecoderForwardFn,
98
+ initial_state: Tensor,
99
+ schedule: Tensor,
100
+ latents: Tensor,
101
+ logsnr_min: float,
102
+ logsnr_max: float,
103
+ log_change_high: float = 0.0,
104
+ log_change_low: float = 0.0,
105
+ pdg_mode: str = "disabled",
106
+ pdg_strength: float = 1.5,
107
+ device: torch.device | None = None,
108
+ ) -> Tensor:
109
+ """Run DDIM sampling loop with path-drop PDG support.
110
+
111
+ Args:
112
+ forward_fn: Decoder forward function (x_t, t, latents) -> x0_hat.
113
+ initial_state: Starting noised state [B, C, H, W] in float32.
114
+ schedule: Descending t-schedule [num_steps] in [0, 1].
115
+ latents: Encoder latents [B, bottleneck_dim, h, w].
116
+ logsnr_min, logsnr_max: VP schedule endpoints.
117
+ log_change_high, log_change_low: Shifted-cosine schedule parameters.
118
+ pdg_mode: "disabled" or "path_drop".
119
+ pdg_strength: CFG-like strength for PDG.
120
+ device: Target device.
121
+
122
+ Returns:
123
+ Denoised samples [B, C, H, W] in float32.
124
+ """
125
+ run_device = device or initial_state.device
126
+ batch_size = int(initial_state.shape[0])
127
+ state = initial_state.to(device=run_device, dtype=torch.float32)
128
+
129
+ # Precompute logSNR, alpha, sigma for all schedule points
130
+ lmb = shifted_cosine_interpolated_logsnr_from_t(
131
+ schedule.to(device=run_device),
132
+ logsnr_min=logsnr_min,
133
+ logsnr_max=logsnr_max,
134
+ log_change_high=log_change_high,
135
+ log_change_low=log_change_low,
136
+ )
137
+ alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
138
+
139
+ for i in range(int(schedule.numel()) - 1):
140
+ t_i = schedule[i]
141
+ a_t = alpha_sched[i].expand(batch_size)
142
+ s_t = sigma_sched[i].expand(batch_size)
143
+ a_next = alpha_sched[i + 1].expand(batch_size)
144
+ s_next = sigma_sched[i + 1].expand(batch_size)
145
+
146
+ # Model prediction with optional PDG
147
+ t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
148
+ x0_hat = _predict_with_pdg(
149
+ forward_fn,
150
+ state,
151
+ t_vec,
152
+ latents,
153
+ pdg_mode=pdg_mode,
154
+ pdg_strength=pdg_strength,
155
+ )
156
+
157
+ eps_hat = _reconstruct_eps_from_x0(
158
+ x_t=state, x0_hat=x0_hat, alpha=a_t, sigma=s_t
159
+ )
160
+ state = _ddim_step(
161
+ x0_hat=x0_hat,
162
+ eps_hat=eps_hat,
163
+ alpha_next=a_next,
164
+ sigma_next=s_next,
165
+ ref=state,
166
+ )
167
+
168
+ return state
169
+
170
+
171
+ def run_dpmpp_2m(
172
+ *,
173
+ forward_fn: DecoderForwardFn,
174
+ initial_state: Tensor,
175
+ schedule: Tensor,
176
+ latents: Tensor,
177
+ logsnr_min: float,
178
+ logsnr_max: float,
179
+ log_change_high: float = 0.0,
180
+ log_change_low: float = 0.0,
181
+ pdg_mode: str = "disabled",
182
+ pdg_strength: float = 1.5,
183
+ device: torch.device | None = None,
184
+ ) -> Tensor:
185
+ """Run DPM++2M sampling loop with path-drop PDG support.
186
+
187
+ Multi-step solver using exponential integrator formulation in half-lambda space.
188
+ """
189
+ run_device = device or initial_state.device
190
+ batch_size = int(initial_state.shape[0])
191
+ state = initial_state.to(device=run_device, dtype=torch.float32)
192
+
193
+ # Precompute logSNR, alpha, sigma, half-lambda for all schedule points
194
+ lmb = shifted_cosine_interpolated_logsnr_from_t(
195
+ schedule.to(device=run_device),
196
+ logsnr_min=logsnr_min,
197
+ logsnr_max=logsnr_max,
198
+ log_change_high=log_change_high,
199
+ log_change_low=log_change_low,
200
+ )
201
+ alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
202
+ half_lambda = 0.5 * lmb.to(torch.float32)
203
+
204
+ x0_prev: Tensor | None = None
205
+
206
+ for i in range(int(schedule.numel()) - 1):
207
+ t_i = schedule[i]
208
+ s_t = sigma_sched[i].expand(batch_size)
209
+ a_next = alpha_sched[i + 1].expand(batch_size)
210
+ s_next = sigma_sched[i + 1].expand(batch_size)
211
+
212
+ # Model prediction with optional PDG
213
+ t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
214
+ x0_hat = _predict_with_pdg(
215
+ forward_fn,
216
+ state,
217
+ t_vec,
218
+ latents,
219
+ pdg_mode=pdg_mode,
220
+ pdg_strength=pdg_strength,
221
+ )
222
+
223
+ lam_t = half_lambda[i].expand(batch_size)
224
+ lam_next = half_lambda[i + 1].expand(batch_size)
225
+ h = (lam_next - lam_t).to(torch.float32)
226
+ phi_1 = torch.expm1(-h)
227
+
228
+ sigma_ratio = (s_next / s_t).to(torch.float32)
229
+
230
+ if i == 0 or x0_prev is None:
231
+ # First-order step
232
+ state = (
233
+ sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
234
+ - broadcast_time_like(a_next, state).to(torch.float32)
235
+ * broadcast_time_like(phi_1, state).to(torch.float32)
236
+ * x0_hat
237
+ )
238
+ else:
239
+ # Second-order step
240
+ lam_prev = half_lambda[i - 1].expand(batch_size)
241
+ h_0 = (lam_t - lam_prev).to(torch.float32)
242
+ r0 = h_0 / h
243
+ d1_0 = (x0_hat - x0_prev) / broadcast_time_like(r0, x0_hat)
244
+ common = broadcast_time_like(a_next, state).to(
245
+ torch.float32
246
+ ) * broadcast_time_like(phi_1, state).to(torch.float32)
247
+ state = (
248
+ sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
249
+ - common * x0_hat
250
+ - 0.5 * common * d1_0
251
+ )
252
+
253
+ x0_prev = x0_hat
254
+
255
+ return state
capacitor_decoder/straight_through_encoder.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PixelUnshuffle-based patchifier (no residual conv path)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class Patchify(nn.Module):
9
+ """PixelUnshuffle(patch) -> Conv2d 1x1 projection.
10
+
11
+ Converts [B, C, H, W] images into [B, out_channels, H/patch, W/patch] features.
12
+ """
13
+
14
+ def __init__(self, in_channels: int, patch: int, out_channels: int) -> None:
15
+ super().__init__()
16
+ self.patch = int(patch)
17
+ self.unshuffle = nn.PixelUnshuffle(self.patch)
18
+ in_after = in_channels * (self.patch * self.patch)
19
+ self.proj = nn.Conv2d(in_after, out_channels, kernel_size=1, bias=True)
20
+
21
+ def forward(self, x: Tensor) -> Tensor:
22
+ if x.shape[2] % self.patch != 0 or x.shape[3] % self.patch != 0:
23
+ raise ValueError(
24
+ f"Input H={x.shape[2]} and W={x.shape[3]} must be divisible by patch={self.patch}"
25
+ )
26
+ y = self.unshuffle(x)
27
+ return self.proj(y)
capacitor_decoder/time_embed.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sinusoidal timestep embedding with MLP projection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+
11
+ def _log_spaced_frequencies(
12
+ half: int, max_period: float, *, device: torch.device | None = None
13
+ ) -> Tensor:
14
+ """Log-spaced frequencies for sinusoidal embedding."""
15
+ return torch.exp(
16
+ -math.log(max_period)
17
+ * torch.arange(half, device=device, dtype=torch.float32)
18
+ / max(float(half - 1), 1.0)
19
+ )
20
+
21
+
22
+ def sinusoidal_time_embedding(
23
+ t: Tensor,
24
+ dim: int,
25
+ *,
26
+ max_period: float = 10000.0,
27
+ scale: float | None = None,
28
+ freqs: Tensor | None = None,
29
+ ) -> Tensor:
30
+ """Sinusoidal timestep embedding (DDPM/DiT-style). Always float32."""
31
+ t32 = t.to(torch.float32)
32
+ if scale is not None:
33
+ t32 = t32 * float(scale)
34
+ half = dim // 2
35
+ if freqs is not None:
36
+ freqs = freqs.to(device=t32.device, dtype=torch.float32)
37
+ else:
38
+ freqs = _log_spaced_frequencies(half, max_period, device=t32.device)
39
+ angles = t32[:, None] * freqs[None, :]
40
+ return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
41
+
42
+
43
+ class SinusoidalTimeEmbeddingMLP(nn.Module):
44
+ """Sinusoidal time embedding followed by Linear -> SiLU -> Linear."""
45
+
46
+ def __init__(
47
+ self,
48
+ dim: int,
49
+ *,
50
+ freq_dim: int = 256,
51
+ hidden_mult: float = 1.0,
52
+ time_scale: float = 1000.0,
53
+ max_period: float = 10000.0,
54
+ ) -> None:
55
+ super().__init__()
56
+ self.dim = int(dim)
57
+ self.freq_dim = int(freq_dim)
58
+ self.time_scale = float(time_scale)
59
+ self.max_period = float(max_period)
60
+ hidden_dim = max(int(round(int(dim) * float(hidden_mult))), 1)
61
+
62
+ freqs = _log_spaced_frequencies(self.freq_dim // 2, self.max_period)
63
+ self.register_buffer("freqs", freqs, persistent=True)
64
+
65
+ self.proj_in = nn.Linear(self.freq_dim, hidden_dim)
66
+ self.act = nn.SiLU()
67
+ self.proj_out = nn.Linear(hidden_dim, self.dim)
68
+
69
+ def forward(self, t: Tensor) -> Tensor:
70
+ freqs: Tensor = self.freqs # type: ignore[assignment]
71
+ emb_freq = sinusoidal_time_embedding(
72
+ t.to(torch.float32),
73
+ self.freq_dim,
74
+ max_period=self.max_period,
75
+ scale=self.time_scale,
76
+ freqs=freqs,
77
+ )
78
+ dtype_in = self.proj_in.weight.dtype
79
+ hidden = self.proj_in(emb_freq.to(dtype_in))
80
+ hidden = self.act(hidden)
81
+ if hidden.dtype != self.proj_out.weight.dtype:
82
+ hidden = hidden.to(self.proj_out.weight.dtype)
83
+ return self.proj_out(hidden)
capacitor_decoder/vp_diffusion.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VP diffusion math: logSNR schedules, alpha/sigma computation, noise construction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+
11
+ def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]:
12
+ """Compute (alpha, sigma) from logSNR in float32.
13
+
14
+ VP constraint: alpha^2 + sigma^2 = 1.
15
+ """
16
+ lmb32 = lmb.to(dtype=torch.float32)
17
+ alpha = torch.sqrt(torch.sigmoid(lmb32))
18
+ sigma = torch.sqrt(torch.sigmoid(-lmb32))
19
+ return alpha, sigma
20
+
21
+
22
+ def broadcast_time_like(coeff: Tensor, x: Tensor) -> Tensor:
23
+ """Broadcast [B] coefficient to match x for per-sample scaling."""
24
+ view_shape = (int(x.shape[0]),) + (1,) * (x.dim() - 1)
25
+ return coeff.view(view_shape)
26
+
27
+
28
+ def _cosine_interpolated_params(
29
+ logsnr_min: float, logsnr_max: float
30
+ ) -> tuple[float, float]:
31
+ """Compute (a, b) for cosine-interpolated logSNR schedule.
32
+
33
+ logsnr(t) = -2 * log(tan(a*t + b))
34
+ logsnr(0) = logsnr_max, logsnr(1) = logsnr_min
35
+ """
36
+ b = math.atan(math.exp(-0.5 * logsnr_max))
37
+ a = math.atan(math.exp(-0.5 * logsnr_min)) - b
38
+ return a, b
39
+
40
+
41
+ def cosine_interpolated_logsnr_from_t(
42
+ t: Tensor, *, logsnr_min: float, logsnr_max: float
43
+ ) -> Tensor:
44
+ """Map t in [0,1] to logSNR via cosine-interpolated schedule. Always float32."""
45
+ a, b = _cosine_interpolated_params(logsnr_min, logsnr_max)
46
+ t32 = t.to(dtype=torch.float32)
47
+ a_t = torch.tensor(a, device=t32.device, dtype=torch.float32)
48
+ b_t = torch.tensor(b, device=t32.device, dtype=torch.float32)
49
+ u = a_t * t32 + b_t
50
+ return -2.0 * torch.log(torch.tan(u))
51
+
52
+
53
+ def shifted_cosine_interpolated_logsnr_from_t(
54
+ t: Tensor,
55
+ *,
56
+ logsnr_min: float,
57
+ logsnr_max: float,
58
+ log_change_high: float = 0.0,
59
+ log_change_low: float = 0.0,
60
+ ) -> Tensor:
61
+ """SiD2 "shifted cosine" schedule: logSNR with resolution-dependent shifts.
62
+
63
+ lambda(t) = (1-t) * (base(t) + log_change_high) + t * (base(t) + log_change_low)
64
+ """
65
+ base = cosine_interpolated_logsnr_from_t(
66
+ t, logsnr_min=logsnr_min, logsnr_max=logsnr_max
67
+ )
68
+ t32 = t.to(dtype=torch.float32)
69
+ high = base + float(log_change_high)
70
+ low = base + float(log_change_low)
71
+ return (1.0 - t32) * high + t32 * low
72
+
73
+
74
+ def get_schedule(schedule_type: str, num_steps: int) -> Tensor:
75
+ """Generate a descending t-schedule in [0, 1] for VP diffusion sampling.
76
+
77
+ ``num_steps`` is the number of function evaluations (NFE = decoder forward
78
+ passes). Internally the schedule has ``num_steps + 1`` time points
79
+ (including both endpoints).
80
+
81
+ Args:
82
+ schedule_type: "linear" or "cosine".
83
+ num_steps: Number of decoder forward passes (NFE), >= 1.
84
+
85
+ Returns:
86
+ Descending 1D tensor with ``num_steps + 1`` elements from ~1.0 to ~0.0.
87
+ """
88
+ # NOTE: the upstream training code (src/ode/time_schedules.py) uses a
89
+ # different convention where num_steps counts schedule *points* (so NFE =
90
+ # num_steps - 1). This export package corrects the off-by-one so that
91
+ # num_steps means NFE directly. TODO: align the upstream convention.
92
+ n = max(int(num_steps) + 1, 2)
93
+ if schedule_type == "linear":
94
+ base = torch.linspace(0.0, 1.0, n)
95
+ elif schedule_type == "cosine":
96
+ i = torch.arange(n, dtype=torch.float32)
97
+ base = 0.5 * (1.0 - torch.cos(math.pi * (i / (n - 1))))
98
+ else:
99
+ raise ValueError(
100
+ f"Unsupported schedule type: {schedule_type!r}. Use 'linear' or 'cosine'."
101
+ )
102
+ # Descending: high t (noisy) -> low t (clean)
103
+ return torch.flip(base, dims=[0])
104
+
105
+
106
+ def make_initial_state(
107
+ *,
108
+ noise: Tensor,
109
+ t_start: Tensor,
110
+ logsnr_min: float,
111
+ logsnr_max: float,
112
+ log_change_high: float = 0.0,
113
+ log_change_low: float = 0.0,
114
+ ) -> Tensor:
115
+ """Construct VP initial state x_t0 = sigma_start * noise (since x0=0).
116
+
117
+ All math in float32.
118
+ """
119
+ batch = int(noise.shape[0])
120
+ lmb_start = shifted_cosine_interpolated_logsnr_from_t(
121
+ t_start.expand(batch).to(dtype=torch.float32),
122
+ logsnr_min=logsnr_min,
123
+ logsnr_max=logsnr_max,
124
+ log_change_high=log_change_high,
125
+ log_change_low=log_change_low,
126
+ )
127
+ _alpha_start, sigma_start = alpha_sigma_from_logsnr(lmb_start)
128
+ sigma_view = broadcast_time_like(sigma_start, noise)
129
+ return sigma_view * noise.to(dtype=torch.float32)
130
+
131
+
132
+ def sample_noise(
133
+ shape: tuple[int, ...],
134
+ *,
135
+ noise_std: float = 1.0,
136
+ seed: int | None = None,
137
+ device: torch.device | None = None,
138
+ dtype: torch.dtype = torch.float32,
139
+ ) -> Tensor:
140
+ """Sample Gaussian noise with optional seeding. CPU-seeded for reproducibility."""
141
+ if seed is None:
142
+ noise = torch.randn(
143
+ shape, device=device or torch.device("cpu"), dtype=torch.float32
144
+ )
145
+ else:
146
+ gen = torch.Generator(device="cpu")
147
+ gen.manual_seed(int(seed))
148
+ noise = torch.randn(shape, generator=gen, device="cpu", dtype=torch.float32)
149
+ noise = noise.mul(float(noise_std))
150
+ target_device = device if device is not None else torch.device("cpu")
151
+ return noise.to(device=target_device, dtype=dtype)
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "in_channels": 3,
3
+ "patch_size": 16,
4
+ "model_dim": 896,
5
+ "decoder_depth": 8,
6
+ "decoder_start_blocks": 2,
7
+ "decoder_end_blocks": 2,
8
+ "bottleneck_dim": 128,
9
+ "mlp_ratio": 4.0,
10
+ "depthwise_kernel_size": 7,
11
+ "adaln_low_rank_rank": 128,
12
+ "logsnr_min": -10.0,
13
+ "logsnr_max": 10.0,
14
+ "pixel_noise_std": 0.558,
15
+ "flux2_batch_norm_eps": 0.0001
16
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a5b67fdf0c67b43123505a30ccbcc0b195e067f0b578f92e5353f343d5359ba
3
+ size 247744864
technical_report_capacitor_decoder.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data-archetype/capacitor_decoder — Technical Report
2
+
3
+ **Capacitor decoder** is a FLUX.2-compatible latent decoder built on the
4
+ [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae)
5
+ architecture.
6
+
7
+ This document only covers the export contract.
8
+
9
+ - Model card: https://huggingface.co/data-archetype/capacitor_decoder
10
+ - SemDisDiffAE technical report: https://huggingface.co/data-archetype/semdisdiffae/blob/main/technical_report_semantic.md
11
+
12
+ ## 1. Latent Interface
13
+
14
+ The exported runtime defaults to the usual FLUX.2 pipeline convention:
15
+ FLUX.2 patchified latents with FLUX.2 BN whitening still applied. The decoder
16
+ first unwhitens those latents using the saved FLUX.2 BN running statistics and
17
+ then decodes them.
18
+
19
+ ## 2. Two Supported Input Modes
20
+
21
+ There are two valid ways to call the export:
22
+
23
+ 1. **Default mode**: upstream produces FLUX.2 patchified + whitened latents.
24
+ Use `decode(latents, ..., latents_are_flux2_whitened=True)`.
25
+ 2. **Raw-latent mode**: upstream produces raw patchified decoder-space latents
26
+ with no FLUX.2 BN whitening. Use
27
+ `decode(latents, ..., latents_are_flux2_whitened=False)`.
28
+
29
+ The important invariant is that whitening must be handled consistently on both
30
+ sides. If whitening is enabled upstream, keep the decoder default. If whitening
31
+ is disabled upstream, disable dewhitening in the decoder too.
32
+
33
+ ## 3. Training
34
+
35
+ This export corresponds to roughly **300k training steps**. The saved run
36
+ configuration uses:
37
+
38
+ | Parameter | Value |
39
+ |---|---|
40
+ | Optimizer | AdamW |
41
+ | AdamW betas | `(0.9, 0.99)` |
42
+ | AdamW epsilon | `1e-8` |
43
+ | Weight decay | `0.0` |
44
+ | Learning rate | `1e-4` |
45
+ | LR schedule | constant after warmup |
46
+ | Warmup steps | `2,000` |
47
+ | Batch size | `128` |
48
+ | Gradient accumulation | `1` |
49
+ | Gradient clip | `1.0` max norm |
50
+ | Precision | AMP bfloat16 |
51
+ | FP32 matmul precision | TF32 |
52
+ | EMA decay | `0.9995` |
53
+ | EMA dtype | FP32 |
54
+ | EMA update cadence | every step |
55
+ | Compilation | `torch.compile` enabled |
56
+ | Validation / checkpoint cadence | every `1,000` steps |
57
+
58
+ ## 4. Links
59
+
60
+ - [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae)
61
+ - [This model card](https://huggingface.co/data-archetype/capacitor_decoder)
62
+ - [SemDisDiffAE technical report](https://huggingface.co/data-archetype/semdisdiffae/blob/main/technical_report_semantic.md)
63
+ - [Results viewer](https://huggingface.co/spaces/data-archetype/capacitor_decoder-results)
64
+
65
+ ## Citation
66
+
67
+ ```bibtex
68
+ @misc{capacitor_decoder,
69
+ title = {Capacitor Decoder: A Faster, Lighter FLUX.2-Compatible Latent Decoder},
70
+ author = {data-archetype},
71
+ email = {data-archetype@proton.me},
72
+ year = {2026},
73
+ month = apr,
74
+ url = {https://huggingface.co/data-archetype/capacitor_decoder},
75
+ }
76
+ ```
77
+