data-archetype commited on
Commit
bdf1427
·
verified ·
1 Parent(s): 7658892

Upload semdisdiffae_p32_v2 private export

Browse files
README.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - diffusion
5
+ - autoencoder
6
+ - image-reconstruction
7
+ - latent-space
8
+ - pytorch
9
+ - fcdm
10
+ library_name: fcdm_diffae
11
+ ---
12
+
13
+ # data-archetype/semdisdiffae_p32_v2
14
+
15
+ **semdisdiffae_p32_v2** is a native patch-32 SemDisDiffAE diffusion autoencoder. It
16
+ keeps the same FCDM decoder family as
17
+ [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae), with an
18
+ 8-block encoder, an 8-block decoder, and a 384-channel spatial latent at
19
+ `H/32 x W/32`.
20
+
21
+ Relative to the original
22
+ [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae), this model
23
+ is optimized for a lower-resolution latent grid and downstream latent diffusion:
24
+ patch size `32` instead of `16`, `384` latent channels instead of `128`, an
25
+ 8-block encoder instead of a 4-block encoder, and DINOv3 ConvNeXt-B semantic
26
+ alignment instead of the original DINO semantic alignment setup.
27
+
28
+ See the original
29
+ [SemDisDiffAE technical report](https://huggingface.co/data-archetype/semdisdiffae/blob/main/technical_report_semantic.md)
30
+ for the shared architecture background.
31
+
32
+ The p32 checkpoint was trained at `384` resolution rather than the original
33
+ `256`-scale recipe. With patch size `32`, this gives a `12x12` latent grid
34
+ instead of `8x8`, reducing the impact of 7x7-convolution border effects during
35
+ training.
36
+
37
+ ## 2k PSNR Benchmark
38
+
39
+ Evaluated on `2000` images, split as `1333` Pexels images and `667` Amazon book
40
+ covers. Reconstruction uses the default 1-step VP/DDIM path in `bfloat16`.
41
+
42
+ | Model | Mean PSNR (dB) | Std (dB) | Median (dB) | P5 (dB) | P95 (dB) |
43
+ |---|---:|---:|---:|---:|---:|
44
+ | semdisdiffae_p32_v2 | `36.06` | `5.47` | `35.80` | `27.63` | `45.02` |
45
+
46
+ ## Throughput
47
+
48
+ Measured on an `NVIDIA GeForce RTX 5090` in `bfloat16`, with `5` warmup batches
49
+ and `20` timed batches. Decode uses the default 1-step sampler with PDG
50
+ disabled.
51
+
52
+ | Operation | Resolution | Batch Size | Mean (ms/batch) | Images/s | Peak Allocated VRAM |
53
+ |---|---:|---:|---:|---:|---:|
54
+ | Encode | `256x256` | `128` | `12.57` | `10186.8` | `574 MiB` |
55
+ | Decode | `256x256` | `128` | `98.93` | `1293.9` | `1042 MiB` |
56
+ | Encode | `512x512` | `32` | `12.08` | `2649.9` | `579 MiB` |
57
+ | Decode | `512x512` | `32` | `100.36` | `318.8` | `1042 MiB` |
58
+
59
+ ## Results Viewer
60
+
61
+ The 39-image visual viewer shows originals, FLUX.2 reconstructions,
62
+ semdisdiffae_p32_v2 reconstructions, deltas, and latent PCA side by side:
63
+ [semdisdiffae_p32_v2 results](https://huggingface.co/spaces/data-archetype/semdisdiffae_p32_v2-results).
64
+
65
+ ## Latent Interface
66
+
67
+ - `encode()` returns whitened latents using the model's saved running statistics.
68
+ - `decode()` expects those whitened latents and dewhitens internally.
69
+ - `whiten()` and `dewhiten()` expose the transform explicitly.
70
+ - `encode_posterior()` returns the raw exported posterior before whitening.
71
+
72
+ Weights are stored in `float32`. The recommended runtime path is `bfloat16` for
73
+ the encoder and decoder, while whitening, dewhitening, posterior moment math,
74
+ VP schedule math, and sampler state updates are kept in `float32`.
75
+
76
+ ## Usage
77
+
78
+ ```python
79
+ import torch
80
+
81
+ from fcdm_diffae import FCDMDiffAE, FCDMDiffAEInferenceConfig
82
+
83
+
84
+ device = "cuda"
85
+ model = FCDMDiffAE.from_pretrained(
86
+ "data-archetype/semdisdiffae_p32_v2",
87
+ device=device,
88
+ dtype=torch.bfloat16,
89
+ )
90
+
91
+ image = ... # [B, 3, H, W] in [-1, 1], H and W divisible by 32
92
+
93
+ with torch.inference_mode():
94
+ latents = model.encode(image.to(device=device, dtype=torch.bfloat16))
95
+ recon = model.decode(
96
+ latents,
97
+ height=int(image.shape[-2]),
98
+ width=int(image.shape[-1]),
99
+ inference_config=FCDMDiffAEInferenceConfig(num_steps=1),
100
+ )
101
+ ```
102
+
103
+ ## Details
104
+
105
+ - Architecture: patch-32 FCDM DiffAE, `156.6M` parameters, `384` latent channels.
106
+ - Encoder / decoder depth: `8` blocks each.
107
+ - Training resolution: `384` AR buckets and `384x384` square crops.
108
+ - Semantic alignment: DINOv3 ConvNeXt-B/LVD1689M, 50/50 MSE plus negative cosine.
109
+ - Posterior: diagonal Gaussian with VP log-SNR parameterization.
110
+ - Export variant: EMA weights.
111
+ - [Technical report](https://huggingface.co/data-archetype/semdisdiffae_p32_v2/blob/main/technical_report_fcdm_diffae.md)
112
+
113
+ ## Citation
114
+
115
+ ```bibtex
116
+ @misc{semdisdiffae_p32_v2,
117
+ title = {SemDisDiffAE p32 v2: a patch-32 FCDM diffusion autoencoder},
118
+ author = {data-archetype},
119
+ email = {data-archetype@proton.me},
120
+ year = {2026},
121
+ month = apr,
122
+ url = {https://huggingface.co/data-archetype/semdisdiffae_p32_v2},
123
+ }
124
+ ```
config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "in_channels": 3,
3
+ "patch_size": 32,
4
+ "model_dim": 1024,
5
+ "encoder_depth": 8,
6
+ "decoder_depth": 8,
7
+ "decoder_start_blocks": 2,
8
+ "decoder_end_blocks": 2,
9
+ "bottleneck_dim": 384,
10
+ "mlp_ratio": 4.0,
11
+ "depthwise_kernel_size": 7,
12
+ "adaln_low_rank_rank": 128,
13
+ "bottleneck_posterior_kind": "diagonal_gaussian",
14
+ "bottleneck_norm_mode": "disabled",
15
+ "logsnr_min": -10.0,
16
+ "logsnr_max": 10.0,
17
+ "pixel_noise_std": 0.558,
18
+ "latent_running_stats_eps": 0.0001
19
+ }
fcdm_diffae/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDMDiffAE: Standalone diffusion autoencoder with FCDM blocks.
2
+
3
+ FCDM DiffAE — a fast diffusion autoencoder with a 128-channel spatial
4
+ bottleneck and a VP-parameterized diagonal Gaussian posterior. Built on FCDM
5
+ (Fully Convolutional Diffusion Model) blocks with GRN and scale+gate AdaLN.
6
+
7
+ Usage::
8
+
9
+ from fcdm_diffae import FCDMDiffAE, FCDMDiffAEInferenceConfig
10
+
11
+ model = FCDMDiffAE.from_pretrained("path/to/weights", device="cuda")
12
+
13
+ # Encode (returns posterior mode by default)
14
+ latents = model.encode(images) # images: [B,3,H,W] in [-1,1]
15
+
16
+ # Decode — PSNR-optimal (1 step, default)
17
+ recon = model.decode(latents, height=H, width=W)
18
+
19
+ # Decode — perceptual sharpness (10 steps + path-drop PDG)
20
+ cfg = FCDMDiffAEInferenceConfig(num_steps=10, pdg=True, pdg_strength=2.0)
21
+ recon = model.decode(latents, height=H, width=W, inference_config=cfg)
22
+ """
23
+
24
+ from .config import FCDMDiffAEConfig, FCDMDiffAEInferenceConfig
25
+ from .encoder import EncoderPosterior
26
+ from .model import FCDMDiffAE
27
+
28
+ __all__ = [
29
+ "EncoderPosterior",
30
+ "FCDMDiffAE",
31
+ "FCDMDiffAEConfig",
32
+ "FCDMDiffAEInferenceConfig",
33
+ ]
fcdm_diffae/adaln.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scale+Gate AdaLN (2-way) for FCDM decoder blocks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class AdaLNScaleGateZeroProjector(nn.Module):
9
+ """Packed 2-way AdaLN projection (SiLU -> Linear), zero-initialized.
10
+
11
+ Outputs [B, 2*d_model] packed as (scale, gate).
12
+ """
13
+
14
+ def __init__(self, d_model: int, d_cond: int) -> None:
15
+ super().__init__()
16
+ self.d_model: int = int(d_model)
17
+ self.d_cond: int = int(d_cond)
18
+ self.act: nn.SiLU = nn.SiLU()
19
+ self.proj: nn.Linear = nn.Linear(self.d_cond, 2 * self.d_model)
20
+ nn.init.zeros_(self.proj.weight)
21
+ nn.init.zeros_(self.proj.bias)
22
+
23
+ def forward_activated(self, act_cond: Tensor) -> Tensor:
24
+ """Return packed modulation for a pre-activated conditioning vector."""
25
+ return self.proj(act_cond)
26
+
27
+ def forward(self, cond: Tensor) -> Tensor:
28
+ """Return packed modulation [B, 2*d_model]."""
29
+ return self.forward_activated(self.act(cond))
30
+
31
+
32
+ class AdaLNScaleGateZeroLowRankDelta(nn.Module):
33
+ """Low-rank delta for 2-way AdaLN: down(d_cond -> rank) -> up(rank -> 2*d_model).
34
+
35
+ Zero-initialized up projection preserves zero-output semantics at init.
36
+ """
37
+
38
+ def __init__(self, *, d_model: int, d_cond: int, rank: int) -> None:
39
+ super().__init__()
40
+ self.d_model: int = int(d_model)
41
+ self.d_cond: int = int(d_cond)
42
+ self.rank: int = int(rank)
43
+ self.down: nn.Linear = nn.Linear(self.d_cond, self.rank, bias=False)
44
+ self.up: nn.Linear = nn.Linear(self.rank, 2 * self.d_model, bias=False)
45
+ nn.init.normal_(self.down.weight, mean=0.0, std=0.02)
46
+ nn.init.zeros_(self.up.weight)
47
+
48
+ def forward(self, act_cond: Tensor) -> Tensor:
49
+ """Return packed delta modulation [B, 2*d_model]."""
50
+ return self.up(self.down(act_cond))
fcdm_diffae/config.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Frozen model architecture and user-tunable inference configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import asdict, dataclass
7
+ from pathlib import Path
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class FCDMDiffAEConfig:
12
+ """Frozen model architecture config stored alongside exported weights."""
13
+
14
+ in_channels: int = 3
15
+ patch_size: int = 16
16
+ model_dim: int = 896
17
+ encoder_depth: int = 8
18
+ decoder_depth: int = 8
19
+ decoder_start_blocks: int = 2
20
+ decoder_end_blocks: int = 2
21
+ bottleneck_dim: int = 128
22
+ mlp_ratio: float = 4.0
23
+ depthwise_kernel_size: int = 7
24
+ adaln_low_rank_rank: int = 128
25
+ # Encoder posterior kind: "diagonal_gaussian" or "deterministic"
26
+ bottleneck_posterior_kind: str = "diagonal_gaussian"
27
+ # Post-bottleneck normalization: "channel_wise" or "disabled"
28
+ bottleneck_norm_mode: str = "disabled"
29
+ # VP diffusion schedule endpoints
30
+ logsnr_min: float = -10.0
31
+ logsnr_max: float = 10.0
32
+ # Pixel-space noise std for VP diffusion initialization
33
+ pixel_noise_std: float = 0.558
34
+ # Per-channel latent whitening epsilon used with running stats
35
+ latent_running_stats_eps: float = 1e-4
36
+
37
+ @property
38
+ def latent_channels(self) -> int:
39
+ """Channel width of the exported latent space."""
40
+ return self.bottleneck_dim
41
+
42
+ @property
43
+ def effective_patch_size(self) -> int:
44
+ """Effective spatial stride from image to latent grid."""
45
+ return self.patch_size
46
+
47
+ def save(self, path: str | Path) -> None:
48
+ """Save config as JSON."""
49
+ p = Path(path)
50
+ p.parent.mkdir(parents=True, exist_ok=True)
51
+ p.write_text(json.dumps(asdict(self), indent=2) + "\n")
52
+
53
+ @classmethod
54
+ def load(cls, path: str | Path) -> FCDMDiffAEConfig:
55
+ """Load config from JSON."""
56
+ data = json.loads(Path(path).read_text())
57
+ return cls(**data)
58
+
59
+
60
+ @dataclass
61
+ class FCDMDiffAEInferenceConfig:
62
+ """User-tunable inference parameters with sensible defaults.
63
+
64
+ PDG (Path-Drop Guidance) sharpens reconstructions by degrading conditioning
65
+ in one pass and amplifying the difference. When enabled, uses 2 NFE per step.
66
+ Recommended: ``pdg=True, pdg_strength=2.0, num_steps=10``.
67
+ """
68
+
69
+ num_steps: int = 1 # number of denoising steps (NFE)
70
+ sampler: str = "ddim" # "ddim" or "dpmpp_2m"
71
+ schedule: str = "linear" # "linear" or "cosine"
72
+ pdg: bool = False # enable PDG for perceptual sharpening
73
+ pdg_strength: float = 2.0 # CFG-like strength when pdg=True
74
+ seed: int | None = None
fcdm_diffae/decoder.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDM DiffAE decoder: skip-concat topology with FCDM blocks and path-drop PDG.
2
+
3
+ No outer RMSNorms (use_other_outer_rms_norms=False during training):
4
+ norm_in, latent_norm, and norm_out are all absent.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+ from torch import Tensor, nn
11
+
12
+ from .adaln import AdaLNScaleGateZeroLowRankDelta, AdaLNScaleGateZeroProjector
13
+ from .fcdm_block import FCDMBlock
14
+ from .straight_through_encoder import Patchify
15
+ from .time_embed import SinusoidalTimeEmbeddingMLP
16
+
17
+
18
+ class Decoder(nn.Module):
19
+ """VP diffusion decoder conditioned on encoder latents and timestep.
20
+
21
+ Architecture (skip-concat, 2+4+2 default):
22
+ Patchify x_t -> Fuse with upsampled z
23
+ -> Start blocks (2) -> Middle blocks (4) -> Skip fuse -> End blocks (2)
24
+ -> Conv1x1 -> PixelShuffle
25
+
26
+ Path-Drop Guidance (PDG) at inference:
27
+ - Replace middle block output with ``path_drop_mask_feature`` to create
28
+ an unconditional prediction, then extrapolate.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ in_channels: int,
34
+ patch_size: int,
35
+ model_dim: int,
36
+ depth: int,
37
+ start_block_count: int,
38
+ end_block_count: int,
39
+ bottleneck_dim: int,
40
+ mlp_ratio: float,
41
+ depthwise_kernel_size: int,
42
+ adaln_low_rank_rank: int,
43
+ ) -> None:
44
+ super().__init__()
45
+ self.patch_size = int(patch_size)
46
+ self.model_dim = int(model_dim)
47
+
48
+ # Input processing (no norm_in)
49
+ self.patchify = Patchify(in_channels, patch_size, model_dim)
50
+
51
+ # Latent conditioning path (no latent_norm)
52
+ self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
53
+ self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
54
+
55
+ # Time embedding
56
+ self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
57
+
58
+ # 2-way AdaLN: shared base projector + per-block low-rank deltas
59
+ self.adaln_base = AdaLNScaleGateZeroProjector(
60
+ d_model=model_dim, d_cond=model_dim
61
+ )
62
+ self.adaln_deltas = nn.ModuleList(
63
+ [
64
+ AdaLNScaleGateZeroLowRankDelta(
65
+ d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
66
+ )
67
+ for _ in range(depth)
68
+ ]
69
+ )
70
+
71
+ # Block layout: start + middle + end
72
+ middle_count = depth - start_block_count - end_block_count
73
+ self._middle_start_idx = start_block_count
74
+ self._end_start_idx = start_block_count + middle_count
75
+
76
+ def _make_blocks(count: int) -> nn.ModuleList:
77
+ return nn.ModuleList(
78
+ [
79
+ FCDMBlock(
80
+ model_dim,
81
+ mlp_ratio,
82
+ depthwise_kernel_size=depthwise_kernel_size,
83
+ use_external_adaln=True,
84
+ )
85
+ for _ in range(count)
86
+ ]
87
+ )
88
+
89
+ self.start_blocks = _make_blocks(start_block_count)
90
+ self.middle_blocks = _make_blocks(middle_count)
91
+ self.fuse_skip = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
92
+ self.end_blocks = _make_blocks(end_block_count)
93
+
94
+ # Learned mask feature for path-drop PDG
95
+ self.path_drop_mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
96
+
97
+ # Output head (no norm_out)
98
+ self.out_proj = nn.Conv2d(
99
+ model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
100
+ )
101
+ self.unpatchify = nn.PixelShuffle(patch_size)
102
+
103
+ def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
104
+ """Compute packed AdaLN modulation = shared_base + per-layer delta."""
105
+ act = self.adaln_base.act(cond)
106
+ base_m = self.adaln_base.forward_activated(act)
107
+ delta_m = self.adaln_deltas[layer_idx](act)
108
+ return base_m + delta_m
109
+
110
+ def _run_blocks(
111
+ self, blocks: nn.ModuleList, x: Tensor, cond: Tensor, start_index: int
112
+ ) -> Tensor:
113
+ """Run a group of decoder blocks with per-block AdaLN modulation."""
114
+ for local_idx, block in enumerate(blocks):
115
+ adaln_m = self._adaln_m_for_layer(cond, layer_idx=start_index + local_idx)
116
+ x = block(x, adaln_m=adaln_m)
117
+ return x
118
+
119
+ def forward(
120
+ self,
121
+ x_t: Tensor,
122
+ t: Tensor,
123
+ latents: Tensor,
124
+ *,
125
+ drop_middle_blocks: bool = False,
126
+ ) -> Tensor:
127
+ """Single decoder forward pass.
128
+
129
+ Args:
130
+ x_t: Noised image [B, C, H, W].
131
+ t: Timestep [B] in [0, 1].
132
+ latents: Encoder latents [B, bottleneck_dim, h, w].
133
+ drop_middle_blocks: Replace middle block output with mask feature (PDG).
134
+
135
+ Returns:
136
+ x0 prediction [B, C, H, W].
137
+ """
138
+ x_feat = self.patchify(x_t)
139
+ z_up = self.latent_up(latents)
140
+
141
+ fused = torch.cat([x_feat, z_up], dim=1)
142
+ fused = self.fuse_in(fused)
143
+
144
+ cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
145
+
146
+ start_out = self._run_blocks(self.start_blocks, fused, cond, start_index=0)
147
+
148
+ if drop_middle_blocks:
149
+ middle_out = self.path_drop_mask_feature.to(
150
+ device=x_t.device, dtype=x_t.dtype
151
+ ).expand_as(start_out)
152
+ else:
153
+ middle_out = self._run_blocks(
154
+ self.middle_blocks,
155
+ start_out,
156
+ cond,
157
+ start_index=self._middle_start_idx,
158
+ )
159
+
160
+ skip_fused = torch.cat([start_out, middle_out], dim=1)
161
+ skip_fused = self.fuse_skip(skip_fused)
162
+
163
+ end_out = self._run_blocks(
164
+ self.end_blocks, skip_fused, cond, start_index=self._end_start_idx
165
+ )
166
+
167
+ patches = self.out_proj(end_out)
168
+ return self.unpatchify(patches)
fcdm_diffae/encoder.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDM DiffAE encoder: patchify -> FCDMBlocks -> diagonal Gaussian posterior.
2
+
3
+ No input RMSNorm (use_other_outer_rms_norms=False during training).
4
+ Post-bottleneck RMSNorm (affine=False) on the mean branch.
5
+ Encoder outputs posterior mode by default: alpha * RMSNorm(mean).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import Tensor, nn
15
+
16
+ from .fcdm_block import FCDMBlock
17
+ from .norms import ChannelWiseRMSNorm
18
+ from .straight_through_encoder import Patchify
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class EncoderPosterior:
23
+ """VP-parameterized diagonal Gaussian posterior.
24
+
25
+ mean: Clean signal branch mu [B, bottleneck_dim, h, w]
26
+ logsnr: Per-element log signal-to-noise ratio [B, bottleneck_dim, h, w]
27
+ """
28
+
29
+ mean: Tensor
30
+ logsnr: Tensor
31
+
32
+ @property
33
+ def alpha(self) -> Tensor:
34
+ """VP signal coefficient computed stably in float32."""
35
+ logsnr_fp32 = self.logsnr.to(torch.float32)
36
+ return torch.exp(0.5 * F.logsigmoid(logsnr_fp32))
37
+
38
+ @property
39
+ def sigma(self) -> Tensor:
40
+ """VP noise coefficient computed stably in float32."""
41
+ logsnr_fp32 = self.logsnr.to(torch.float32)
42
+ return torch.exp(0.5 * F.logsigmoid(-logsnr_fp32))
43
+
44
+ def mode(self) -> Tensor:
45
+ """Posterior mode in token space: alpha * mean, computed in float32."""
46
+ return (self.alpha * self.mean.to(torch.float32)).to(dtype=self.mean.dtype)
47
+
48
+ def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
49
+ """Sample from posterior: alpha * mean + sigma * eps, computed in float32."""
50
+ mean_fp32 = self.mean.to(torch.float32)
51
+ eps = torch.randn(
52
+ mean_fp32.shape,
53
+ device=mean_fp32.device,
54
+ dtype=torch.float32,
55
+ generator=generator,
56
+ )
57
+ return (self.alpha * mean_fp32 + self.sigma * eps).to(dtype=self.mean.dtype)
58
+
59
+
60
+ class Encoder(nn.Module):
61
+ """Encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].
62
+
63
+ With diagonal_gaussian posterior, the to_bottleneck projection outputs
64
+ 2 * bottleneck_dim channels, split into mean and logsnr. The default
65
+ encode() returns the posterior mode: alpha * RMSNorm(mean).
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ in_channels: int,
71
+ patch_size: int,
72
+ model_dim: int,
73
+ depth: int,
74
+ bottleneck_dim: int,
75
+ mlp_ratio: float,
76
+ depthwise_kernel_size: int,
77
+ bottleneck_posterior_kind: str = "diagonal_gaussian",
78
+ bottleneck_norm_mode: str = "disabled",
79
+ ) -> None:
80
+ super().__init__()
81
+ self.bottleneck_dim = int(bottleneck_dim)
82
+ self.bottleneck_posterior_kind = bottleneck_posterior_kind
83
+ self.bottleneck_norm_mode = bottleneck_norm_mode
84
+ self.patchify = Patchify(in_channels, patch_size, model_dim)
85
+ self.blocks = nn.ModuleList(
86
+ [
87
+ FCDMBlock(
88
+ model_dim,
89
+ mlp_ratio,
90
+ depthwise_kernel_size=depthwise_kernel_size,
91
+ use_external_adaln=False,
92
+ )
93
+ for _ in range(depth)
94
+ ]
95
+ )
96
+ out_dim = (
97
+ 2 * bottleneck_dim
98
+ if bottleneck_posterior_kind == "diagonal_gaussian"
99
+ else bottleneck_dim
100
+ )
101
+ self.to_bottleneck = nn.Conv2d(model_dim, out_dim, kernel_size=1, bias=True)
102
+ if bottleneck_norm_mode == "channel_wise":
103
+ self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)
104
+ else:
105
+ self.norm_out = nn.Identity()
106
+
107
+ def encode_posterior(self, images: Tensor) -> EncoderPosterior:
108
+ """Encode images and return the full posterior (mean + logsnr).
109
+
110
+ Only valid when bottleneck_posterior_kind == "diagonal_gaussian".
111
+ """
112
+ z = self.patchify(images)
113
+ for block in self.blocks:
114
+ z = block(z)
115
+ projection = self.to_bottleneck(z)
116
+ mean, logsnr = projection.chunk(2, dim=1)
117
+ mean = self.norm_out(mean)
118
+ return EncoderPosterior(mean=mean, logsnr=logsnr)
119
+
120
+ def forward(self, images: Tensor) -> Tensor:
121
+ """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w].
122
+
123
+ Returns posterior mode (alpha * mean) for diagonal_gaussian,
124
+ or deterministic latents otherwise.
125
+ """
126
+ z = self.patchify(images)
127
+ for block in self.blocks:
128
+ z = block(z)
129
+ projection = self.to_bottleneck(z)
130
+ if self.bottleneck_posterior_kind == "diagonal_gaussian":
131
+ mean, logsnr = projection.chunk(2, dim=1)
132
+ mean = self.norm_out(mean)
133
+ logsnr_fp32 = logsnr.to(torch.float32)
134
+ alpha = torch.exp(0.5 * F.logsigmoid(logsnr_fp32))
135
+ return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype)
136
+ z = self.norm_out(projection)
137
+ return z
fcdm_diffae/fcdm_block.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDM block: ConvNeXt-style conv block with GRN and scale+gate AdaLN."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor, nn
8
+
9
+ from .norms import ChannelWiseRMSNorm
10
+
11
+
12
+ class GRN(nn.Module):
13
+ """Global Response Normalization for NCHW tensors."""
14
+
15
+ def __init__(self, channels: int, *, eps: float = 1e-6) -> None:
16
+ super().__init__()
17
+ self.eps: float = float(eps)
18
+ c = int(channels)
19
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
20
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
21
+
22
+ def forward(self, x: Tensor) -> Tensor:
23
+ g = torch.linalg.vector_norm(x, ord=2, dim=(2, 3), keepdim=True)
24
+ g_fp32 = g.to(dtype=torch.float32)
25
+ n = (g_fp32 / (g_fp32.mean(dim=1, keepdim=True) + self.eps)).to(dtype=x.dtype)
26
+ gamma = self.gamma.to(device=x.device, dtype=x.dtype)
27
+ beta = self.beta.to(device=x.device, dtype=x.dtype)
28
+ return gamma * (x * n) + beta + x
29
+
30
+
31
+ class FCDMBlock(nn.Module):
32
+ """ConvNeXt-style block with scale+gate AdaLN and GRN.
33
+
34
+ Two modes:
35
+ - Unconditioned (encoder): uses learned layer-scale for near-identity init.
36
+ - External AdaLN (decoder): receives packed [B, 2*C] modulation (scale, gate).
37
+ The gate is applied raw (no tanh).
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ channels: int,
43
+ mlp_ratio: float,
44
+ *,
45
+ depthwise_kernel_size: int = 7,
46
+ use_external_adaln: bool = False,
47
+ norm_eps: float = 1e-6,
48
+ layer_scale_init: float = 1e-3,
49
+ ) -> None:
50
+ super().__init__()
51
+ self.channels: int = int(channels)
52
+ self.mlp_ratio: float = float(mlp_ratio)
53
+
54
+ self.dwconv = nn.Conv2d(
55
+ channels,
56
+ channels,
57
+ kernel_size=depthwise_kernel_size,
58
+ padding=depthwise_kernel_size // 2,
59
+ stride=1,
60
+ groups=channels,
61
+ bias=True,
62
+ )
63
+ self.norm = ChannelWiseRMSNorm(channels, eps=float(norm_eps), affine=False)
64
+ hidden = max(int(float(channels) * float(mlp_ratio)), 1)
65
+ self.pwconv1 = nn.Conv2d(channels, hidden, kernel_size=1, bias=True)
66
+ self.grn = GRN(hidden, eps=1e-6)
67
+ self.pwconv2 = nn.Conv2d(hidden, channels, kernel_size=1, bias=True)
68
+
69
+ if not use_external_adaln:
70
+ self.layer_scale = nn.Parameter(
71
+ torch.full((channels,), float(layer_scale_init))
72
+ )
73
+ else:
74
+ self.register_parameter("layer_scale", None)
75
+
76
+ def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor:
77
+ b, c, _, _ = x.shape
78
+
79
+ if adaln_m is not None:
80
+ m = adaln_m.to(device=x.device, dtype=x.dtype)
81
+ scale, gate = m.chunk(2, dim=-1)
82
+ else:
83
+ scale = gate = None
84
+
85
+ h = self.dwconv(x)
86
+ h = self.norm(h)
87
+
88
+ if scale is not None:
89
+ h = h * (1.0 + scale.view(b, c, 1, 1))
90
+
91
+ h = self.pwconv1(h)
92
+ h = F.gelu(h)
93
+ h = self.grn(h)
94
+ h = self.pwconv2(h)
95
+
96
+ if gate is not None:
97
+ gate_view = gate.view(b, c, 1, 1)
98
+ else:
99
+ gate_view = self.layer_scale.view(1, c, 1, 1).to( # type: ignore[union-attr]
100
+ device=h.device, dtype=h.dtype
101
+ )
102
+
103
+ return x + gate_view * h
fcdm_diffae/model.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDMDiffAE: standalone HuggingFace-compatible diffusion autoencoder."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ from torch import Tensor, nn
10
+
11
+ from .config import FCDMDiffAEConfig, FCDMDiffAEInferenceConfig
12
+ from .decoder import Decoder
13
+ from .encoder import Encoder, EncoderPosterior
14
+ from .samplers import run_ddim, run_dpmpp_2m
15
+ from .vp_diffusion import get_schedule, make_initial_state, sample_noise
16
+
17
+
18
+ def _resolve_model_dir(
19
+ path_or_repo_id: str | Path,
20
+ *,
21
+ revision: str | None,
22
+ cache_dir: str | Path | None,
23
+ ) -> Path:
24
+ """Resolve a local path or HuggingFace Hub repo ID to a local directory."""
25
+ local = Path(path_or_repo_id)
26
+ if local.is_dir():
27
+ return local
28
+ repo_id = str(path_or_repo_id)
29
+ try:
30
+ from huggingface_hub import snapshot_download
31
+ except ImportError:
32
+ raise ImportError(
33
+ f"'{repo_id}' is not an existing local directory. "
34
+ "To download from HuggingFace Hub, install huggingface_hub: "
35
+ "pip install huggingface_hub"
36
+ )
37
+ cache_dir_str = str(cache_dir) if cache_dir is not None else None
38
+ local_dir = snapshot_download(
39
+ repo_id,
40
+ revision=revision,
41
+ cache_dir=cache_dir_str,
42
+ )
43
+ return Path(local_dir)
44
+
45
+
46
+ class FCDMDiffAE(nn.Module):
47
+ """Standalone FCDM DiffAE model for HuggingFace distribution.
48
+
49
+ A diffusion autoencoder built on FCDM (Fully Convolutional Diffusion Model)
50
+ blocks. Encodes images to compact 128-channel spatial latents via a
51
+ VP-parameterized diagonal Gaussian posterior, and decodes them back via
52
+ iterative VP diffusion with a skip-concat decoder.
53
+
54
+ Usage::
55
+
56
+ model = FCDMDiffAE.from_pretrained("path/to/weights")
57
+ model = model.to("cuda", dtype=torch.bfloat16)
58
+
59
+ # Encode (returns posterior mode by default)
60
+ latents = model.encode(images) # images: [B,3,H,W] in [-1,1]
61
+
62
+ # Decode (1 step by default — PSNR-optimal)
63
+ recon = model.decode(latents, height=H, width=W)
64
+
65
+ # Reconstruct (encode + 1-step decode)
66
+ recon = model.reconstruct(images)
67
+ """
68
+
69
+ def __init__(self, config: FCDMDiffAEConfig) -> None:
70
+ super().__init__()
71
+ self.config = config
72
+
73
+ # Latent running stats for whitening/dewhitening (at exported latent channels)
74
+ self.register_buffer(
75
+ "latent_norm_running_mean",
76
+ torch.zeros((config.latent_channels,), dtype=torch.float32),
77
+ )
78
+ self.register_buffer(
79
+ "latent_norm_running_var",
80
+ torch.ones((config.latent_channels,), dtype=torch.float32),
81
+ )
82
+
83
+ self.encoder = Encoder(
84
+ in_channels=config.in_channels,
85
+ patch_size=config.patch_size,
86
+ model_dim=config.model_dim,
87
+ depth=config.encoder_depth,
88
+ bottleneck_dim=config.bottleneck_dim,
89
+ mlp_ratio=config.mlp_ratio,
90
+ depthwise_kernel_size=config.depthwise_kernel_size,
91
+ bottleneck_posterior_kind=config.bottleneck_posterior_kind,
92
+ bottleneck_norm_mode=config.bottleneck_norm_mode,
93
+ )
94
+
95
+ self.decoder = Decoder(
96
+ in_channels=config.in_channels,
97
+ patch_size=config.patch_size,
98
+ model_dim=config.model_dim,
99
+ depth=config.decoder_depth,
100
+ start_block_count=config.decoder_start_blocks,
101
+ end_block_count=config.decoder_end_blocks,
102
+ bottleneck_dim=config.bottleneck_dim,
103
+ mlp_ratio=config.mlp_ratio,
104
+ depthwise_kernel_size=config.depthwise_kernel_size,
105
+ adaln_low_rank_rank=config.adaln_low_rank_rank,
106
+ )
107
+
108
+ def _restore_float32_norm_buffers(self) -> None:
109
+ """Keep latent whitening statistics in float32 after dtype/device moves."""
110
+
111
+ self.latent_norm_running_mean = self.latent_norm_running_mean.to(
112
+ dtype=torch.float32
113
+ )
114
+ self.latent_norm_running_var = self.latent_norm_running_var.to(
115
+ dtype=torch.float32
116
+ )
117
+
118
+ def _apply(
119
+ self,
120
+ fn: Callable[[Tensor], Tensor],
121
+ recurse: bool = True,
122
+ ) -> FCDMDiffAE:
123
+ """Apply module tensor transforms while preserving float32 latent stats."""
124
+
125
+ applied = super()._apply(fn, recurse=recurse)
126
+ if not isinstance(applied, FCDMDiffAE):
127
+ raise RuntimeError(
128
+ f"Expected FCDMDiffAE after nn.Module._apply(), got {type(applied).__name__}"
129
+ )
130
+ applied._restore_float32_norm_buffers()
131
+ return applied
132
+
133
+ def to(self, *args: object, **kwargs: object) -> FCDMDiffAE:
134
+ """Move the module while preserving latent whitening stats in float32."""
135
+
136
+ moved = super().to(*args, **kwargs)
137
+ if not isinstance(moved, FCDMDiffAE):
138
+ raise RuntimeError(
139
+ f"Expected FCDMDiffAE after nn.Module.to(), got {type(moved).__name__}"
140
+ )
141
+ moved._restore_float32_norm_buffers()
142
+ return moved
143
+
144
+ @classmethod
145
+ def from_pretrained(
146
+ cls,
147
+ path_or_repo_id: str | Path,
148
+ *,
149
+ dtype: torch.dtype = torch.bfloat16,
150
+ device: str | torch.device = "cpu",
151
+ revision: str | None = None,
152
+ cache_dir: str | Path | None = None,
153
+ ) -> FCDMDiffAE:
154
+ """Load a pretrained model from a local directory or HuggingFace Hub.
155
+
156
+ The directory (or repo) should contain:
157
+ - config.json: Model architecture config.
158
+ - model.safetensors (preferred) or model.pt: Model weights.
159
+
160
+ Args:
161
+ path_or_repo_id: Local directory path or HuggingFace Hub repo ID.
162
+ dtype: Load weights in this dtype (float32 or bfloat16).
163
+ device: Target device.
164
+ revision: Git revision for Hub downloads.
165
+ cache_dir: Where to cache Hub downloads.
166
+
167
+ Returns:
168
+ Loaded model in eval mode.
169
+ """
170
+ model_dir = _resolve_model_dir(
171
+ path_or_repo_id, revision=revision, cache_dir=cache_dir
172
+ )
173
+ config = FCDMDiffAEConfig.load(model_dir / "config.json")
174
+ model = cls(config)
175
+
176
+ safetensors_path = model_dir / "model.safetensors"
177
+ pt_path = model_dir / "model.pt"
178
+
179
+ if safetensors_path.exists():
180
+ try:
181
+ from safetensors.torch import load_file
182
+
183
+ state_dict = load_file(str(safetensors_path), device=str(device))
184
+ except ImportError:
185
+ raise ImportError(
186
+ "safetensors package required to load .safetensors files. "
187
+ "Install with: pip install safetensors"
188
+ )
189
+ elif pt_path.exists():
190
+ state_dict = torch.load(
191
+ str(pt_path), map_location=device, weights_only=True
192
+ )
193
+ else:
194
+ raise FileNotFoundError(
195
+ f"No model weights found in {model_dir}. "
196
+ "Expected model.safetensors or model.pt."
197
+ )
198
+
199
+ model.load_state_dict(state_dict)
200
+ model = model.to(dtype=dtype, device=torch.device(device))
201
+ model.eval()
202
+ return model
203
+
204
+ def _latent_norm_stats(self) -> tuple[Tensor, Tensor]:
205
+ """Return (mean, std) tensors for latent whitening, shaped [1,C,1,1]."""
206
+ mean = self.latent_norm_running_mean.view(1, -1, 1, 1)
207
+ var = self.latent_norm_running_var.view(1, -1, 1, 1)
208
+ std = torch.sqrt(
209
+ var.to(torch.float32) + float(self.config.latent_running_stats_eps)
210
+ )
211
+ return mean.to(torch.float32), std
212
+
213
+ def whiten(self, latents: Tensor) -> Tensor:
214
+ """Whiten encoder latents using per-channel running stats.
215
+
216
+ Use this before passing latents to a downstream latent-space
217
+ diffusion model. The whitened latents have approximately zero mean
218
+ and unit variance per channel.
219
+
220
+ Args:
221
+ latents: [B, bottleneck_dim, h, w] raw encoder output.
222
+
223
+ Returns:
224
+ Whitened latents [B, bottleneck_dim, h, w] in float32.
225
+ """
226
+ z = latents.to(torch.float32)
227
+ mean, std = self._latent_norm_stats()
228
+ return (z - mean.to(device=z.device)) / std.to(device=z.device)
229
+
230
+ def dewhiten(self, latents: Tensor) -> Tensor:
231
+ """Undo whitening to recover raw encoder latent scale.
232
+
233
+ Use this before passing whitened latents back to ``decode()``.
234
+
235
+ Args:
236
+ latents: [B, bottleneck_dim, h, w] whitened latents.
237
+
238
+ Returns:
239
+ Dewhitened latents [B, bottleneck_dim, h, w] in float32.
240
+ """
241
+ z = latents.to(torch.float32)
242
+ mean, std = self._latent_norm_stats()
243
+ return z * std.to(device=z.device) + mean.to(device=z.device)
244
+
245
+ def encode(self, images: Tensor) -> Tensor:
246
+ """Encode images to whitened latents (posterior mode).
247
+
248
+ Returns latents whitened using per-channel running stats, ready for
249
+ use by downstream latent-space diffusion models.
250
+
251
+ Args:
252
+ images: [B, 3, H, W] in [-1, 1], H and W divisible by
253
+ effective_patch_size.
254
+
255
+ Returns:
256
+ Whitened latents [B, latent_channels, H/effective_patch, W/effective_patch].
257
+ """
258
+ eff_patch = self.config.effective_patch_size
259
+ h, w = int(images.shape[2]), int(images.shape[3])
260
+ if h % eff_patch != 0 or w % eff_patch != 0:
261
+ raise ValueError(
262
+ f"Image height={h} and width={w} must be divisible by "
263
+ f"effective_patch_size={eff_patch}"
264
+ )
265
+ try:
266
+ model_dtype = next(self.parameters()).dtype
267
+ except StopIteration:
268
+ model_dtype = torch.float32
269
+ z = self.encoder(images.to(dtype=model_dtype))
270
+ return self.whiten(z).to(dtype=model_dtype)
271
+
272
+ def encode_posterior(self, images: Tensor) -> EncoderPosterior:
273
+ """Encode images and return the full posterior (mean + logsnr).
274
+
275
+ Args:
276
+ images: [B, 3, H, W] in [-1, 1], H and W divisible by
277
+ effective_patch_size.
278
+
279
+ Returns:
280
+ EncoderPosterior with mean and logsnr tensors in the exported
281
+ latent space.
282
+ """
283
+ eff_patch = self.config.effective_patch_size
284
+ h, w = int(images.shape[2]), int(images.shape[3])
285
+ if h % eff_patch != 0 or w % eff_patch != 0:
286
+ raise ValueError(
287
+ f"Image height={h} and width={w} must be divisible by "
288
+ f"effective_patch_size={eff_patch}"
289
+ )
290
+ try:
291
+ model_dtype = next(self.parameters()).dtype
292
+ except StopIteration:
293
+ model_dtype = torch.float32
294
+ posterior = self.encoder.encode_posterior(images.to(dtype=model_dtype))
295
+ return posterior
296
+
297
+ @torch.no_grad()
298
+ def decode(
299
+ self,
300
+ latents: Tensor,
301
+ height: int,
302
+ width: int,
303
+ *,
304
+ inference_config: FCDMDiffAEInferenceConfig | None = None,
305
+ ) -> Tensor:
306
+ """Decode whitened latents to images via VP diffusion.
307
+
308
+ Latents are dewhitened before being passed to the decoder.
309
+
310
+ Args:
311
+ latents: [B, latent_channels, h, w] whitened encoder latents.
312
+ height: Output image height (divisible by effective_patch_size).
313
+ width: Output image width (divisible by effective_patch_size).
314
+ inference_config: Optional inference parameters.
315
+
316
+ Returns:
317
+ Reconstructed images [B, 3, H, W] in float32.
318
+ """
319
+ cfg = inference_config or FCDMDiffAEInferenceConfig()
320
+ config = self.config
321
+ batch = int(latents.shape[0])
322
+ device = latents.device
323
+
324
+ try:
325
+ model_dtype = next(self.parameters()).dtype
326
+ except StopIteration:
327
+ model_dtype = torch.float32
328
+
329
+ eff_patch = config.effective_patch_size
330
+ if height % eff_patch != 0 or width % eff_patch != 0:
331
+ raise ValueError(
332
+ f"height={height} and width={width} must be divisible by "
333
+ f"effective_patch_size={eff_patch}"
334
+ )
335
+
336
+ latents = self.dewhiten(latents)
337
+ latents = latents.to(dtype=model_dtype)
338
+
339
+ shape = (batch, config.in_channels, height, width)
340
+ noise = sample_noise(
341
+ shape,
342
+ noise_std=config.pixel_noise_std,
343
+ seed=cfg.seed,
344
+ device=torch.device("cpu"),
345
+ dtype=torch.float32,
346
+ )
347
+
348
+ schedule = get_schedule(cfg.schedule, cfg.num_steps).to(device=device)
349
+ initial_state = make_initial_state(
350
+ noise=noise.to(device=device),
351
+ t_start=schedule[0:1],
352
+ logsnr_min=config.logsnr_min,
353
+ logsnr_max=config.logsnr_max,
354
+ )
355
+
356
+ device_type = "cuda" if device.type == "cuda" else "cpu"
357
+ with torch.autocast(device_type=device_type, enabled=False):
358
+ latents_in = latents.to(device=device)
359
+
360
+ def _forward_fn(
361
+ x_t: Tensor,
362
+ t: Tensor,
363
+ latents: Tensor,
364
+ *,
365
+ drop_middle_blocks: bool = False,
366
+ mask_latent_tokens: bool = False,
367
+ ) -> Tensor:
368
+ return self.decoder(
369
+ x_t.to(dtype=model_dtype),
370
+ t,
371
+ latents.to(dtype=model_dtype),
372
+ drop_middle_blocks=drop_middle_blocks,
373
+ )
374
+
375
+ pdg_mode = "path_drop" if cfg.pdg else "disabled"
376
+
377
+ if cfg.sampler == "ddim":
378
+ sampler_fn = run_ddim
379
+ elif cfg.sampler == "dpmpp_2m":
380
+ sampler_fn = run_dpmpp_2m
381
+ else:
382
+ raise ValueError(
383
+ f"Unsupported sampler: {cfg.sampler!r}. Use 'ddim' or 'dpmpp_2m'."
384
+ )
385
+
386
+ result = sampler_fn(
387
+ forward_fn=_forward_fn,
388
+ initial_state=initial_state,
389
+ schedule=schedule,
390
+ latents=latents_in,
391
+ logsnr_min=config.logsnr_min,
392
+ logsnr_max=config.logsnr_max,
393
+ pdg_mode=pdg_mode,
394
+ pdg_strength=cfg.pdg_strength,
395
+ device=device,
396
+ )
397
+
398
+ return result
399
+
400
+ @torch.no_grad()
401
+ def reconstruct(
402
+ self,
403
+ images: Tensor,
404
+ *,
405
+ inference_config: FCDMDiffAEInferenceConfig | None = None,
406
+ ) -> Tensor:
407
+ """Encode then decode. Convenience wrapper.
408
+
409
+ Args:
410
+ images: [B, 3, H, W] in [-1, 1].
411
+ inference_config: Optional inference parameters.
412
+
413
+ Returns:
414
+ Reconstructed images [B, 3, H, W] in float32.
415
+ """
416
+ latents = self.encode(images)
417
+ _, _, h, w = images.shape
418
+ return self.decode(
419
+ latents, height=h, width=w, inference_config=inference_config
420
+ )
fcdm_diffae/norms.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Channel-wise RMSNorm for NCHW tensors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+
9
+ class ChannelWiseRMSNorm(nn.Module):
10
+ """Channel-wise RMSNorm with float32 reduction for numerical stability.
11
+
12
+ Normalizes across channels per spatial position. Supports optional
13
+ per-channel affine weight and bias.
14
+ """
15
+
16
+ def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None:
17
+ super().__init__()
18
+ self.channels: int = int(channels)
19
+ self._eps: float = float(eps)
20
+ if affine:
21
+ self.weight = nn.Parameter(torch.ones(self.channels))
22
+ self.bias = nn.Parameter(torch.zeros(self.channels))
23
+ else:
24
+ self.register_parameter("weight", None)
25
+ self.register_parameter("bias", None)
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ if x.dim() < 2:
29
+ return x
30
+ # Float32 accumulation for stability
31
+ ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
32
+ inv_rms = torch.rsqrt(ms + self._eps)
33
+ y = x * inv_rms.to(dtype=x.dtype)
34
+ if self.weight is not None:
35
+ shape = (1, -1) + (1,) * (x.dim() - 2)
36
+ y = y * self.weight.view(shape).to(dtype=x.dtype)
37
+ if self.bias is not None:
38
+ y = y + self.bias.view(shape).to(dtype=x.dtype)
39
+ return y
fcdm_diffae/samplers.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DDIM and DPM++2M samplers for VP diffusion with path-drop PDG support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Protocol
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+ from .vp_diffusion import (
11
+ alpha_sigma_from_logsnr,
12
+ broadcast_time_like,
13
+ shifted_cosine_interpolated_logsnr_from_t,
14
+ )
15
+
16
+
17
+ class DecoderForwardFn(Protocol):
18
+ """Callable that predicts x0 from (x_t, t, latents) with path-drop PDG flag."""
19
+
20
+ def __call__(
21
+ self,
22
+ x_t: Tensor,
23
+ t: Tensor,
24
+ latents: Tensor,
25
+ *,
26
+ drop_middle_blocks: bool = False,
27
+ mask_latent_tokens: bool = False,
28
+ ) -> Tensor: ...
29
+
30
+
31
+ def _reconstruct_eps_from_x0(
32
+ *, x_t: Tensor, x0_hat: Tensor, alpha: Tensor, sigma: Tensor
33
+ ) -> Tensor:
34
+ """Reconstruct eps_hat from (x_t, x0_hat) under VP parameterization.
35
+
36
+ eps_hat = (x_t - alpha * x0_hat) / sigma. All float32.
37
+ """
38
+ alpha_view = broadcast_time_like(alpha, x_t).to(dtype=torch.float32)
39
+ sigma_view = broadcast_time_like(sigma, x_t).to(dtype=torch.float32)
40
+ x_t_f32 = x_t.to(torch.float32)
41
+ x0_f32 = x0_hat.to(torch.float32)
42
+ return (x_t_f32 - alpha_view * x0_f32) / sigma_view
43
+
44
+
45
+ def _ddim_step(
46
+ *,
47
+ x0_hat: Tensor,
48
+ eps_hat: Tensor,
49
+ alpha_next: Tensor,
50
+ sigma_next: Tensor,
51
+ ref: Tensor,
52
+ ) -> Tensor:
53
+ """DDIM step: x_next = alpha_next * x0_hat + sigma_next * eps_hat."""
54
+ a = broadcast_time_like(alpha_next, ref).to(dtype=torch.float32)
55
+ s = broadcast_time_like(sigma_next, ref).to(dtype=torch.float32)
56
+ return a * x0_hat + s * eps_hat
57
+
58
+
59
+ def _predict_with_pdg(
60
+ forward_fn: DecoderForwardFn,
61
+ state: Tensor,
62
+ t_vec: Tensor,
63
+ latents: Tensor,
64
+ *,
65
+ pdg_mode: str,
66
+ pdg_strength: float,
67
+ ) -> Tensor:
68
+ """Run decoder forward with optional PDG guidance.
69
+
70
+ Args:
71
+ forward_fn: Decoder forward function.
72
+ state: Current noised state [B, C, H, W].
73
+ t_vec: Timestep vector [B].
74
+ latents: Encoder latents.
75
+ pdg_mode: "disabled" or "path_drop".
76
+ pdg_strength: CFG-like strength for PDG.
77
+
78
+ Returns:
79
+ x0_hat prediction in float32.
80
+ """
81
+ if pdg_mode == "path_drop":
82
+ x0_uncond = forward_fn(state, t_vec, latents, drop_middle_blocks=True).to(
83
+ torch.float32
84
+ )
85
+ x0_cond = forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
86
+ torch.float32
87
+ )
88
+ return x0_uncond + pdg_strength * (x0_cond - x0_uncond)
89
+ else:
90
+ return forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
91
+ torch.float32
92
+ )
93
+
94
+
95
+ def run_ddim(
96
+ *,
97
+ forward_fn: DecoderForwardFn,
98
+ initial_state: Tensor,
99
+ schedule: Tensor,
100
+ latents: Tensor,
101
+ logsnr_min: float,
102
+ logsnr_max: float,
103
+ log_change_high: float = 0.0,
104
+ log_change_low: float = 0.0,
105
+ pdg_mode: str = "disabled",
106
+ pdg_strength: float = 1.5,
107
+ device: torch.device | None = None,
108
+ ) -> Tensor:
109
+ """Run DDIM sampling loop with path-drop PDG support.
110
+
111
+ Args:
112
+ forward_fn: Decoder forward function (x_t, t, latents) -> x0_hat.
113
+ initial_state: Starting noised state [B, C, H, W] in float32.
114
+ schedule: Descending t-schedule [num_steps] in [0, 1].
115
+ latents: Encoder latents [B, bottleneck_dim, h, w].
116
+ logsnr_min, logsnr_max: VP schedule endpoints.
117
+ log_change_high, log_change_low: Shifted-cosine schedule parameters.
118
+ pdg_mode: "disabled" or "path_drop".
119
+ pdg_strength: CFG-like strength for PDG.
120
+ device: Target device.
121
+
122
+ Returns:
123
+ Denoised samples [B, C, H, W] in float32.
124
+ """
125
+ run_device = device or initial_state.device
126
+ batch_size = int(initial_state.shape[0])
127
+ state = initial_state.to(device=run_device, dtype=torch.float32)
128
+
129
+ # Precompute logSNR, alpha, sigma for all schedule points
130
+ lmb = shifted_cosine_interpolated_logsnr_from_t(
131
+ schedule.to(device=run_device),
132
+ logsnr_min=logsnr_min,
133
+ logsnr_max=logsnr_max,
134
+ log_change_high=log_change_high,
135
+ log_change_low=log_change_low,
136
+ )
137
+ alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
138
+
139
+ for i in range(int(schedule.numel()) - 1):
140
+ t_i = schedule[i]
141
+ a_t = alpha_sched[i].expand(batch_size)
142
+ s_t = sigma_sched[i].expand(batch_size)
143
+ a_next = alpha_sched[i + 1].expand(batch_size)
144
+ s_next = sigma_sched[i + 1].expand(batch_size)
145
+
146
+ # Model prediction with optional PDG
147
+ t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
148
+ x0_hat = _predict_with_pdg(
149
+ forward_fn,
150
+ state,
151
+ t_vec,
152
+ latents,
153
+ pdg_mode=pdg_mode,
154
+ pdg_strength=pdg_strength,
155
+ )
156
+
157
+ eps_hat = _reconstruct_eps_from_x0(
158
+ x_t=state, x0_hat=x0_hat, alpha=a_t, sigma=s_t
159
+ )
160
+ state = _ddim_step(
161
+ x0_hat=x0_hat,
162
+ eps_hat=eps_hat,
163
+ alpha_next=a_next,
164
+ sigma_next=s_next,
165
+ ref=state,
166
+ )
167
+
168
+ return state
169
+
170
+
171
+ def run_dpmpp_2m(
172
+ *,
173
+ forward_fn: DecoderForwardFn,
174
+ initial_state: Tensor,
175
+ schedule: Tensor,
176
+ latents: Tensor,
177
+ logsnr_min: float,
178
+ logsnr_max: float,
179
+ log_change_high: float = 0.0,
180
+ log_change_low: float = 0.0,
181
+ pdg_mode: str = "disabled",
182
+ pdg_strength: float = 1.5,
183
+ device: torch.device | None = None,
184
+ ) -> Tensor:
185
+ """Run DPM++2M sampling loop with path-drop PDG support.
186
+
187
+ Multi-step solver using exponential integrator formulation in half-lambda space.
188
+ """
189
+ run_device = device or initial_state.device
190
+ batch_size = int(initial_state.shape[0])
191
+ state = initial_state.to(device=run_device, dtype=torch.float32)
192
+
193
+ # Precompute logSNR, alpha, sigma, half-lambda for all schedule points
194
+ lmb = shifted_cosine_interpolated_logsnr_from_t(
195
+ schedule.to(device=run_device),
196
+ logsnr_min=logsnr_min,
197
+ logsnr_max=logsnr_max,
198
+ log_change_high=log_change_high,
199
+ log_change_low=log_change_low,
200
+ )
201
+ alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
202
+ half_lambda = 0.5 * lmb.to(torch.float32)
203
+
204
+ x0_prev: Tensor | None = None
205
+
206
+ for i in range(int(schedule.numel()) - 1):
207
+ t_i = schedule[i]
208
+ s_t = sigma_sched[i].expand(batch_size)
209
+ a_next = alpha_sched[i + 1].expand(batch_size)
210
+ s_next = sigma_sched[i + 1].expand(batch_size)
211
+
212
+ # Model prediction with optional PDG
213
+ t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
214
+ x0_hat = _predict_with_pdg(
215
+ forward_fn,
216
+ state,
217
+ t_vec,
218
+ latents,
219
+ pdg_mode=pdg_mode,
220
+ pdg_strength=pdg_strength,
221
+ )
222
+
223
+ lam_t = half_lambda[i].expand(batch_size)
224
+ lam_next = half_lambda[i + 1].expand(batch_size)
225
+ h = (lam_next - lam_t).to(torch.float32)
226
+ phi_1 = torch.expm1(-h)
227
+
228
+ sigma_ratio = (s_next / s_t).to(torch.float32)
229
+
230
+ if i == 0 or x0_prev is None:
231
+ # First-order step
232
+ state = (
233
+ sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
234
+ - broadcast_time_like(a_next, state).to(torch.float32)
235
+ * broadcast_time_like(phi_1, state).to(torch.float32)
236
+ * x0_hat
237
+ )
238
+ else:
239
+ # Second-order step
240
+ lam_prev = half_lambda[i - 1].expand(batch_size)
241
+ h_0 = (lam_t - lam_prev).to(torch.float32)
242
+ r0 = h_0 / h
243
+ d1_0 = (x0_hat - x0_prev) / broadcast_time_like(r0, x0_hat)
244
+ common = broadcast_time_like(a_next, state).to(
245
+ torch.float32
246
+ ) * broadcast_time_like(phi_1, state).to(torch.float32)
247
+ state = (
248
+ sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
249
+ - common * x0_hat
250
+ - 0.5 * common * d1_0
251
+ )
252
+
253
+ x0_prev = x0_hat
254
+
255
+ return state
fcdm_diffae/straight_through_encoder.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PixelUnshuffle-based patchifier (no residual conv path)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class Patchify(nn.Module):
9
+ """PixelUnshuffle(patch) -> Conv2d 1x1 projection.
10
+
11
+ Converts [B, C, H, W] images into [B, out_channels, H/patch, W/patch] features.
12
+ """
13
+
14
+ def __init__(self, in_channels: int, patch: int, out_channels: int) -> None:
15
+ super().__init__()
16
+ self.patch = int(patch)
17
+ self.unshuffle = nn.PixelUnshuffle(self.patch)
18
+ in_after = in_channels * (self.patch * self.patch)
19
+ self.proj = nn.Conv2d(in_after, out_channels, kernel_size=1, bias=True)
20
+
21
+ def forward(self, x: Tensor) -> Tensor:
22
+ if x.shape[2] % self.patch != 0 or x.shape[3] % self.patch != 0:
23
+ raise ValueError(
24
+ f"Input H={x.shape[2]} and W={x.shape[3]} must be divisible by patch={self.patch}"
25
+ )
26
+ y = self.unshuffle(x)
27
+ return self.proj(y)
fcdm_diffae/time_embed.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sinusoidal timestep embedding with MLP projection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+
11
+ def _log_spaced_frequencies(
12
+ half: int, max_period: float, *, device: torch.device | None = None
13
+ ) -> Tensor:
14
+ """Log-spaced frequencies for sinusoidal embedding."""
15
+ return torch.exp(
16
+ -math.log(max_period)
17
+ * torch.arange(half, device=device, dtype=torch.float32)
18
+ / max(float(half - 1), 1.0)
19
+ )
20
+
21
+
22
+ def sinusoidal_time_embedding(
23
+ t: Tensor,
24
+ dim: int,
25
+ *,
26
+ max_period: float = 10000.0,
27
+ scale: float | None = None,
28
+ freqs: Tensor | None = None,
29
+ ) -> Tensor:
30
+ """Sinusoidal timestep embedding (DDPM/DiT-style). Always float32."""
31
+ t32 = t.to(torch.float32)
32
+ if scale is not None:
33
+ t32 = t32 * float(scale)
34
+ half = dim // 2
35
+ if freqs is not None:
36
+ freqs = freqs.to(device=t32.device, dtype=torch.float32)
37
+ else:
38
+ freqs = _log_spaced_frequencies(half, max_period, device=t32.device)
39
+ angles = t32[:, None] * freqs[None, :]
40
+ return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
41
+
42
+
43
+ class SinusoidalTimeEmbeddingMLP(nn.Module):
44
+ """Sinusoidal time embedding followed by Linear -> SiLU -> Linear."""
45
+
46
+ def __init__(
47
+ self,
48
+ dim: int,
49
+ *,
50
+ freq_dim: int = 256,
51
+ hidden_mult: float = 1.0,
52
+ time_scale: float = 1000.0,
53
+ max_period: float = 10000.0,
54
+ ) -> None:
55
+ super().__init__()
56
+ self.dim = int(dim)
57
+ self.freq_dim = int(freq_dim)
58
+ self.time_scale = float(time_scale)
59
+ self.max_period = float(max_period)
60
+ hidden_dim = max(int(round(int(dim) * float(hidden_mult))), 1)
61
+
62
+ freqs = _log_spaced_frequencies(self.freq_dim // 2, self.max_period)
63
+ self.register_buffer("freqs", freqs, persistent=True)
64
+
65
+ self.proj_in = nn.Linear(self.freq_dim, hidden_dim)
66
+ self.act = nn.SiLU()
67
+ self.proj_out = nn.Linear(hidden_dim, self.dim)
68
+
69
+ def forward(self, t: Tensor) -> Tensor:
70
+ freqs: Tensor = self.freqs # type: ignore[assignment]
71
+ emb_freq = sinusoidal_time_embedding(
72
+ t.to(torch.float32),
73
+ self.freq_dim,
74
+ max_period=self.max_period,
75
+ scale=self.time_scale,
76
+ freqs=freqs,
77
+ )
78
+ dtype_in = self.proj_in.weight.dtype
79
+ hidden = self.proj_in(emb_freq.to(dtype_in))
80
+ hidden = self.act(hidden)
81
+ if hidden.dtype != self.proj_out.weight.dtype:
82
+ hidden = hidden.to(self.proj_out.weight.dtype)
83
+ return self.proj_out(hidden)
fcdm_diffae/vp_diffusion.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VP diffusion math: logSNR schedules, alpha/sigma computation, noise construction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+
12
+ def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]:
13
+ """Compute (alpha, sigma) from logSNR in float32.
14
+
15
+ VP constraint: alpha^2 + sigma^2 = 1.
16
+ """
17
+ lmb32 = lmb.to(dtype=torch.float32)
18
+ alpha = torch.exp(0.5 * F.logsigmoid(lmb32))
19
+ sigma = torch.exp(0.5 * F.logsigmoid(-lmb32))
20
+ return alpha, sigma
21
+
22
+
23
+ def broadcast_time_like(coeff: Tensor, x: Tensor) -> Tensor:
24
+ """Broadcast [B] coefficient to match x for per-sample scaling."""
25
+ view_shape = (int(x.shape[0]),) + (1,) * (x.dim() - 1)
26
+ return coeff.view(view_shape)
27
+
28
+
29
+ def _cosine_interpolated_params(
30
+ logsnr_min: float, logsnr_max: float
31
+ ) -> tuple[float, float]:
32
+ """Compute (a, b) for cosine-interpolated logSNR schedule.
33
+
34
+ logsnr(t) = -2 * log(tan(a*t + b))
35
+ logsnr(0) = logsnr_max, logsnr(1) = logsnr_min
36
+ """
37
+ b = math.atan(math.exp(-0.5 * logsnr_max))
38
+ a = math.atan(math.exp(-0.5 * logsnr_min)) - b
39
+ return a, b
40
+
41
+
42
+ def cosine_interpolated_logsnr_from_t(
43
+ t: Tensor, *, logsnr_min: float, logsnr_max: float
44
+ ) -> Tensor:
45
+ """Map t in [0,1] to logSNR via cosine-interpolated schedule. Always float32."""
46
+ a, b = _cosine_interpolated_params(logsnr_min, logsnr_max)
47
+ t32 = t.to(dtype=torch.float32)
48
+ a_t = torch.tensor(a, device=t32.device, dtype=torch.float32)
49
+ b_t = torch.tensor(b, device=t32.device, dtype=torch.float32)
50
+ u = a_t * t32 + b_t
51
+ return -2.0 * torch.log(torch.tan(u))
52
+
53
+
54
+ def shifted_cosine_interpolated_logsnr_from_t(
55
+ t: Tensor,
56
+ *,
57
+ logsnr_min: float,
58
+ logsnr_max: float,
59
+ log_change_high: float = 0.0,
60
+ log_change_low: float = 0.0,
61
+ ) -> Tensor:
62
+ """SiD2 "shifted cosine" schedule: logSNR with resolution-dependent shifts.
63
+
64
+ lambda(t) = (1-t) * (base(t) + log_change_high) + t * (base(t) + log_change_low)
65
+ """
66
+ base = cosine_interpolated_logsnr_from_t(
67
+ t, logsnr_min=logsnr_min, logsnr_max=logsnr_max
68
+ )
69
+ t32 = t.to(dtype=torch.float32)
70
+ high = base + float(log_change_high)
71
+ low = base + float(log_change_low)
72
+ return (1.0 - t32) * high + t32 * low
73
+
74
+
75
+ def get_schedule(schedule_type: str, num_steps: int) -> Tensor:
76
+ """Generate a descending t-schedule in [0, 1] for VP diffusion sampling.
77
+
78
+ ``num_steps`` is the number of function evaluations (NFE = decoder forward
79
+ passes). Internally the schedule has ``num_steps + 1`` time points
80
+ (including both endpoints).
81
+
82
+ Args:
83
+ schedule_type: "linear" or "cosine".
84
+ num_steps: Number of decoder forward passes (NFE), >= 1.
85
+
86
+ Returns:
87
+ Descending 1D tensor with ``num_steps + 1`` elements from ~1.0 to ~0.0.
88
+ """
89
+ # num_steps is the number of decoder evaluations exposed to users.
90
+ # The schedule therefore needs one additional endpoint.
91
+ n = max(int(num_steps) + 1, 2)
92
+ if schedule_type == "linear":
93
+ base = torch.linspace(0.0, 1.0, n)
94
+ elif schedule_type == "cosine":
95
+ i = torch.arange(n, dtype=torch.float32)
96
+ base = 0.5 * (1.0 - torch.cos(math.pi * (i / (n - 1))))
97
+ else:
98
+ raise ValueError(
99
+ f"Unsupported schedule type: {schedule_type!r}. Use 'linear' or 'cosine'."
100
+ )
101
+ # Descending: high t (noisy) -> low t (clean)
102
+ return torch.flip(base, dims=[0])
103
+
104
+
105
+ def make_initial_state(
106
+ *,
107
+ noise: Tensor,
108
+ t_start: Tensor,
109
+ logsnr_min: float,
110
+ logsnr_max: float,
111
+ log_change_high: float = 0.0,
112
+ log_change_low: float = 0.0,
113
+ ) -> Tensor:
114
+ """Construct VP initial state x_t0 = sigma_start * noise (since x0=0).
115
+
116
+ All math in float32.
117
+ """
118
+ batch = int(noise.shape[0])
119
+ lmb_start = shifted_cosine_interpolated_logsnr_from_t(
120
+ t_start.expand(batch).to(dtype=torch.float32),
121
+ logsnr_min=logsnr_min,
122
+ logsnr_max=logsnr_max,
123
+ log_change_high=log_change_high,
124
+ log_change_low=log_change_low,
125
+ )
126
+ _alpha_start, sigma_start = alpha_sigma_from_logsnr(lmb_start)
127
+ sigma_view = broadcast_time_like(sigma_start, noise)
128
+ return sigma_view * noise.to(dtype=torch.float32)
129
+
130
+
131
+ def sample_noise(
132
+ shape: tuple[int, ...],
133
+ *,
134
+ noise_std: float = 1.0,
135
+ seed: int | None = None,
136
+ device: torch.device | None = None,
137
+ dtype: torch.dtype = torch.float32,
138
+ ) -> Tensor:
139
+ """Sample Gaussian noise with optional seeding. CPU-seeded for reproducibility."""
140
+ if seed is None:
141
+ noise = torch.randn(
142
+ shape, device=device or torch.device("cpu"), dtype=torch.float32
143
+ )
144
+ else:
145
+ gen = torch.Generator(device="cpu")
146
+ gen.manual_seed(int(seed))
147
+ noise = torch.randn(shape, generator=gen, device="cpu", dtype=torch.float32)
148
+ noise = noise.mul(float(noise_std))
149
+ target_device = device if device is not None else torch.device("cpu")
150
+ return noise.to(device=target_device, dtype=dtype)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:566aaf108b41ceb4a55fa2c3e4f42321e2962fafba56d5fccd16670383df20b7
3
+ size 626569856
technical_report_fcdm_diffae.md ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # semdisdiffae_p32_v2 Technical Report
2
+
3
+ `semdisdiffae_p32_v2` is a native patch-32 SemDisDiffAE diffusion autoencoder. It
4
+ uses the same FCDM block family as
5
+ [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae), but with a
6
+ native `32x32` pixel token grid, an 8-block encoder, and an 8-block decoder. The
7
+ model is intended as a compact latent autoencoder for downstream latent
8
+ diffusion work.
9
+
10
+ This report focuses on the differences from the original SemDisDiffAE release.
11
+ For the shared FCDM block, VP decoder, stochastic posterior, and general design
12
+ background, see the original
13
+ [SemDisDiffAE technical report](https://huggingface.co/data-archetype/semdisdiffae/blob/main/technical_report_semantic.md).
14
+
15
+ ## Compared With SemDisDiffAE
16
+
17
+ | Component | SemDisDiffAE | semdisdiffae_p32_v2 |
18
+ |---|---:|---:|
19
+ | HF repo | [`data-archetype/semdisdiffae`](https://huggingface.co/data-archetype/semdisdiffae) | [`data-archetype/semdisdiffae_p32_v2`](https://huggingface.co/data-archetype/semdisdiffae_p32_v2) |
20
+ | Patch size | `16` | `32` |
21
+ | Latent grid | `H/16 x W/16` | `H/32 x W/32` |
22
+ | Latent channels | `128` | `384` |
23
+ | Model dim | `896` | `1024` |
24
+ | Encoder blocks | `4` | `8` |
25
+ | Decoder blocks | `8` | `8` |
26
+ | Parameters | `88.8M` | `156.6M` |
27
+ | Semantic teacher | DINOv3 ViT-S/16 LVD1689M, `vit_small_patch16_dinov3.lvd_1689m` | DINOv3 ConvNeXt-B/LVD1689M, `convnext_base.dinov3_lvd1689m` |
28
+ | Semantic loss | negative cosine | 50/50 MSE plus negative cosine |
29
+
30
+ The main change is the native patch-32 latent interface. It reduces the spatial
31
+ token count by `4x` relative to patch-16 while increasing channel width and
32
+ encoder depth. The result is a lower-resolution latent grid intended to be
33
+ easier and cheaper for downstream diffusion models, while still preserving the
34
+ SemDisDiffAE-style VP diffusion decoder and stochastic posterior.
35
+
36
+ ## Architecture
37
+
38
+ | Component | Value |
39
+ |---|---:|
40
+ | Parameters | `156.6M` |
41
+ | Encoder blocks | `8` |
42
+ | Decoder blocks | `8` |
43
+ | Patch size | `32` |
44
+ | Model dim | `1024` |
45
+ | Bottleneck dim | `384` |
46
+ | Spatial compression | `32x` |
47
+ | Posterior | `diagonal_gaussian` |
48
+ | Bottleneck norm | `disabled` |
49
+
50
+ The decoder uses the start / middle / end skip-concat layout with `2` start
51
+ blocks and `2` end blocks. The encoder and decoder both operate natively at
52
+ patch size `32`; this is not a patch-16 model with an additional latent
53
+ downsampling step. The posterior is VP-parameterized as mean plus log-SNR, with
54
+ the exported `encode()` path returning the whitened posterior mode.
55
+
56
+ ## Training
57
+
58
+ The checkpoint exported here is the EMA model at step `400107`. The training
59
+ recipe follows the original SemDisDiffAE objective, with changes for the native
60
+ patch-32 latent grid.
61
+
62
+ ### Data And Resolution
63
+
64
+ Training uses the same bucketed image mixture style as the original
65
+ SemDisDiffAE recipe: photographs and aesthetic web images, posters/stills, and
66
+ book-cover/document-like images. Images are loaded without captions.
67
+
68
+ Unlike the original `256`-scale SemDisDiffAE recipe, p32 was trained with `384`
69
+ aspect-ratio buckets and `384x384` square crops, batch size `128`; validation
70
+ used `384` AR buckets with patch samples disabled. At patch size `32`, `256x256`
71
+ images give only an `8x8` latent grid, so 7x7 depthwise convolutions are
72
+ dominated by border padding. `384x384` gives a `12x12` grid and a better-trained
73
+ center region.
74
+
75
+ ### Losses
76
+
77
+ The active losses are:
78
+
79
+ | Loss | Weight | Description |
80
+ |---|---:|---|
81
+ | Reconstruction | `1.0` | VP `x_pred` reconstruction objective with SiD2-style sigmoid weighting. |
82
+ | Semantic alignment | `0.01` | DINOv3 ConvNeXt-B/LVD1689M alignment on noisy latent tokens. |
83
+ | Latent scale penalty | `1e-4` | Per-channel latent variance regularization toward unit variance. |
84
+ | Posterior variance expansion | `1e-5` | Encourages a non-collapsed stochastic posterior variance. |
85
+
86
+ The auxiliary weights are intentionally small. The native SiD2-style VP
87
+ `x_pred` reconstruction loss has a small effective scale after the log-SNR
88
+ Jacobian and sigmoid weighting are applied; larger auxiliary weights would
89
+ dwarf the reconstruction objective rather than act as regularizers.
90
+
91
+ Semantic alignment uses the DINOv3 ConvNeXt-B/LVD1689M teacher at token stride
92
+ `32`, matching the model's native latent grid. The training-only projection head
93
+ maps noisy student latents to `1024`-dimensional teacher tokens. It is a compact
94
+ DiT-style head: `1x1` input projection from `384` latent channels to model dim
95
+ `1024`, one DiT block with `16` attention heads of head dim `64`, axial RoPE over
96
+ patch-index coordinates, sinusoidal timestep embedding fed through AdaLN-Zero,
97
+ RMSNorm, and a final `1x1` projection to the teacher token width.
98
+
99
+ As in the original SemDisDiffAE alignment setup, alignment is performed on
100
+ noisy latents rather than only on clean latents. The clean latent `z` is linearly
101
+ interpolated with Gaussian noise:
102
+
103
+ ```text
104
+ z_noisy = (1 - tau) * z + tau * eps, eps ~ N(0, I), tau ~ Beta(2, 2)
105
+ ```
106
+
107
+ The p32 semantic loss is exactly the 50/50 mix configured in the checkpoint:
108
+
109
+ ```text
110
+ semantic_loss = 0.5 * (MSE(student_tokens, teacher_tokens)
111
+ + negative_cosine(student_tokens, teacher_tokens))
112
+ ```
113
+
114
+ This differs from the original SemDisDiffAE report, which used the DINOv3
115
+ ViT-S/16 LVD1689M teacher and negative-cosine-only alignment. The projection
116
+ head is training-only and is not included in the exported inference package.
117
+
118
+ ### Optimizer And Schedule
119
+
120
+ | Parameter | Value |
121
+ |---|---:|
122
+ | Optimizer | AdamW |
123
+ | Learning rate | `1e-4` |
124
+ | Betas | `(0.9, 0.99)` |
125
+ | Epsilon | `1e-8` |
126
+ | Weight decay | `0.0` |
127
+ | LR schedule | constant after warmup |
128
+ | Warmup steps | `10000` |
129
+ | Gradient clip | `1.0` |
130
+ | Precision | AMP `bfloat16`, TF32 matmul |
131
+ | EMA decay | `0.9995` |
132
+ | EMA update | every step |
133
+ | Compile | enabled |
134
+
135
+ Latent running statistics are tracked during training with momentum `0.0001`
136
+ and epsilon `0.0001`; those statistics are stored in the export and used for
137
+ the `whiten()` / `dewhiten()` interface.
138
+
139
+ ## Export Contract
140
+
141
+ - `encode()` returns whitened latents using the model's saved running statistics.
142
+ - `decode()` expects those whitened latents and dewhitens internally.
143
+ - `whiten()` and `dewhiten()` expose the transform explicitly.
144
+ - `encode_posterior()` returns the raw exported posterior before whitening.
145
+
146
+ Weights are stored in `float32`. The recommended runtime path is `bfloat16` for
147
+ the encoder and decoder. Whitening, dewhitening, posterior alpha/sigma
148
+ computation, diffusion schedule computation, and sampler state updates are kept
149
+ in `float32`.
150
+
151
+ ## Export Verification
152
+
153
+ The exported EMA checkpoint was checked against the source training checkpoint
154
+ on the same input image. In `float32`, reconstruction mean absolute error was
155
+ `8.95e-5` with max absolute error `7.86e-4`; whitened latent mean absolute
156
+ error was `1.84e-4` with max absolute error `2.91e-3`. In the recommended
157
+ `bfloat16` runtime path, reconstruction mean absolute error was `0.0012`.
158
+
159
+ ## Reconstruction Quality
160
+
161
+ The `2k` benchmark uses a fixed validation list with `1333` Pexels images and
162
+ `667` Amazon book covers. The export uses `bfloat16` inference and the default
163
+ 1-step reconstruction path.
164
+
165
+ | Model | Mean PSNR (dB) | Std (dB) | Median (dB) | P5 (dB) | P95 (dB) |
166
+ |---|---:|---:|---:|---:|---:|
167
+ | semdisdiffae_p32_v2 | `36.06` | `5.47` | `35.80` | `27.63` | `45.02` |
168
+
169
+ The fixed 39-image visual set is exported separately as a results viewer with
170
+ side-by-side originals, FLUX.2 reconstructions, semdisdiffae_p32_v2
171
+ reconstructions, error deltas, and PCA visualizations.
172
+
173
+ ## Throughput
174
+
175
+ Measured on an `NVIDIA GeForce RTX 5090` in `bfloat16`, with `5` warmup batches
176
+ and `20` timed batches. Decode uses the default 1-step sampler with PDG
177
+ disabled.
178
+
179
+ | Operation | Resolution | Batch Size | Mean (ms/batch) | Median (ms/batch) | P95 (ms/batch) | ms/image | Images/s | Peak Allocated VRAM |
180
+ |---|---:|---:|---:|---:|---:|---:|---:|---:|
181
+ | Encode | `256x256` | `128` | `12.57` | `12.41` | `13.27` | `0.098` | `10186.8` | `574 MiB` |
182
+ | Decode | `256x256` | `128` | `98.93` | `99.22` | `100.16` | `0.773` | `1293.9` | `1042 MiB` |
183
+ | Encode | `512x512` | `32` | `12.08` | `11.98` | `12.46` | `0.377` | `2649.9` | `579 MiB` |
184
+ | Decode | `512x512` | `32` | `100.36` | `99.39` | `105.84` | `3.136` | `318.8` | `1042 MiB` |
185
+
186
+ ## VP Stability
187
+
188
+ The VP posterior and diffusion schedule compute alpha/sigma with a stable
189
+ `logsigmoid` formulation:
190
+
191
+ ```text
192
+ alpha = exp(0.5 * logsigmoid(logsnr))
193
+ sigma = exp(0.5 * logsigmoid(-logsnr))
194
+ ```
195
+
196
+ This avoids unstable `sqrt(sigmoid(...))` behavior at extreme log-SNR values.