data-archetype commited on
Commit
720fb6d
Β·
verified Β·
1 Parent(s): dc28b34

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - diffusion
5
+ - autoencoder
6
+ - image-reconstruction
7
+ - image-tokenizer
8
+ - pytorch
9
+ - fcdm
10
+ - semantic-alignment
11
+ library_name: fcdm_diffae
12
+ ---
13
+
14
+ # data-archetype/semdisdiffae_p32
15
+
16
+ ### Version History
17
+
18
+ | Date | Change |
19
+ |------|--------|
20
+ | 2026-04-08 | Initial release |
21
+
22
+ **Experimental patch-32 version** of
23
+ [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae).
24
+
25
+ This model extends the patch-16 SemDisDiffAE with a 2x2 bottleneck
26
+ patchification after the encoder, producing **512-channel latents at H/32 x W/32**
27
+ instead of the base model's 128-channel latents at H/16 x W/16. The decoder
28
+ unpatchifies back to 128ch before reconstruction.
29
+
30
+ See the [patch-16 SemDisDiffAE model card](https://huggingface.co/data-archetype/semdisdiffae)
31
+ and its [technical report](https://huggingface.co/data-archetype/semdisdiffae/blob/main/technical_report_semantic.md)
32
+ for full architectural details. The [p32 technical report](technical_report_p32.md)
33
+ covers only the differences.
34
+
35
+ ## Architecture
36
+
37
+ | Property | p32 (this model) | p16 (base) |
38
+ |----------|-----------------|------------|
39
+ | Latent channels | 512 | 128 |
40
+ | Effective patch | 32 | 16 |
41
+ | Latent grid | H/32 x W/32 | H/16 x W/16 |
42
+ | Encoder patch | 16 (same) | 16 |
43
+ | Bottleneck patchify | 2x2 | none |
44
+ | Parameters | 88.8M (same) | 88.8M |
45
+
46
+ ## Quick Start
47
+
48
+ ```python
49
+ from fcdm_diffae import FCDMDiffAE
50
+
51
+ model = FCDMDiffAE.from_pretrained("data-archetype/semdisdiffae_p32", device="cuda")
52
+
53
+ # Encode β€” returns whitened 512ch latents at H/32 x W/32
54
+ latents = model.encode(images) # [B,3,H,W] in [-1,1] -> [B,512,H/32,W/32]
55
+
56
+ # Decode
57
+ recon = model.decode(latents, height=H, width=W)
58
+
59
+ # Reconstruct
60
+ recon = model.reconstruct(images)
61
+ ```
62
+
63
+ ## Training
64
+
65
+ Same losses and hyperparameters as the base SemDisDiffAE (DINOv2 semantic
66
+ alignment, VP posterior variance expansion, latent scale penalty). Trained
67
+ for 275k steps. See the
68
+ [base model training section](https://huggingface.co/data-archetype/semdisdiffae/blob/main/technical_report_semantic.md#6-training)
69
+ for details.
70
+
71
+ ## Dependencies
72
+
73
+ - PyTorch >= 2.0
74
+ - safetensors
75
+
76
+ ## Citation
77
+
78
+ ```bibtex
79
+ @misc{semdisdiffae,
80
+ title = {SemDisDiffAE: A Semantically Disentangled Diffusion Autoencoder},
81
+ author = {data-archetype},
82
+ email = {data-archetype@proton.me},
83
+ year = {2026},
84
+ month = apr,
85
+ url = {https://huggingface.co/data-archetype/semdisdiffae},
86
+ }
87
+ ```
88
+
89
+ ## License
90
+
91
+ Apache 2.0
config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "in_channels": 3,
3
+ "patch_size": 16,
4
+ "model_dim": 896,
5
+ "encoder_depth": 4,
6
+ "decoder_depth": 8,
7
+ "decoder_start_blocks": 2,
8
+ "decoder_end_blocks": 2,
9
+ "bottleneck_dim": 128,
10
+ "mlp_ratio": 4.0,
11
+ "depthwise_kernel_size": 7,
12
+ "adaln_low_rank_rank": 128,
13
+ "bottleneck_posterior_kind": "diagonal_gaussian",
14
+ "bottleneck_norm_mode": "disabled",
15
+ "bottleneck_patchify_mode": "patch_2x2",
16
+ "logsnr_min": -10.0,
17
+ "logsnr_max": 10.0,
18
+ "pixel_noise_std": 0.558
19
+ }
fcdm_diffae/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDMDiffAE: Standalone diffusion autoencoder with FCDM blocks.
2
+
3
+ FCDM DiffAE β€” a fast diffusion autoencoder with a 128-channel spatial
4
+ bottleneck and a VP-parameterized diagonal Gaussian posterior. Built on FCDM
5
+ (Fully Convolutional Diffusion Model) blocks with GRN and scale+gate AdaLN.
6
+
7
+ Usage::
8
+
9
+ from fcdm_diffae import FCDMDiffAE, FCDMDiffAEInferenceConfig
10
+
11
+ model = FCDMDiffAE.from_pretrained("path/to/weights", device="cuda")
12
+
13
+ # Encode (returns posterior mode by default)
14
+ latents = model.encode(images) # images: [B,3,H,W] in [-1,1]
15
+
16
+ # Decode β€” PSNR-optimal (1 step, default)
17
+ recon = model.decode(latents, height=H, width=W)
18
+
19
+ # Decode β€” perceptual sharpness (10 steps + path-drop PDG)
20
+ cfg = FCDMDiffAEInferenceConfig(num_steps=10, pdg=True, pdg_strength=2.0)
21
+ recon = model.decode(latents, height=H, width=W, inference_config=cfg)
22
+ """
23
+
24
+ from .config import FCDMDiffAEConfig, FCDMDiffAEInferenceConfig
25
+ from .encoder import EncoderPosterior
26
+ from .model import FCDMDiffAE
27
+
28
+ __all__ = [
29
+ "EncoderPosterior",
30
+ "FCDMDiffAE",
31
+ "FCDMDiffAEConfig",
32
+ "FCDMDiffAEInferenceConfig",
33
+ ]
fcdm_diffae/adaln.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scale+Gate AdaLN (2-way) for FCDM decoder blocks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class AdaLNScaleGateZeroProjector(nn.Module):
9
+ """Packed 2-way AdaLN projection (SiLU -> Linear), zero-initialized.
10
+
11
+ Outputs [B, 2*d_model] packed as (scale, gate).
12
+ """
13
+
14
+ def __init__(self, d_model: int, d_cond: int) -> None:
15
+ super().__init__()
16
+ self.d_model: int = int(d_model)
17
+ self.d_cond: int = int(d_cond)
18
+ self.act: nn.SiLU = nn.SiLU()
19
+ self.proj: nn.Linear = nn.Linear(self.d_cond, 2 * self.d_model)
20
+ nn.init.zeros_(self.proj.weight)
21
+ nn.init.zeros_(self.proj.bias)
22
+
23
+ def forward_activated(self, act_cond: Tensor) -> Tensor:
24
+ """Return packed modulation for a pre-activated conditioning vector."""
25
+ return self.proj(act_cond)
26
+
27
+ def forward(self, cond: Tensor) -> Tensor:
28
+ """Return packed modulation [B, 2*d_model]."""
29
+ return self.forward_activated(self.act(cond))
30
+
31
+
32
+ class AdaLNScaleGateZeroLowRankDelta(nn.Module):
33
+ """Low-rank delta for 2-way AdaLN: down(d_cond -> rank) -> up(rank -> 2*d_model).
34
+
35
+ Zero-initialized up projection preserves zero-output semantics at init.
36
+ """
37
+
38
+ def __init__(self, *, d_model: int, d_cond: int, rank: int) -> None:
39
+ super().__init__()
40
+ self.d_model: int = int(d_model)
41
+ self.d_cond: int = int(d_cond)
42
+ self.rank: int = int(rank)
43
+ self.down: nn.Linear = nn.Linear(self.d_cond, self.rank, bias=False)
44
+ self.up: nn.Linear = nn.Linear(self.rank, 2 * self.d_model, bias=False)
45
+ nn.init.normal_(self.down.weight, mean=0.0, std=0.02)
46
+ nn.init.zeros_(self.up.weight)
47
+
48
+ def forward(self, act_cond: Tensor) -> Tensor:
49
+ """Return packed delta modulation [B, 2*d_model]."""
50
+ return self.up(self.down(act_cond))
fcdm_diffae/config.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Frozen model architecture and user-tunable inference configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import asdict, dataclass
7
+ from pathlib import Path
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class FCDMDiffAEConfig:
12
+ """Frozen model architecture config. Stored alongside weights as config.json."""
13
+
14
+ in_channels: int = 3
15
+ patch_size: int = 16
16
+ model_dim: int = 896
17
+ encoder_depth: int = 4
18
+ decoder_depth: int = 8
19
+ decoder_start_blocks: int = 2
20
+ decoder_end_blocks: int = 2
21
+ bottleneck_dim: int = 128
22
+ mlp_ratio: float = 4.0
23
+ depthwise_kernel_size: int = 7
24
+ adaln_low_rank_rank: int = 128
25
+ # Encoder posterior kind: "diagonal_gaussian" or "deterministic"
26
+ bottleneck_posterior_kind: str = "diagonal_gaussian"
27
+ # Post-bottleneck normalization: "channel_wise" or "disabled"
28
+ bottleneck_norm_mode: str = "disabled"
29
+ # Bottleneck patchification: "off" or "patch_2x2"
30
+ # When "patch_2x2", encoder latents are 2x2 patchified after the bottleneck
31
+ # (channels * 4, spatial / 2), and decode unpatchifies before the decoder.
32
+ bottleneck_patchify_mode: str = "off"
33
+ # VP diffusion schedule endpoints
34
+ logsnr_min: float = -10.0
35
+ logsnr_max: float = 10.0
36
+ # Pixel-space noise std for VP diffusion initialization
37
+ pixel_noise_std: float = 0.558
38
+
39
+ @property
40
+ def latent_channels(self) -> int:
41
+ """Channel width of the exported latent space."""
42
+ if self.bottleneck_patchify_mode == "patch_2x2":
43
+ return self.bottleneck_dim * 4
44
+ return self.bottleneck_dim
45
+
46
+ @property
47
+ def effective_patch_size(self) -> int:
48
+ """Effective spatial stride from image to latent grid."""
49
+ if self.bottleneck_patchify_mode == "patch_2x2":
50
+ return self.patch_size * 2
51
+ return self.patch_size
52
+
53
+ def save(self, path: str | Path) -> None:
54
+ """Save config as JSON."""
55
+ p = Path(path)
56
+ p.parent.mkdir(parents=True, exist_ok=True)
57
+ p.write_text(json.dumps(asdict(self), indent=2) + "\n")
58
+
59
+ @classmethod
60
+ def load(cls, path: str | Path) -> FCDMDiffAEConfig:
61
+ """Load config from JSON."""
62
+ data = json.loads(Path(path).read_text())
63
+ return cls(**data)
64
+
65
+
66
+ @dataclass
67
+ class FCDMDiffAEInferenceConfig:
68
+ """User-tunable inference parameters with sensible defaults.
69
+
70
+ PDG (Path-Drop Guidance) sharpens reconstructions by degrading conditioning
71
+ in one pass and amplifying the difference. When enabled, uses 2 NFE per step.
72
+ Recommended: ``pdg=True, pdg_strength=2.0, num_steps=10``.
73
+ """
74
+
75
+ num_steps: int = 1 # number of denoising steps (NFE)
76
+ sampler: str = "ddim" # "ddim" or "dpmpp_2m"
77
+ schedule: str = "linear" # "linear" or "cosine"
78
+ pdg: bool = False # enable PDG for perceptual sharpening
79
+ pdg_strength: float = 2.0 # CFG-like strength when pdg=True
80
+ seed: int | None = None
fcdm_diffae/decoder.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDM DiffAE decoder: skip-concat topology with FCDM blocks and path-drop PDG.
2
+
3
+ No outer RMSNorms (use_other_outer_rms_norms=False during training):
4
+ norm_in, latent_norm, and norm_out are all absent.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+ from torch import Tensor, nn
11
+
12
+ from .adaln import AdaLNScaleGateZeroLowRankDelta, AdaLNScaleGateZeroProjector
13
+ from .fcdm_block import FCDMBlock
14
+ from .straight_through_encoder import Patchify
15
+ from .time_embed import SinusoidalTimeEmbeddingMLP
16
+
17
+
18
+ class Decoder(nn.Module):
19
+ """VP diffusion decoder conditioned on encoder latents and timestep.
20
+
21
+ Architecture (skip-concat, 2+4+2 default):
22
+ Patchify x_t -> Fuse with upsampled z
23
+ -> Start blocks (2) -> Middle blocks (4) -> Skip fuse -> End blocks (2)
24
+ -> Conv1x1 -> PixelShuffle
25
+
26
+ Path-Drop Guidance (PDG) at inference:
27
+ - Replace middle block output with ``path_drop_mask_feature`` to create
28
+ an unconditional prediction, then extrapolate.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ in_channels: int,
34
+ patch_size: int,
35
+ model_dim: int,
36
+ depth: int,
37
+ start_block_count: int,
38
+ end_block_count: int,
39
+ bottleneck_dim: int,
40
+ mlp_ratio: float,
41
+ depthwise_kernel_size: int,
42
+ adaln_low_rank_rank: int,
43
+ ) -> None:
44
+ super().__init__()
45
+ self.patch_size = int(patch_size)
46
+ self.model_dim = int(model_dim)
47
+
48
+ # Input processing (no norm_in)
49
+ self.patchify = Patchify(in_channels, patch_size, model_dim)
50
+
51
+ # Latent conditioning path (no latent_norm)
52
+ self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
53
+ self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
54
+
55
+ # Time embedding
56
+ self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
57
+
58
+ # 2-way AdaLN: shared base projector + per-block low-rank deltas
59
+ self.adaln_base = AdaLNScaleGateZeroProjector(
60
+ d_model=model_dim, d_cond=model_dim
61
+ )
62
+ self.adaln_deltas = nn.ModuleList(
63
+ [
64
+ AdaLNScaleGateZeroLowRankDelta(
65
+ d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
66
+ )
67
+ for _ in range(depth)
68
+ ]
69
+ )
70
+
71
+ # Block layout: start + middle + end
72
+ middle_count = depth - start_block_count - end_block_count
73
+ self._middle_start_idx = start_block_count
74
+ self._end_start_idx = start_block_count + middle_count
75
+
76
+ def _make_blocks(count: int) -> nn.ModuleList:
77
+ return nn.ModuleList(
78
+ [
79
+ FCDMBlock(
80
+ model_dim,
81
+ mlp_ratio,
82
+ depthwise_kernel_size=depthwise_kernel_size,
83
+ use_external_adaln=True,
84
+ )
85
+ for _ in range(count)
86
+ ]
87
+ )
88
+
89
+ self.start_blocks = _make_blocks(start_block_count)
90
+ self.middle_blocks = _make_blocks(middle_count)
91
+ self.fuse_skip = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
92
+ self.end_blocks = _make_blocks(end_block_count)
93
+
94
+ # Learned mask feature for path-drop PDG
95
+ self.path_drop_mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
96
+
97
+ # Output head (no norm_out)
98
+ self.out_proj = nn.Conv2d(
99
+ model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
100
+ )
101
+ self.unpatchify = nn.PixelShuffle(patch_size)
102
+
103
+ def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
104
+ """Compute packed AdaLN modulation = shared_base + per-layer delta."""
105
+ act = self.adaln_base.act(cond)
106
+ base_m = self.adaln_base.forward_activated(act)
107
+ delta_m = self.adaln_deltas[layer_idx](act)
108
+ return base_m + delta_m
109
+
110
+ def _run_blocks(
111
+ self, blocks: nn.ModuleList, x: Tensor, cond: Tensor, start_index: int
112
+ ) -> Tensor:
113
+ """Run a group of decoder blocks with per-block AdaLN modulation."""
114
+ for local_idx, block in enumerate(blocks):
115
+ adaln_m = self._adaln_m_for_layer(cond, layer_idx=start_index + local_idx)
116
+ x = block(x, adaln_m=adaln_m)
117
+ return x
118
+
119
+ def forward(
120
+ self,
121
+ x_t: Tensor,
122
+ t: Tensor,
123
+ latents: Tensor,
124
+ *,
125
+ drop_middle_blocks: bool = False,
126
+ ) -> Tensor:
127
+ """Single decoder forward pass.
128
+
129
+ Args:
130
+ x_t: Noised image [B, C, H, W].
131
+ t: Timestep [B] in [0, 1].
132
+ latents: Encoder latents [B, bottleneck_dim, h, w].
133
+ drop_middle_blocks: Replace middle block output with mask feature (PDG).
134
+
135
+ Returns:
136
+ x0 prediction [B, C, H, W].
137
+ """
138
+ x_feat = self.patchify(x_t)
139
+ z_up = self.latent_up(latents)
140
+
141
+ fused = torch.cat([x_feat, z_up], dim=1)
142
+ fused = self.fuse_in(fused)
143
+
144
+ cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
145
+
146
+ start_out = self._run_blocks(self.start_blocks, fused, cond, start_index=0)
147
+
148
+ if drop_middle_blocks:
149
+ middle_out = self.path_drop_mask_feature.to(
150
+ device=x_t.device, dtype=x_t.dtype
151
+ ).expand_as(start_out)
152
+ else:
153
+ middle_out = self._run_blocks(
154
+ self.middle_blocks,
155
+ start_out,
156
+ cond,
157
+ start_index=self._middle_start_idx,
158
+ )
159
+
160
+ skip_fused = torch.cat([start_out, middle_out], dim=1)
161
+ skip_fused = self.fuse_skip(skip_fused)
162
+
163
+ end_out = self._run_blocks(
164
+ self.end_blocks, skip_fused, cond, start_index=self._end_start_idx
165
+ )
166
+
167
+ patches = self.out_proj(end_out)
168
+ return self.unpatchify(patches)
fcdm_diffae/encoder.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDM DiffAE encoder: patchify -> FCDMBlocks -> diagonal Gaussian posterior.
2
+
3
+ No input RMSNorm (use_other_outer_rms_norms=False during training).
4
+ Post-bottleneck RMSNorm (affine=False) on the mean branch.
5
+ Encoder outputs posterior mode by default: alpha * RMSNorm(mean).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+
12
+ import torch
13
+ from torch import Tensor, nn
14
+
15
+ from .fcdm_block import FCDMBlock
16
+ from .norms import ChannelWiseRMSNorm
17
+ from .straight_through_encoder import Patchify
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class EncoderPosterior:
22
+ """VP-parameterized diagonal Gaussian posterior.
23
+
24
+ mean: Clean signal branch mu [B, bottleneck_dim, h, w]
25
+ logsnr: Per-element log signal-to-noise ratio [B, bottleneck_dim, h, w]
26
+ """
27
+
28
+ mean: Tensor
29
+ logsnr: Tensor
30
+
31
+ @property
32
+ def alpha(self) -> Tensor:
33
+ """VP signal coefficient: sqrt(sigmoid(logsnr))."""
34
+ return torch.sigmoid(self.logsnr).sqrt()
35
+
36
+ @property
37
+ def sigma(self) -> Tensor:
38
+ """VP noise coefficient: sqrt(sigmoid(-logsnr))."""
39
+ return torch.sigmoid(-self.logsnr).sqrt()
40
+
41
+ def mode(self) -> Tensor:
42
+ """Posterior mode in token space: alpha * mean."""
43
+ return self.alpha.to(dtype=self.mean.dtype) * self.mean
44
+
45
+ def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
46
+ """Sample from posterior: alpha * mean + sigma * eps."""
47
+ eps = torch.randn_like(self.mean, generator=generator) # type: ignore[call-overload]
48
+ alpha = self.alpha.to(dtype=self.mean.dtype)
49
+ sigma = self.sigma.to(dtype=self.mean.dtype)
50
+ return alpha * self.mean + sigma * eps
51
+
52
+
53
+ class Encoder(nn.Module):
54
+ """Encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].
55
+
56
+ With diagonal_gaussian posterior, the to_bottleneck projection outputs
57
+ 2 * bottleneck_dim channels, split into mean and logsnr. The default
58
+ encode() returns the posterior mode: alpha * RMSNorm(mean).
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ in_channels: int,
64
+ patch_size: int,
65
+ model_dim: int,
66
+ depth: int,
67
+ bottleneck_dim: int,
68
+ mlp_ratio: float,
69
+ depthwise_kernel_size: int,
70
+ bottleneck_posterior_kind: str = "diagonal_gaussian",
71
+ bottleneck_norm_mode: str = "disabled",
72
+ ) -> None:
73
+ super().__init__()
74
+ self.bottleneck_dim = int(bottleneck_dim)
75
+ self.bottleneck_posterior_kind = bottleneck_posterior_kind
76
+ self.bottleneck_norm_mode = bottleneck_norm_mode
77
+ self.patchify = Patchify(in_channels, patch_size, model_dim)
78
+ self.blocks = nn.ModuleList(
79
+ [
80
+ FCDMBlock(
81
+ model_dim,
82
+ mlp_ratio,
83
+ depthwise_kernel_size=depthwise_kernel_size,
84
+ use_external_adaln=False,
85
+ )
86
+ for _ in range(depth)
87
+ ]
88
+ )
89
+ out_dim = (
90
+ 2 * bottleneck_dim
91
+ if bottleneck_posterior_kind == "diagonal_gaussian"
92
+ else bottleneck_dim
93
+ )
94
+ self.to_bottleneck = nn.Conv2d(model_dim, out_dim, kernel_size=1, bias=True)
95
+ if bottleneck_norm_mode == "channel_wise":
96
+ self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)
97
+ else:
98
+ self.norm_out = nn.Identity()
99
+
100
+ def encode_posterior(self, images: Tensor) -> EncoderPosterior:
101
+ """Encode images and return the full posterior (mean + logsnr).
102
+
103
+ Only valid when bottleneck_posterior_kind == "diagonal_gaussian".
104
+ """
105
+ z = self.patchify(images)
106
+ for block in self.blocks:
107
+ z = block(z)
108
+ projection = self.to_bottleneck(z)
109
+ mean, logsnr = projection.chunk(2, dim=1)
110
+ mean = self.norm_out(mean)
111
+ return EncoderPosterior(mean=mean, logsnr=logsnr)
112
+
113
+ def forward(self, images: Tensor) -> Tensor:
114
+ """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w].
115
+
116
+ Returns posterior mode (alpha * mean) for diagonal_gaussian,
117
+ or deterministic latents otherwise.
118
+ """
119
+ z = self.patchify(images)
120
+ for block in self.blocks:
121
+ z = block(z)
122
+ projection = self.to_bottleneck(z)
123
+ if self.bottleneck_posterior_kind == "diagonal_gaussian":
124
+ mean, logsnr = projection.chunk(2, dim=1)
125
+ mean = self.norm_out(mean)
126
+ alpha = torch.sigmoid(logsnr).sqrt().to(dtype=mean.dtype)
127
+ return alpha * mean
128
+ z = self.norm_out(projection)
129
+ return z
fcdm_diffae/fcdm_block.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDM block: ConvNeXt-style conv block with GRN and scale+gate AdaLN."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor, nn
8
+
9
+ from .norms import ChannelWiseRMSNorm
10
+
11
+
12
+ class GRN(nn.Module):
13
+ """Global Response Normalization for NCHW tensors."""
14
+
15
+ def __init__(self, channels: int, *, eps: float = 1e-6) -> None:
16
+ super().__init__()
17
+ self.eps: float = float(eps)
18
+ c = int(channels)
19
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
20
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
21
+
22
+ def forward(self, x: Tensor) -> Tensor:
23
+ g = torch.linalg.vector_norm(x, ord=2, dim=(2, 3), keepdim=True)
24
+ g_fp32 = g.to(dtype=torch.float32)
25
+ n = (g_fp32 / (g_fp32.mean(dim=1, keepdim=True) + self.eps)).to(dtype=x.dtype)
26
+ gamma = self.gamma.to(device=x.device, dtype=x.dtype)
27
+ beta = self.beta.to(device=x.device, dtype=x.dtype)
28
+ return gamma * (x * n) + beta + x
29
+
30
+
31
+ class FCDMBlock(nn.Module):
32
+ """ConvNeXt-style block with scale+gate AdaLN and GRN.
33
+
34
+ Two modes:
35
+ - Unconditioned (encoder): uses learned layer-scale for near-identity init.
36
+ - External AdaLN (decoder): receives packed [B, 2*C] modulation (scale, gate).
37
+ The gate is applied raw (no tanh).
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ channels: int,
43
+ mlp_ratio: float,
44
+ *,
45
+ depthwise_kernel_size: int = 7,
46
+ use_external_adaln: bool = False,
47
+ norm_eps: float = 1e-6,
48
+ layer_scale_init: float = 1e-3,
49
+ ) -> None:
50
+ super().__init__()
51
+ self.channels: int = int(channels)
52
+ self.mlp_ratio: float = float(mlp_ratio)
53
+
54
+ self.dwconv = nn.Conv2d(
55
+ channels,
56
+ channels,
57
+ kernel_size=depthwise_kernel_size,
58
+ padding=depthwise_kernel_size // 2,
59
+ stride=1,
60
+ groups=channels,
61
+ bias=True,
62
+ )
63
+ self.norm = ChannelWiseRMSNorm(channels, eps=float(norm_eps), affine=False)
64
+ hidden = max(int(float(channels) * float(mlp_ratio)), 1)
65
+ self.pwconv1 = nn.Conv2d(channels, hidden, kernel_size=1, bias=True)
66
+ self.grn = GRN(hidden, eps=1e-6)
67
+ self.pwconv2 = nn.Conv2d(hidden, channels, kernel_size=1, bias=True)
68
+
69
+ if not use_external_adaln:
70
+ self.layer_scale = nn.Parameter(
71
+ torch.full((channels,), float(layer_scale_init))
72
+ )
73
+ else:
74
+ self.register_parameter("layer_scale", None)
75
+
76
+ def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor:
77
+ b, c, _, _ = x.shape
78
+
79
+ if adaln_m is not None:
80
+ m = adaln_m.to(device=x.device, dtype=x.dtype)
81
+ scale, gate = m.chunk(2, dim=-1)
82
+ else:
83
+ scale = gate = None
84
+
85
+ h = self.dwconv(x)
86
+ h = self.norm(h)
87
+
88
+ if scale is not None:
89
+ h = h * (1.0 + scale.view(b, c, 1, 1))
90
+
91
+ h = self.pwconv1(h)
92
+ h = F.gelu(h)
93
+ h = self.grn(h)
94
+ h = self.pwconv2(h)
95
+
96
+ if gate is not None:
97
+ gate_view = gate.view(b, c, 1, 1)
98
+ else:
99
+ gate_view = self.layer_scale.view(1, c, 1, 1).to( # type: ignore[union-attr]
100
+ device=h.device, dtype=h.dtype
101
+ )
102
+
103
+ return x + gate_view * h
fcdm_diffae/model.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDMDiffAE: standalone HuggingFace-compatible diffusion autoencoder."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ from .config import FCDMDiffAEConfig, FCDMDiffAEInferenceConfig
11
+ from .decoder import Decoder
12
+ from .encoder import Encoder, EncoderPosterior
13
+ from .samplers import run_ddim, run_dpmpp_2m
14
+ from .vp_diffusion import get_schedule, make_initial_state, sample_noise
15
+
16
+
17
+ def _resolve_model_dir(
18
+ path_or_repo_id: str | Path,
19
+ *,
20
+ revision: str | None,
21
+ cache_dir: str | Path | None,
22
+ ) -> Path:
23
+ """Resolve a local path or HuggingFace Hub repo ID to a local directory."""
24
+ local = Path(path_or_repo_id)
25
+ if local.is_dir():
26
+ return local
27
+ repo_id = str(path_or_repo_id)
28
+ try:
29
+ from huggingface_hub import snapshot_download
30
+ except ImportError:
31
+ raise ImportError(
32
+ f"'{repo_id}' is not an existing local directory. "
33
+ "To download from HuggingFace Hub, install huggingface_hub: "
34
+ "pip install huggingface_hub"
35
+ )
36
+ cache_dir_str = str(cache_dir) if cache_dir is not None else None
37
+ local_dir = snapshot_download(
38
+ repo_id,
39
+ revision=revision,
40
+ cache_dir=cache_dir_str,
41
+ )
42
+ return Path(local_dir)
43
+
44
+
45
+ class FCDMDiffAE(nn.Module):
46
+ """Standalone FCDM DiffAE model for HuggingFace distribution.
47
+
48
+ A diffusion autoencoder built on FCDM (Fully Convolutional Diffusion Model)
49
+ blocks. Encodes images to compact 128-channel spatial latents via a
50
+ VP-parameterized diagonal Gaussian posterior, and decodes them back via
51
+ iterative VP diffusion with a skip-concat decoder.
52
+
53
+ Usage::
54
+
55
+ model = FCDMDiffAE.from_pretrained("path/to/weights")
56
+ model = model.to("cuda", dtype=torch.bfloat16)
57
+
58
+ # Encode (returns posterior mode by default)
59
+ latents = model.encode(images) # images: [B,3,H,W] in [-1,1]
60
+
61
+ # Decode (1 step by default β€” PSNR-optimal)
62
+ recon = model.decode(latents, height=H, width=W)
63
+
64
+ # Reconstruct (encode + 1-step decode)
65
+ recon = model.reconstruct(images)
66
+ """
67
+
68
+ _LATENT_NORM_EPS: float = 1e-4
69
+
70
+ def __init__(self, config: FCDMDiffAEConfig) -> None:
71
+ super().__init__()
72
+ self.config = config
73
+
74
+ # Latent running stats for whitening/dewhitening (at exported latent channels)
75
+ self.register_buffer(
76
+ "latent_norm_running_mean",
77
+ torch.zeros((config.latent_channels,), dtype=torch.float32),
78
+ )
79
+ self.register_buffer(
80
+ "latent_norm_running_var",
81
+ torch.ones((config.latent_channels,), dtype=torch.float32),
82
+ )
83
+
84
+ self.encoder = Encoder(
85
+ in_channels=config.in_channels,
86
+ patch_size=config.patch_size,
87
+ model_dim=config.model_dim,
88
+ depth=config.encoder_depth,
89
+ bottleneck_dim=config.bottleneck_dim,
90
+ mlp_ratio=config.mlp_ratio,
91
+ depthwise_kernel_size=config.depthwise_kernel_size,
92
+ bottleneck_posterior_kind=config.bottleneck_posterior_kind,
93
+ bottleneck_norm_mode=config.bottleneck_norm_mode,
94
+ )
95
+
96
+ self.decoder = Decoder(
97
+ in_channels=config.in_channels,
98
+ patch_size=config.patch_size,
99
+ model_dim=config.model_dim,
100
+ depth=config.decoder_depth,
101
+ start_block_count=config.decoder_start_blocks,
102
+ end_block_count=config.decoder_end_blocks,
103
+ bottleneck_dim=config.bottleneck_dim,
104
+ mlp_ratio=config.mlp_ratio,
105
+ depthwise_kernel_size=config.depthwise_kernel_size,
106
+ adaln_low_rank_rank=config.adaln_low_rank_rank,
107
+ )
108
+
109
+ @classmethod
110
+ def from_pretrained(
111
+ cls,
112
+ path_or_repo_id: str | Path,
113
+ *,
114
+ dtype: torch.dtype = torch.bfloat16,
115
+ device: str | torch.device = "cpu",
116
+ revision: str | None = None,
117
+ cache_dir: str | Path | None = None,
118
+ ) -> FCDMDiffAE:
119
+ """Load a pretrained model from a local directory or HuggingFace Hub.
120
+
121
+ The directory (or repo) should contain:
122
+ - config.json: Model architecture config.
123
+ - model.safetensors (preferred) or model.pt: Model weights.
124
+
125
+ Args:
126
+ path_or_repo_id: Local directory path or HuggingFace Hub repo ID.
127
+ dtype: Load weights in this dtype (float32 or bfloat16).
128
+ device: Target device.
129
+ revision: Git revision for Hub downloads.
130
+ cache_dir: Where to cache Hub downloads.
131
+
132
+ Returns:
133
+ Loaded model in eval mode.
134
+ """
135
+ model_dir = _resolve_model_dir(
136
+ path_or_repo_id, revision=revision, cache_dir=cache_dir
137
+ )
138
+ config = FCDMDiffAEConfig.load(model_dir / "config.json")
139
+ model = cls(config)
140
+
141
+ safetensors_path = model_dir / "model.safetensors"
142
+ pt_path = model_dir / "model.pt"
143
+
144
+ if safetensors_path.exists():
145
+ try:
146
+ from safetensors.torch import load_file
147
+
148
+ state_dict = load_file(str(safetensors_path), device=str(device))
149
+ except ImportError:
150
+ raise ImportError(
151
+ "safetensors package required to load .safetensors files. "
152
+ "Install with: pip install safetensors"
153
+ )
154
+ elif pt_path.exists():
155
+ state_dict = torch.load(
156
+ str(pt_path), map_location=device, weights_only=True
157
+ )
158
+ else:
159
+ raise FileNotFoundError(
160
+ f"No model weights found in {model_dir}. "
161
+ "Expected model.safetensors or model.pt."
162
+ )
163
+
164
+ model.load_state_dict(state_dict)
165
+ model = model.to(dtype=dtype, device=torch.device(device))
166
+ model.eval()
167
+ return model
168
+
169
+ def _latent_norm_stats(self) -> tuple[Tensor, Tensor]:
170
+ """Return (mean, std) tensors for latent whitening, shaped [1,C,1,1]."""
171
+ mean = self.latent_norm_running_mean.view(1, -1, 1, 1)
172
+ var = self.latent_norm_running_var.view(1, -1, 1, 1)
173
+ std = torch.sqrt(var.to(torch.float32) + self._LATENT_NORM_EPS)
174
+ return mean.to(torch.float32), std
175
+
176
+ def whiten(self, latents: Tensor) -> Tensor:
177
+ """Whiten encoder latents using per-channel running stats.
178
+
179
+ Use this before passing latents to a downstream latent-space
180
+ diffusion model. The whitened latents have approximately zero mean
181
+ and unit variance per channel.
182
+
183
+ Args:
184
+ latents: [B, bottleneck_dim, h, w] raw encoder output.
185
+
186
+ Returns:
187
+ Whitened latents [B, bottleneck_dim, h, w] in float32.
188
+ """
189
+ z = latents.to(torch.float32)
190
+ mean, std = self._latent_norm_stats()
191
+ return (z - mean.to(device=z.device)) / std.to(device=z.device)
192
+
193
+ def dewhiten(self, latents: Tensor) -> Tensor:
194
+ """Undo whitening to recover raw encoder latent scale.
195
+
196
+ Use this before passing whitened latents back to ``decode()``.
197
+
198
+ Args:
199
+ latents: [B, bottleneck_dim, h, w] whitened latents.
200
+
201
+ Returns:
202
+ Dewhitened latents [B, bottleneck_dim, h, w] in float32.
203
+ """
204
+ z = latents.to(torch.float32)
205
+ mean, std = self._latent_norm_stats()
206
+ return z * std.to(device=z.device) + mean.to(device=z.device)
207
+
208
+ def _patchify(self, z: Tensor) -> Tensor:
209
+ """2x2 patchify: [B, C, H, W] -> [B, 4C, H/2, W/2]."""
210
+ b, c, h, w = z.shape
211
+ z = z.reshape(b, c, h // 2, 2, w // 2, 2)
212
+ z = z.permute(0, 1, 3, 5, 2, 4)
213
+ return z.reshape(b, c * 4, h // 2, w // 2)
214
+
215
+ def _unpatchify(self, z: Tensor) -> Tensor:
216
+ """2x2 unpatchify: [B, 4C, H/2, W/2] -> [B, C, H, W]."""
217
+ b, c, h, w = z.shape
218
+ z = z.reshape(b, c // 4, 2, 2, h, w)
219
+ z = z.permute(0, 1, 4, 2, 5, 3)
220
+ return z.reshape(b, c // 4, h * 2, w * 2)
221
+
222
+ def encode(self, images: Tensor) -> Tensor:
223
+ """Encode images to whitened latents (posterior mode).
224
+
225
+ Returns latents whitened using per-channel running stats, ready for
226
+ use by downstream latent-space diffusion models.
227
+
228
+ Args:
229
+ images: [B, 3, H, W] in [-1, 1], H and W divisible by
230
+ effective_patch_size.
231
+
232
+ Returns:
233
+ Whitened latents [B, latent_channels, H/effective_patch, W/effective_patch].
234
+ """
235
+ try:
236
+ model_dtype = next(self.parameters()).dtype
237
+ except StopIteration:
238
+ model_dtype = torch.float32
239
+ z = self.encoder(images.to(dtype=model_dtype))
240
+ if self.config.bottleneck_patchify_mode == "patch_2x2":
241
+ z = self._patchify(z)
242
+ return self.whiten(z).to(dtype=model_dtype)
243
+
244
+ def encode_posterior(self, images: Tensor) -> EncoderPosterior:
245
+ """Encode images and return the full posterior (mean + logsnr).
246
+
247
+ Args:
248
+ images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
249
+
250
+ Returns:
251
+ EncoderPosterior with mean and logsnr tensors.
252
+ """
253
+ try:
254
+ model_dtype = next(self.parameters()).dtype
255
+ except StopIteration:
256
+ model_dtype = torch.float32
257
+ return self.encoder.encode_posterior(images.to(dtype=model_dtype))
258
+
259
+ @torch.no_grad()
260
+ def decode(
261
+ self,
262
+ latents: Tensor,
263
+ height: int,
264
+ width: int,
265
+ *,
266
+ inference_config: FCDMDiffAEInferenceConfig | None = None,
267
+ ) -> Tensor:
268
+ """Decode whitened latents to images via VP diffusion.
269
+
270
+ Latents are dewhitened and (if applicable) unpatchified internally
271
+ before being passed to the decoder.
272
+
273
+ Args:
274
+ latents: [B, latent_channels, h, w] whitened encoder latents.
275
+ height: Output image height (divisible by effective_patch_size).
276
+ width: Output image width (divisible by effective_patch_size).
277
+ inference_config: Optional inference parameters.
278
+
279
+ Returns:
280
+ Reconstructed images [B, 3, H, W] in float32.
281
+ """
282
+ cfg = inference_config or FCDMDiffAEInferenceConfig()
283
+ config = self.config
284
+ batch = int(latents.shape[0])
285
+ device = latents.device
286
+
287
+ try:
288
+ model_dtype = next(self.parameters()).dtype
289
+ except StopIteration:
290
+ model_dtype = torch.float32
291
+
292
+ # Dewhiten and unpatchify back to raw encoder scale for the decoder
293
+ latents = self.dewhiten(latents)
294
+ if config.bottleneck_patchify_mode == "patch_2x2":
295
+ latents = self._unpatchify(latents)
296
+ latents = latents.to(dtype=model_dtype)
297
+
298
+ if height % config.patch_size != 0 or width % config.patch_size != 0:
299
+ raise ValueError(
300
+ f"height={height} and width={width} must be divisible by "
301
+ f"patch_size={config.patch_size}"
302
+ )
303
+
304
+ shape = (batch, config.in_channels, height, width)
305
+ noise = sample_noise(
306
+ shape,
307
+ noise_std=config.pixel_noise_std,
308
+ seed=cfg.seed,
309
+ device=torch.device("cpu"),
310
+ dtype=torch.float32,
311
+ )
312
+
313
+ schedule = get_schedule(cfg.schedule, cfg.num_steps).to(device=device)
314
+ initial_state = make_initial_state(
315
+ noise=noise.to(device=device),
316
+ t_start=schedule[0:1],
317
+ logsnr_min=config.logsnr_min,
318
+ logsnr_max=config.logsnr_max,
319
+ )
320
+
321
+ device_type = "cuda" if device.type == "cuda" else "cpu"
322
+ with torch.autocast(device_type=device_type, enabled=False):
323
+ latents_in = latents.to(device=device)
324
+
325
+ def _forward_fn(
326
+ x_t: Tensor,
327
+ t: Tensor,
328
+ latents: Tensor,
329
+ *,
330
+ drop_middle_blocks: bool = False,
331
+ mask_latent_tokens: bool = False,
332
+ ) -> Tensor:
333
+ return self.decoder(
334
+ x_t.to(dtype=model_dtype),
335
+ t,
336
+ latents.to(dtype=model_dtype),
337
+ drop_middle_blocks=drop_middle_blocks,
338
+ )
339
+
340
+ pdg_mode = "path_drop" if cfg.pdg else "disabled"
341
+
342
+ if cfg.sampler == "ddim":
343
+ sampler_fn = run_ddim
344
+ elif cfg.sampler == "dpmpp_2m":
345
+ sampler_fn = run_dpmpp_2m
346
+ else:
347
+ raise ValueError(
348
+ f"Unsupported sampler: {cfg.sampler!r}. Use 'ddim' or 'dpmpp_2m'."
349
+ )
350
+
351
+ result = sampler_fn(
352
+ forward_fn=_forward_fn,
353
+ initial_state=initial_state,
354
+ schedule=schedule,
355
+ latents=latents_in,
356
+ logsnr_min=config.logsnr_min,
357
+ logsnr_max=config.logsnr_max,
358
+ pdg_mode=pdg_mode,
359
+ pdg_strength=cfg.pdg_strength,
360
+ device=device,
361
+ )
362
+
363
+ return result
364
+
365
+ @torch.no_grad()
366
+ def reconstruct(
367
+ self,
368
+ images: Tensor,
369
+ *,
370
+ inference_config: FCDMDiffAEInferenceConfig | None = None,
371
+ ) -> Tensor:
372
+ """Encode then decode. Convenience wrapper.
373
+
374
+ Args:
375
+ images: [B, 3, H, W] in [-1, 1].
376
+ inference_config: Optional inference parameters.
377
+
378
+ Returns:
379
+ Reconstructed images [B, 3, H, W] in float32.
380
+ """
381
+ latents = self.encode(images)
382
+ _, _, h, w = images.shape
383
+ return self.decode(
384
+ latents, height=h, width=w, inference_config=inference_config
385
+ )
fcdm_diffae/norms.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Channel-wise RMSNorm for NCHW tensors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+
9
+ class ChannelWiseRMSNorm(nn.Module):
10
+ """Channel-wise RMSNorm with float32 reduction for numerical stability.
11
+
12
+ Normalizes across channels per spatial position. Supports optional
13
+ per-channel affine weight and bias.
14
+ """
15
+
16
+ def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None:
17
+ super().__init__()
18
+ self.channels: int = int(channels)
19
+ self._eps: float = float(eps)
20
+ if affine:
21
+ self.weight = nn.Parameter(torch.ones(self.channels))
22
+ self.bias = nn.Parameter(torch.zeros(self.channels))
23
+ else:
24
+ self.register_parameter("weight", None)
25
+ self.register_parameter("bias", None)
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ if x.dim() < 2:
29
+ return x
30
+ # Float32 accumulation for stability
31
+ ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
32
+ inv_rms = torch.rsqrt(ms + self._eps)
33
+ y = x * inv_rms
34
+ if self.weight is not None:
35
+ shape = (1, -1) + (1,) * (x.dim() - 2)
36
+ y = y * self.weight.view(shape).to(dtype=y.dtype)
37
+ if self.bias is not None:
38
+ y = y + self.bias.view(shape).to(dtype=y.dtype)
39
+ return y.to(dtype=x.dtype)
fcdm_diffae/samplers.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DDIM and DPM++2M samplers for VP diffusion with path-drop PDG support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Protocol
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+ from .vp_diffusion import (
11
+ alpha_sigma_from_logsnr,
12
+ broadcast_time_like,
13
+ shifted_cosine_interpolated_logsnr_from_t,
14
+ )
15
+
16
+
17
+ class DecoderForwardFn(Protocol):
18
+ """Callable that predicts x0 from (x_t, t, latents) with path-drop PDG flag."""
19
+
20
+ def __call__(
21
+ self,
22
+ x_t: Tensor,
23
+ t: Tensor,
24
+ latents: Tensor,
25
+ *,
26
+ drop_middle_blocks: bool = False,
27
+ mask_latent_tokens: bool = False,
28
+ ) -> Tensor: ...
29
+
30
+
31
+ def _reconstruct_eps_from_x0(
32
+ *, x_t: Tensor, x0_hat: Tensor, alpha: Tensor, sigma: Tensor
33
+ ) -> Tensor:
34
+ """Reconstruct eps_hat from (x_t, x0_hat) under VP parameterization.
35
+
36
+ eps_hat = (x_t - alpha * x0_hat) / sigma. All float32.
37
+ """
38
+ alpha_view = broadcast_time_like(alpha, x_t).to(dtype=torch.float32)
39
+ sigma_view = broadcast_time_like(sigma, x_t).to(dtype=torch.float32)
40
+ x_t_f32 = x_t.to(torch.float32)
41
+ x0_f32 = x0_hat.to(torch.float32)
42
+ return (x_t_f32 - alpha_view * x0_f32) / sigma_view
43
+
44
+
45
+ def _ddim_step(
46
+ *,
47
+ x0_hat: Tensor,
48
+ eps_hat: Tensor,
49
+ alpha_next: Tensor,
50
+ sigma_next: Tensor,
51
+ ref: Tensor,
52
+ ) -> Tensor:
53
+ """DDIM step: x_next = alpha_next * x0_hat + sigma_next * eps_hat."""
54
+ a = broadcast_time_like(alpha_next, ref).to(dtype=torch.float32)
55
+ s = broadcast_time_like(sigma_next, ref).to(dtype=torch.float32)
56
+ return a * x0_hat + s * eps_hat
57
+
58
+
59
+ def _predict_with_pdg(
60
+ forward_fn: DecoderForwardFn,
61
+ state: Tensor,
62
+ t_vec: Tensor,
63
+ latents: Tensor,
64
+ *,
65
+ pdg_mode: str,
66
+ pdg_strength: float,
67
+ ) -> Tensor:
68
+ """Run decoder forward with optional PDG guidance.
69
+
70
+ Args:
71
+ forward_fn: Decoder forward function.
72
+ state: Current noised state [B, C, H, W].
73
+ t_vec: Timestep vector [B].
74
+ latents: Encoder latents.
75
+ pdg_mode: "disabled" or "path_drop".
76
+ pdg_strength: CFG-like strength for PDG.
77
+
78
+ Returns:
79
+ x0_hat prediction in float32.
80
+ """
81
+ if pdg_mode == "path_drop":
82
+ x0_uncond = forward_fn(state, t_vec, latents, drop_middle_blocks=True).to(
83
+ torch.float32
84
+ )
85
+ x0_cond = forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
86
+ torch.float32
87
+ )
88
+ return x0_uncond + pdg_strength * (x0_cond - x0_uncond)
89
+ else:
90
+ return forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
91
+ torch.float32
92
+ )
93
+
94
+
95
+ def run_ddim(
96
+ *,
97
+ forward_fn: DecoderForwardFn,
98
+ initial_state: Tensor,
99
+ schedule: Tensor,
100
+ latents: Tensor,
101
+ logsnr_min: float,
102
+ logsnr_max: float,
103
+ log_change_high: float = 0.0,
104
+ log_change_low: float = 0.0,
105
+ pdg_mode: str = "disabled",
106
+ pdg_strength: float = 1.5,
107
+ device: torch.device | None = None,
108
+ ) -> Tensor:
109
+ """Run DDIM sampling loop with path-drop PDG support.
110
+
111
+ Args:
112
+ forward_fn: Decoder forward function (x_t, t, latents) -> x0_hat.
113
+ initial_state: Starting noised state [B, C, H, W] in float32.
114
+ schedule: Descending t-schedule [num_steps] in [0, 1].
115
+ latents: Encoder latents [B, bottleneck_dim, h, w].
116
+ logsnr_min, logsnr_max: VP schedule endpoints.
117
+ log_change_high, log_change_low: Shifted-cosine schedule parameters.
118
+ pdg_mode: "disabled" or "path_drop".
119
+ pdg_strength: CFG-like strength for PDG.
120
+ device: Target device.
121
+
122
+ Returns:
123
+ Denoised samples [B, C, H, W] in float32.
124
+ """
125
+ run_device = device or initial_state.device
126
+ batch_size = int(initial_state.shape[0])
127
+ state = initial_state.to(device=run_device, dtype=torch.float32)
128
+
129
+ # Precompute logSNR, alpha, sigma for all schedule points
130
+ lmb = shifted_cosine_interpolated_logsnr_from_t(
131
+ schedule.to(device=run_device),
132
+ logsnr_min=logsnr_min,
133
+ logsnr_max=logsnr_max,
134
+ log_change_high=log_change_high,
135
+ log_change_low=log_change_low,
136
+ )
137
+ alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
138
+
139
+ for i in range(int(schedule.numel()) - 1):
140
+ t_i = schedule[i]
141
+ a_t = alpha_sched[i].expand(batch_size)
142
+ s_t = sigma_sched[i].expand(batch_size)
143
+ a_next = alpha_sched[i + 1].expand(batch_size)
144
+ s_next = sigma_sched[i + 1].expand(batch_size)
145
+
146
+ # Model prediction with optional PDG
147
+ t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
148
+ x0_hat = _predict_with_pdg(
149
+ forward_fn,
150
+ state,
151
+ t_vec,
152
+ latents,
153
+ pdg_mode=pdg_mode,
154
+ pdg_strength=pdg_strength,
155
+ )
156
+
157
+ eps_hat = _reconstruct_eps_from_x0(
158
+ x_t=state, x0_hat=x0_hat, alpha=a_t, sigma=s_t
159
+ )
160
+ state = _ddim_step(
161
+ x0_hat=x0_hat,
162
+ eps_hat=eps_hat,
163
+ alpha_next=a_next,
164
+ sigma_next=s_next,
165
+ ref=state,
166
+ )
167
+
168
+ return state
169
+
170
+
171
+ def run_dpmpp_2m(
172
+ *,
173
+ forward_fn: DecoderForwardFn,
174
+ initial_state: Tensor,
175
+ schedule: Tensor,
176
+ latents: Tensor,
177
+ logsnr_min: float,
178
+ logsnr_max: float,
179
+ log_change_high: float = 0.0,
180
+ log_change_low: float = 0.0,
181
+ pdg_mode: str = "disabled",
182
+ pdg_strength: float = 1.5,
183
+ device: torch.device | None = None,
184
+ ) -> Tensor:
185
+ """Run DPM++2M sampling loop with path-drop PDG support.
186
+
187
+ Multi-step solver using exponential integrator formulation in half-lambda space.
188
+ """
189
+ run_device = device or initial_state.device
190
+ batch_size = int(initial_state.shape[0])
191
+ state = initial_state.to(device=run_device, dtype=torch.float32)
192
+
193
+ # Precompute logSNR, alpha, sigma, half-lambda for all schedule points
194
+ lmb = shifted_cosine_interpolated_logsnr_from_t(
195
+ schedule.to(device=run_device),
196
+ logsnr_min=logsnr_min,
197
+ logsnr_max=logsnr_max,
198
+ log_change_high=log_change_high,
199
+ log_change_low=log_change_low,
200
+ )
201
+ alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
202
+ half_lambda = 0.5 * lmb.to(torch.float32)
203
+
204
+ x0_prev: Tensor | None = None
205
+
206
+ for i in range(int(schedule.numel()) - 1):
207
+ t_i = schedule[i]
208
+ s_t = sigma_sched[i].expand(batch_size)
209
+ a_next = alpha_sched[i + 1].expand(batch_size)
210
+ s_next = sigma_sched[i + 1].expand(batch_size)
211
+
212
+ # Model prediction with optional PDG
213
+ t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
214
+ x0_hat = _predict_with_pdg(
215
+ forward_fn,
216
+ state,
217
+ t_vec,
218
+ latents,
219
+ pdg_mode=pdg_mode,
220
+ pdg_strength=pdg_strength,
221
+ )
222
+
223
+ lam_t = half_lambda[i].expand(batch_size)
224
+ lam_next = half_lambda[i + 1].expand(batch_size)
225
+ h = (lam_next - lam_t).to(torch.float32)
226
+ phi_1 = torch.expm1(-h)
227
+
228
+ sigma_ratio = (s_next / s_t).to(torch.float32)
229
+
230
+ if i == 0 or x0_prev is None:
231
+ # First-order step
232
+ state = (
233
+ sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
234
+ - broadcast_time_like(a_next, state).to(torch.float32)
235
+ * broadcast_time_like(phi_1, state).to(torch.float32)
236
+ * x0_hat
237
+ )
238
+ else:
239
+ # Second-order step
240
+ lam_prev = half_lambda[i - 1].expand(batch_size)
241
+ h_0 = (lam_t - lam_prev).to(torch.float32)
242
+ r0 = h_0 / h
243
+ d1_0 = (x0_hat - x0_prev) / broadcast_time_like(r0, x0_hat)
244
+ common = broadcast_time_like(a_next, state).to(
245
+ torch.float32
246
+ ) * broadcast_time_like(phi_1, state).to(torch.float32)
247
+ state = (
248
+ sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
249
+ - common * x0_hat
250
+ - 0.5 * common * d1_0
251
+ )
252
+
253
+ x0_prev = x0_hat
254
+
255
+ return state
fcdm_diffae/straight_through_encoder.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PixelUnshuffle-based patchifier (no residual conv path)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class Patchify(nn.Module):
9
+ """PixelUnshuffle(patch) -> Conv2d 1x1 projection.
10
+
11
+ Converts [B, C, H, W] images into [B, out_channels, H/patch, W/patch] features.
12
+ """
13
+
14
+ def __init__(self, in_channels: int, patch: int, out_channels: int) -> None:
15
+ super().__init__()
16
+ self.patch = int(patch)
17
+ self.unshuffle = nn.PixelUnshuffle(self.patch)
18
+ in_after = in_channels * (self.patch * self.patch)
19
+ self.proj = nn.Conv2d(in_after, out_channels, kernel_size=1, bias=True)
20
+
21
+ def forward(self, x: Tensor) -> Tensor:
22
+ if x.shape[2] % self.patch != 0 or x.shape[3] % self.patch != 0:
23
+ raise ValueError(
24
+ f"Input H={x.shape[2]} and W={x.shape[3]} must be divisible by patch={self.patch}"
25
+ )
26
+ y = self.unshuffle(x)
27
+ return self.proj(y)
fcdm_diffae/time_embed.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sinusoidal timestep embedding with MLP projection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+
11
+ def _log_spaced_frequencies(
12
+ half: int, max_period: float, *, device: torch.device | None = None
13
+ ) -> Tensor:
14
+ """Log-spaced frequencies for sinusoidal embedding."""
15
+ return torch.exp(
16
+ -math.log(max_period)
17
+ * torch.arange(half, device=device, dtype=torch.float32)
18
+ / max(float(half - 1), 1.0)
19
+ )
20
+
21
+
22
+ def sinusoidal_time_embedding(
23
+ t: Tensor,
24
+ dim: int,
25
+ *,
26
+ max_period: float = 10000.0,
27
+ scale: float | None = None,
28
+ freqs: Tensor | None = None,
29
+ ) -> Tensor:
30
+ """Sinusoidal timestep embedding (DDPM/DiT-style). Always float32."""
31
+ t32 = t.to(torch.float32)
32
+ if scale is not None:
33
+ t32 = t32 * float(scale)
34
+ half = dim // 2
35
+ if freqs is not None:
36
+ freqs = freqs.to(device=t32.device, dtype=torch.float32)
37
+ else:
38
+ freqs = _log_spaced_frequencies(half, max_period, device=t32.device)
39
+ angles = t32[:, None] * freqs[None, :]
40
+ return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
41
+
42
+
43
+ class SinusoidalTimeEmbeddingMLP(nn.Module):
44
+ """Sinusoidal time embedding followed by Linear -> SiLU -> Linear."""
45
+
46
+ def __init__(
47
+ self,
48
+ dim: int,
49
+ *,
50
+ freq_dim: int = 256,
51
+ hidden_mult: float = 1.0,
52
+ time_scale: float = 1000.0,
53
+ max_period: float = 10000.0,
54
+ ) -> None:
55
+ super().__init__()
56
+ self.dim = int(dim)
57
+ self.freq_dim = int(freq_dim)
58
+ self.time_scale = float(time_scale)
59
+ self.max_period = float(max_period)
60
+ hidden_dim = max(int(round(int(dim) * float(hidden_mult))), 1)
61
+
62
+ freqs = _log_spaced_frequencies(self.freq_dim // 2, self.max_period)
63
+ self.register_buffer("freqs", freqs, persistent=True)
64
+
65
+ self.proj_in = nn.Linear(self.freq_dim, hidden_dim)
66
+ self.act = nn.SiLU()
67
+ self.proj_out = nn.Linear(hidden_dim, self.dim)
68
+
69
+ def forward(self, t: Tensor) -> Tensor:
70
+ freqs: Tensor = self.freqs # type: ignore[assignment]
71
+ emb_freq = sinusoidal_time_embedding(
72
+ t.to(torch.float32),
73
+ self.freq_dim,
74
+ max_period=self.max_period,
75
+ scale=self.time_scale,
76
+ freqs=freqs,
77
+ )
78
+ dtype_in = self.proj_in.weight.dtype
79
+ hidden = self.proj_in(emb_freq.to(dtype_in))
80
+ hidden = self.act(hidden)
81
+ if hidden.dtype != self.proj_out.weight.dtype:
82
+ hidden = hidden.to(self.proj_out.weight.dtype)
83
+ return self.proj_out(hidden)
fcdm_diffae/vp_diffusion.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VP diffusion math: logSNR schedules, alpha/sigma computation, noise construction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+
11
+ def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]:
12
+ """Compute (alpha, sigma) from logSNR in float32.
13
+
14
+ VP constraint: alpha^2 + sigma^2 = 1.
15
+ """
16
+ lmb32 = lmb.to(dtype=torch.float32)
17
+ alpha = torch.sqrt(torch.sigmoid(lmb32))
18
+ sigma = torch.sqrt(torch.sigmoid(-lmb32))
19
+ return alpha, sigma
20
+
21
+
22
+ def broadcast_time_like(coeff: Tensor, x: Tensor) -> Tensor:
23
+ """Broadcast [B] coefficient to match x for per-sample scaling."""
24
+ view_shape = (int(x.shape[0]),) + (1,) * (x.dim() - 1)
25
+ return coeff.view(view_shape)
26
+
27
+
28
+ def _cosine_interpolated_params(
29
+ logsnr_min: float, logsnr_max: float
30
+ ) -> tuple[float, float]:
31
+ """Compute (a, b) for cosine-interpolated logSNR schedule.
32
+
33
+ logsnr(t) = -2 * log(tan(a*t + b))
34
+ logsnr(0) = logsnr_max, logsnr(1) = logsnr_min
35
+ """
36
+ b = math.atan(math.exp(-0.5 * logsnr_max))
37
+ a = math.atan(math.exp(-0.5 * logsnr_min)) - b
38
+ return a, b
39
+
40
+
41
+ def cosine_interpolated_logsnr_from_t(
42
+ t: Tensor, *, logsnr_min: float, logsnr_max: float
43
+ ) -> Tensor:
44
+ """Map t in [0,1] to logSNR via cosine-interpolated schedule. Always float32."""
45
+ a, b = _cosine_interpolated_params(logsnr_min, logsnr_max)
46
+ t32 = t.to(dtype=torch.float32)
47
+ a_t = torch.tensor(a, device=t32.device, dtype=torch.float32)
48
+ b_t = torch.tensor(b, device=t32.device, dtype=torch.float32)
49
+ u = a_t * t32 + b_t
50
+ return -2.0 * torch.log(torch.tan(u))
51
+
52
+
53
+ def shifted_cosine_interpolated_logsnr_from_t(
54
+ t: Tensor,
55
+ *,
56
+ logsnr_min: float,
57
+ logsnr_max: float,
58
+ log_change_high: float = 0.0,
59
+ log_change_low: float = 0.0,
60
+ ) -> Tensor:
61
+ """SiD2 "shifted cosine" schedule: logSNR with resolution-dependent shifts.
62
+
63
+ lambda(t) = (1-t) * (base(t) + log_change_high) + t * (base(t) + log_change_low)
64
+ """
65
+ base = cosine_interpolated_logsnr_from_t(
66
+ t, logsnr_min=logsnr_min, logsnr_max=logsnr_max
67
+ )
68
+ t32 = t.to(dtype=torch.float32)
69
+ high = base + float(log_change_high)
70
+ low = base + float(log_change_low)
71
+ return (1.0 - t32) * high + t32 * low
72
+
73
+
74
+ def get_schedule(schedule_type: str, num_steps: int) -> Tensor:
75
+ """Generate a descending t-schedule in [0, 1] for VP diffusion sampling.
76
+
77
+ ``num_steps`` is the number of function evaluations (NFE = decoder forward
78
+ passes). Internally the schedule has ``num_steps + 1`` time points
79
+ (including both endpoints).
80
+
81
+ Args:
82
+ schedule_type: "linear" or "cosine".
83
+ num_steps: Number of decoder forward passes (NFE), >= 1.
84
+
85
+ Returns:
86
+ Descending 1D tensor with ``num_steps + 1`` elements from ~1.0 to ~0.0.
87
+ """
88
+ # NOTE: the upstream training code (src/ode/time_schedules.py) uses a
89
+ # different convention where num_steps counts schedule *points* (so NFE =
90
+ # num_steps - 1). This export package corrects the off-by-one so that
91
+ # num_steps means NFE directly. TODO: align the upstream convention.
92
+ n = max(int(num_steps) + 1, 2)
93
+ if schedule_type == "linear":
94
+ base = torch.linspace(0.0, 1.0, n)
95
+ elif schedule_type == "cosine":
96
+ i = torch.arange(n, dtype=torch.float32)
97
+ base = 0.5 * (1.0 - torch.cos(math.pi * (i / (n - 1))))
98
+ else:
99
+ raise ValueError(
100
+ f"Unsupported schedule type: {schedule_type!r}. Use 'linear' or 'cosine'."
101
+ )
102
+ # Descending: high t (noisy) -> low t (clean)
103
+ return torch.flip(base, dims=[0])
104
+
105
+
106
+ def make_initial_state(
107
+ *,
108
+ noise: Tensor,
109
+ t_start: Tensor,
110
+ logsnr_min: float,
111
+ logsnr_max: float,
112
+ log_change_high: float = 0.0,
113
+ log_change_low: float = 0.0,
114
+ ) -> Tensor:
115
+ """Construct VP initial state x_t0 = sigma_start * noise (since x0=0).
116
+
117
+ All math in float32.
118
+ """
119
+ batch = int(noise.shape[0])
120
+ lmb_start = shifted_cosine_interpolated_logsnr_from_t(
121
+ t_start.expand(batch).to(dtype=torch.float32),
122
+ logsnr_min=logsnr_min,
123
+ logsnr_max=logsnr_max,
124
+ log_change_high=log_change_high,
125
+ log_change_low=log_change_low,
126
+ )
127
+ _alpha_start, sigma_start = alpha_sigma_from_logsnr(lmb_start)
128
+ sigma_view = broadcast_time_like(sigma_start, noise)
129
+ return sigma_view * noise.to(dtype=torch.float32)
130
+
131
+
132
+ def sample_noise(
133
+ shape: tuple[int, ...],
134
+ *,
135
+ noise_std: float = 1.0,
136
+ seed: int | None = None,
137
+ device: torch.device | None = None,
138
+ dtype: torch.dtype = torch.float32,
139
+ ) -> Tensor:
140
+ """Sample Gaussian noise with optional seeding. CPU-seeded for reproducibility."""
141
+ if seed is None:
142
+ noise = torch.randn(
143
+ shape, device=device or torch.device("cpu"), dtype=torch.float32
144
+ )
145
+ else:
146
+ gen = torch.Generator(device="cpu")
147
+ gen.manual_seed(int(seed))
148
+ noise = torch.randn(shape, generator=gen, device="cpu", dtype=torch.float32)
149
+ noise = noise.mul(float(noise_std))
150
+ target_device = device if device is not None else torch.device("cpu")
151
+ return noise.to(device=target_device, dtype=dtype)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:955027ab9ff5fe48382ee163e9025d8d0185a988c3b4e43e0b5eea3c16bc4e11
3
+ size 355104632
technical_report_p32.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SemDisDiffAE p32 β€” Technical Report
2
+
3
+ **Experimental patch-32 variant** of
4
+ [SemDisDiffAE](https://huggingface.co/data-archetype/semdisdiffae).
5
+
6
+ See the [base model technical report](https://huggingface.co/data-archetype/semdisdiffae/blob/main/technical_report_semantic.md)
7
+ for full architectural and training details. This document covers only the
8
+ differences.
9
+
10
+ ---
11
+
12
+ ## Bottleneck Patchification
13
+
14
+ The base SemDisDiffAE produces 128-channel latents at H/16 x W/16. This
15
+ variant adds a 2x2 patchification step after the encoder bottleneck:
16
+
17
+ $$z_\text{p32} = \text{Patchify}_{2 \times 2}(z_\text{p16})$$
18
+
19
+ The patchification reshapes each 2x2 spatial block of 128 channels into a
20
+ single spatial position with 512 channels:
21
+
22
+ ```
23
+ z_p16 ∈ ℝ^{B Γ— 128 Γ— H/16 Γ— W/16} β†’ z_p32 ∈ ℝ^{B Γ— 512 Γ— H/32 Γ— W/32}
24
+ ```
25
+
26
+ This is a lossless reshape β€” no learned parameters are involved. The
27
+ inverse operation (unpatchify) is applied before decoding.
28
+
29
+ ## Where Patchification Sits in the Pipeline
30
+
31
+ The encoder and decoder architectures are **identical** to the base model.
32
+ The patchification is applied as an outer wrapper:
33
+
34
+ - **Encode**: encoder blocks (128ch at H/16) β†’ posterior mode β†’ **patchify** (512ch at H/32) β†’ whiten
35
+ - **Decode**: dewhiten β†’ **unpatchify** (128ch at H/16) β†’ decoder blocks β†’ image
36
+
37
+ Running stats (mean/variance for whitening) are tracked in the patchified
38
+ 512-channel space.
39
+
40
+ Semantic alignment and the VP posterior variance expansion loss operate
41
+ **before** patchification, in the 128-channel space, matching the base
42
+ model's behavior.
43
+
44
+ ## Training
45
+
46
+ Same losses, weights, and hyperparameters as the base SemDisDiffAE.
47
+ Trained for **275k steps**.
48
+
49
+ ## Model Configuration
50
+
51
+ | Parameter | Value |
52
+ |-----------|-------|
53
+ | Encoder patch | 16 |
54
+ | Bottleneck dim | 128 |
55
+ | Bottleneck patchify | 2x2 |
56
+ | Exported latent channels | 512 |
57
+ | Effective spatial stride | 32 |
58
+ | Parameters | 88.8M |
59
+
60
+ All other parameters identical to the
61
+ [base model configuration](https://huggingface.co/data-archetype/semdisdiffae/blob/main/technical_report_semantic.md#7-model-configuration).