data-archetype commited on
Commit
433bab6
·
verified ·
1 Parent(s): 8f5592f

Upload folder using huggingface_hub

Browse files
fcdm_diffae/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDMDiffAE: Standalone diffusion autoencoder with FCDM blocks.
2
+
3
+ FCDM DiffAE — a fast diffusion autoencoder with a 128-channel spatial
4
+ bottleneck and a VP-parameterized diagonal Gaussian posterior. Built on FCDM
5
+ (Fully Convolutional Diffusion Model) blocks with GRN and scale+gate AdaLN.
6
+
7
+ Usage::
8
+
9
+ from fcdm_diffae import FCDMDiffAE, FCDMDiffAEInferenceConfig
10
+
11
+ model = FCDMDiffAE.from_pretrained("path/to/weights", device="cuda")
12
+
13
+ # Encode (returns posterior mode by default)
14
+ latents = model.encode(images) # images: [B,3,H,W] in [-1,1]
15
+
16
+ # Decode — PSNR-optimal (1 step, default)
17
+ recon = model.decode(latents, height=H, width=W)
18
+
19
+ # Decode — perceptual sharpness (10 steps + path-drop PDG)
20
+ cfg = FCDMDiffAEInferenceConfig(num_steps=10, pdg=True, pdg_strength=2.0)
21
+ recon = model.decode(latents, height=H, width=W, inference_config=cfg)
22
+ """
23
+
24
+ from .config import FCDMDiffAEConfig, FCDMDiffAEInferenceConfig
25
+ from .encoder import EncoderPosterior
26
+ from .model import FCDMDiffAE
27
+
28
+ __all__ = [
29
+ "FCDMDiffAE",
30
+ "FCDMDiffAEConfig",
31
+ "FCDMDiffAEInferenceConfig",
32
+ "EncoderPosterior",
33
+ ]
fcdm_diffae/adaln.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scale+Gate AdaLN (2-way) for FCDM decoder blocks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class AdaLNScaleGateZeroProjector(nn.Module):
9
+ """Packed 2-way AdaLN projection (SiLU -> Linear), zero-initialized.
10
+
11
+ Outputs [B, 2*d_model] packed as (scale, gate).
12
+ """
13
+
14
+ def __init__(self, d_model: int, d_cond: int) -> None:
15
+ super().__init__()
16
+ self.d_model: int = int(d_model)
17
+ self.d_cond: int = int(d_cond)
18
+ self.act: nn.SiLU = nn.SiLU()
19
+ self.proj: nn.Linear = nn.Linear(self.d_cond, 2 * self.d_model)
20
+ nn.init.zeros_(self.proj.weight)
21
+ nn.init.zeros_(self.proj.bias)
22
+
23
+ def forward_activated(self, act_cond: Tensor) -> Tensor:
24
+ """Return packed modulation for a pre-activated conditioning vector."""
25
+ return self.proj(act_cond)
26
+
27
+ def forward(self, cond: Tensor) -> Tensor:
28
+ """Return packed modulation [B, 2*d_model]."""
29
+ return self.forward_activated(self.act(cond))
30
+
31
+
32
+ class AdaLNScaleGateZeroLowRankDelta(nn.Module):
33
+ """Low-rank delta for 2-way AdaLN: down(d_cond -> rank) -> up(rank -> 2*d_model).
34
+
35
+ Zero-initialized up projection preserves zero-output semantics at init.
36
+ """
37
+
38
+ def __init__(self, *, d_model: int, d_cond: int, rank: int) -> None:
39
+ super().__init__()
40
+ self.d_model: int = int(d_model)
41
+ self.d_cond: int = int(d_cond)
42
+ self.rank: int = int(rank)
43
+ self.down: nn.Linear = nn.Linear(self.d_cond, self.rank, bias=False)
44
+ self.up: nn.Linear = nn.Linear(self.rank, 2 * self.d_model, bias=False)
45
+ nn.init.normal_(self.down.weight, mean=0.0, std=0.02)
46
+ nn.init.zeros_(self.up.weight)
47
+
48
+ def forward(self, act_cond: Tensor) -> Tensor:
49
+ """Return packed delta modulation [B, 2*d_model]."""
50
+ return self.up(self.down(act_cond))
fcdm_diffae/config.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Frozen model architecture and user-tunable inference configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import asdict, dataclass
7
+ from pathlib import Path
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class FCDMDiffAEConfig:
12
+ """Frozen model architecture config. Stored alongside weights as config.json."""
13
+
14
+ in_channels: int = 3
15
+ patch_size: int = 16
16
+ model_dim: int = 896
17
+ encoder_depth: int = 4
18
+ decoder_depth: int = 8
19
+ decoder_start_blocks: int = 2
20
+ decoder_end_blocks: int = 2
21
+ bottleneck_dim: int = 128
22
+ mlp_ratio: float = 4.0
23
+ depthwise_kernel_size: int = 7
24
+ adaln_low_rank_rank: int = 128
25
+ # Encoder posterior kind: "diagonal_gaussian" or "deterministic"
26
+ bottleneck_posterior_kind: str = "diagonal_gaussian"
27
+ # Post-bottleneck normalization: "channel_wise" or "disabled"
28
+ bottleneck_norm_mode: str = "disabled"
29
+ # VP diffusion schedule endpoints
30
+ logsnr_min: float = -10.0
31
+ logsnr_max: float = 10.0
32
+ # Pixel-space noise std for VP diffusion initialization
33
+ pixel_noise_std: float = 0.558
34
+
35
+ def save(self, path: str | Path) -> None:
36
+ """Save config as JSON."""
37
+ p = Path(path)
38
+ p.parent.mkdir(parents=True, exist_ok=True)
39
+ p.write_text(json.dumps(asdict(self), indent=2) + "\n")
40
+
41
+ @classmethod
42
+ def load(cls, path: str | Path) -> FCDMDiffAEConfig:
43
+ """Load config from JSON."""
44
+ data = json.loads(Path(path).read_text())
45
+ return cls(**data)
46
+
47
+
48
+ @dataclass
49
+ class FCDMDiffAEInferenceConfig:
50
+ """User-tunable inference parameters with sensible defaults.
51
+
52
+ PDG (Path-Drop Guidance) sharpens reconstructions by degrading conditioning
53
+ in one pass and amplifying the difference. When enabled, uses 2 NFE per step.
54
+ Recommended: ``pdg=True, pdg_strength=2.0, num_steps=10``.
55
+ """
56
+
57
+ num_steps: int = 1 # number of denoising steps (NFE)
58
+ sampler: str = "ddim" # "ddim" or "dpmpp_2m"
59
+ schedule: str = "linear" # "linear" or "cosine"
60
+ pdg: bool = False # enable PDG for perceptual sharpening
61
+ pdg_strength: float = 2.0 # CFG-like strength when pdg=True
62
+ seed: int | None = None
fcdm_diffae/decoder.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Capacitor decoder: skip-concat topology with FCDM blocks and dual PDG.
2
+
3
+ No outer RMSNorms (use_other_outer_rms_norms=False during training):
4
+ norm_in, latent_norm, and norm_out are all absent.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+ from torch import Tensor, nn
11
+
12
+ from .adaln import AdaLNScaleGateZeroLowRankDelta, AdaLNScaleGateZeroProjector
13
+ from .fcdm_block import FCDMBlock
14
+ from .straight_through_encoder import Patchify
15
+ from .time_embed import SinusoidalTimeEmbeddingMLP
16
+
17
+
18
+ class Decoder(nn.Module):
19
+ """VP diffusion decoder conditioned on encoder latents and timestep.
20
+
21
+ Architecture (skip-concat, 2+4+2 default):
22
+ Patchify x_t -> Fuse with upsampled z
23
+ -> Start blocks (2) -> Middle blocks (4) -> Skip fuse -> End blocks (2)
24
+ -> Conv1x1 -> PixelShuffle
25
+
26
+ Dual PDG at inference:
27
+ - Path drop: replace middle block output with ``path_drop_mask_feature``.
28
+ - Token mask: replace a fraction of upsampled latent tokens with
29
+ ``latent_mask_feature`` before fusion.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ in_channels: int,
35
+ patch_size: int,
36
+ model_dim: int,
37
+ depth: int,
38
+ start_block_count: int,
39
+ end_block_count: int,
40
+ bottleneck_dim: int,
41
+ mlp_ratio: float,
42
+ depthwise_kernel_size: int,
43
+ adaln_low_rank_rank: int,
44
+ ) -> None:
45
+ super().__init__()
46
+ self.patch_size = int(patch_size)
47
+ self.model_dim = int(model_dim)
48
+
49
+ # Input processing (no norm_in)
50
+ self.patchify = Patchify(in_channels, patch_size, model_dim)
51
+
52
+ # Latent conditioning path (no latent_norm)
53
+ self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
54
+ self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
55
+
56
+ # Time embedding
57
+ self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
58
+
59
+ # 2-way AdaLN: shared base projector + per-block low-rank deltas
60
+ self.adaln_base = AdaLNScaleGateZeroProjector(
61
+ d_model=model_dim, d_cond=model_dim
62
+ )
63
+ self.adaln_deltas = nn.ModuleList(
64
+ [
65
+ AdaLNScaleGateZeroLowRankDelta(
66
+ d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
67
+ )
68
+ for _ in range(depth)
69
+ ]
70
+ )
71
+
72
+ # Block layout: start + middle + end
73
+ middle_count = depth - start_block_count - end_block_count
74
+ self._middle_start_idx = start_block_count
75
+ self._end_start_idx = start_block_count + middle_count
76
+
77
+ def _make_blocks(count: int) -> nn.ModuleList:
78
+ return nn.ModuleList(
79
+ [
80
+ FCDMBlock(
81
+ model_dim,
82
+ mlp_ratio,
83
+ depthwise_kernel_size=depthwise_kernel_size,
84
+ use_external_adaln=True,
85
+ )
86
+ for _ in range(count)
87
+ ]
88
+ )
89
+
90
+ self.start_blocks = _make_blocks(start_block_count)
91
+ self.middle_blocks = _make_blocks(middle_count)
92
+ self.fuse_skip = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
93
+ self.end_blocks = _make_blocks(end_block_count)
94
+
95
+ # Learned mask feature for path-drop PDG
96
+ self.path_drop_mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
97
+
98
+ # Output head (no norm_out)
99
+ self.out_proj = nn.Conv2d(
100
+ model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
101
+ )
102
+ self.unpatchify = nn.PixelShuffle(patch_size)
103
+
104
+ def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
105
+ """Compute packed AdaLN modulation = shared_base + per-layer delta."""
106
+ act = self.adaln_base.act(cond)
107
+ base_m = self.adaln_base.forward_activated(act)
108
+ delta_m = self.adaln_deltas[layer_idx](act)
109
+ return base_m + delta_m
110
+
111
+ def _run_blocks(
112
+ self, blocks: nn.ModuleList, x: Tensor, cond: Tensor, start_index: int
113
+ ) -> Tensor:
114
+ """Run a group of decoder blocks with per-block AdaLN modulation."""
115
+ for local_idx, block in enumerate(blocks):
116
+ adaln_m = self._adaln_m_for_layer(cond, layer_idx=start_index + local_idx)
117
+ x = block(x, adaln_m=adaln_m)
118
+ return x
119
+
120
+ def forward(
121
+ self,
122
+ x_t: Tensor,
123
+ t: Tensor,
124
+ latents: Tensor,
125
+ *,
126
+ drop_middle_blocks: bool = False,
127
+ ) -> Tensor:
128
+ """Single decoder forward pass.
129
+
130
+ Args:
131
+ x_t: Noised image [B, C, H, W].
132
+ t: Timestep [B] in [0, 1].
133
+ latents: Encoder latents [B, bottleneck_dim, h, w].
134
+ drop_middle_blocks: Replace middle block output with mask feature (PDG).
135
+
136
+ Returns:
137
+ x0 prediction [B, C, H, W].
138
+ """
139
+ x_feat = self.patchify(x_t)
140
+ z_up = self.latent_up(latents)
141
+
142
+ fused = torch.cat([x_feat, z_up], dim=1)
143
+ fused = self.fuse_in(fused)
144
+
145
+ cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
146
+
147
+ start_out = self._run_blocks(self.start_blocks, fused, cond, start_index=0)
148
+
149
+ if drop_middle_blocks:
150
+ middle_out = self.path_drop_mask_feature.to(
151
+ device=x_t.device, dtype=x_t.dtype
152
+ ).expand_as(start_out)
153
+ else:
154
+ middle_out = self._run_blocks(
155
+ self.middle_blocks,
156
+ start_out,
157
+ cond,
158
+ start_index=self._middle_start_idx,
159
+ )
160
+
161
+ skip_fused = torch.cat([start_out, middle_out], dim=1)
162
+ skip_fused = self.fuse_skip(skip_fused)
163
+
164
+ end_out = self._run_blocks(
165
+ self.end_blocks, skip_fused, cond, start_index=self._end_start_idx
166
+ )
167
+
168
+ patches = self.out_proj(end_out)
169
+ return self.unpatchify(patches)
fcdm_diffae/encoder.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Capacitor encoder: patchify -> FCDMBlocks -> diagonal Gaussian posterior.
2
+
3
+ No input RMSNorm (use_other_outer_rms_norms=False during training).
4
+ Post-bottleneck RMSNorm (affine=False) on the mean branch.
5
+ Encoder outputs posterior mode by default: alpha * RMSNorm(mean).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+
12
+ import torch
13
+ from torch import Tensor, nn
14
+
15
+ from .fcdm_block import FCDMBlock
16
+ from .norms import ChannelWiseRMSNorm
17
+ from .straight_through_encoder import Patchify
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class EncoderPosterior:
22
+ """VP-parameterized diagonal Gaussian posterior.
23
+
24
+ mean: Clean signal branch mu [B, bottleneck_dim, h, w]
25
+ logsnr: Per-element log signal-to-noise ratio [B, bottleneck_dim, h, w]
26
+ """
27
+
28
+ mean: Tensor
29
+ logsnr: Tensor
30
+
31
+ @property
32
+ def alpha(self) -> Tensor:
33
+ """VP signal coefficient: sqrt(sigmoid(logsnr))."""
34
+ return torch.sigmoid(self.logsnr).sqrt()
35
+
36
+ @property
37
+ def sigma(self) -> Tensor:
38
+ """VP noise coefficient: sqrt(sigmoid(-logsnr))."""
39
+ return torch.sigmoid(-self.logsnr).sqrt()
40
+
41
+ def mode(self) -> Tensor:
42
+ """Posterior mode in token space: alpha * mean."""
43
+ return self.alpha.to(dtype=self.mean.dtype) * self.mean
44
+
45
+ def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
46
+ """Sample from posterior: alpha * mean + sigma * eps."""
47
+ eps = torch.randn_like(self.mean, generator=generator) # type: ignore[call-overload]
48
+ alpha = self.alpha.to(dtype=self.mean.dtype)
49
+ sigma = self.sigma.to(dtype=self.mean.dtype)
50
+ return alpha * self.mean + sigma * eps
51
+
52
+
53
+ class Encoder(nn.Module):
54
+ """Encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].
55
+
56
+ With diagonal_gaussian posterior, the to_bottleneck projection outputs
57
+ 2 * bottleneck_dim channels, split into mean and logsnr. The default
58
+ encode() returns the posterior mode: alpha * RMSNorm(mean).
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ in_channels: int,
64
+ patch_size: int,
65
+ model_dim: int,
66
+ depth: int,
67
+ bottleneck_dim: int,
68
+ mlp_ratio: float,
69
+ depthwise_kernel_size: int,
70
+ bottleneck_posterior_kind: str = "diagonal_gaussian",
71
+ bottleneck_norm_mode: str = "disabled",
72
+ ) -> None:
73
+ super().__init__()
74
+ self.bottleneck_dim = int(bottleneck_dim)
75
+ self.bottleneck_posterior_kind = bottleneck_posterior_kind
76
+ self.bottleneck_norm_mode = bottleneck_norm_mode
77
+ self.patchify = Patchify(in_channels, patch_size, model_dim)
78
+ self.blocks = nn.ModuleList(
79
+ [
80
+ FCDMBlock(
81
+ model_dim,
82
+ mlp_ratio,
83
+ depthwise_kernel_size=depthwise_kernel_size,
84
+ use_external_adaln=False,
85
+ )
86
+ for _ in range(depth)
87
+ ]
88
+ )
89
+ out_dim = (
90
+ 2 * bottleneck_dim
91
+ if bottleneck_posterior_kind == "diagonal_gaussian"
92
+ else bottleneck_dim
93
+ )
94
+ self.to_bottleneck = nn.Conv2d(model_dim, out_dim, kernel_size=1, bias=True)
95
+ if bottleneck_norm_mode == "channel_wise":
96
+ self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)
97
+ else:
98
+ self.norm_out = nn.Identity()
99
+
100
+ def encode_posterior(self, images: Tensor) -> EncoderPosterior:
101
+ """Encode images and return the full posterior (mean + logsnr).
102
+
103
+ Only valid when bottleneck_posterior_kind == "diagonal_gaussian".
104
+ """
105
+ z = self.patchify(images)
106
+ for block in self.blocks:
107
+ z = block(z)
108
+ projection = self.to_bottleneck(z)
109
+ mean, logsnr = projection.chunk(2, dim=1)
110
+ mean = self.norm_out(mean)
111
+ return EncoderPosterior(mean=mean, logsnr=logsnr)
112
+
113
+ def forward(self, images: Tensor) -> Tensor:
114
+ """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w].
115
+
116
+ Returns posterior mode (alpha * mean) for diagonal_gaussian,
117
+ or deterministic latents otherwise.
118
+ """
119
+ z = self.patchify(images)
120
+ for block in self.blocks:
121
+ z = block(z)
122
+ projection = self.to_bottleneck(z)
123
+ if self.bottleneck_posterior_kind == "diagonal_gaussian":
124
+ mean, logsnr = projection.chunk(2, dim=1)
125
+ mean = self.norm_out(mean)
126
+ alpha = torch.sigmoid(logsnr).sqrt().to(dtype=mean.dtype)
127
+ return alpha * mean
128
+ z = self.norm_out(projection)
129
+ return z
fcdm_diffae/fcdm_block.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDM block: ConvNeXt-style conv block with GRN and scale+gate AdaLN."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor, nn
8
+
9
+ from .norms import ChannelWiseRMSNorm
10
+
11
+
12
+ class GRN(nn.Module):
13
+ """Global Response Normalization for NCHW tensors."""
14
+
15
+ def __init__(self, channels: int, *, eps: float = 1e-6) -> None:
16
+ super().__init__()
17
+ self.eps: float = float(eps)
18
+ c = int(channels)
19
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
20
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
21
+
22
+ def forward(self, x: Tensor) -> Tensor:
23
+ g = torch.linalg.vector_norm(x, ord=2, dim=(2, 3), keepdim=True)
24
+ g_fp32 = g.to(dtype=torch.float32)
25
+ n = (g_fp32 / (g_fp32.mean(dim=1, keepdim=True) + self.eps)).to(dtype=x.dtype)
26
+ gamma = self.gamma.to(device=x.device, dtype=x.dtype)
27
+ beta = self.beta.to(device=x.device, dtype=x.dtype)
28
+ return gamma * (x * n) + beta + x
29
+
30
+
31
+ class FCDMBlock(nn.Module):
32
+ """ConvNeXt-style block with scale+gate AdaLN and GRN.
33
+
34
+ Two modes:
35
+ - Unconditioned (encoder): uses learned layer-scale for near-identity init.
36
+ - External AdaLN (decoder): receives packed [B, 2*C] modulation (scale, gate).
37
+ The gate is applied raw (no tanh).
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ channels: int,
43
+ mlp_ratio: float,
44
+ *,
45
+ depthwise_kernel_size: int = 7,
46
+ use_external_adaln: bool = False,
47
+ norm_eps: float = 1e-6,
48
+ layer_scale_init: float = 1e-3,
49
+ ) -> None:
50
+ super().__init__()
51
+ self.channels: int = int(channels)
52
+ self.mlp_ratio: float = float(mlp_ratio)
53
+
54
+ self.dwconv = nn.Conv2d(
55
+ channels,
56
+ channels,
57
+ kernel_size=depthwise_kernel_size,
58
+ padding=depthwise_kernel_size // 2,
59
+ stride=1,
60
+ groups=channels,
61
+ bias=True,
62
+ )
63
+ self.norm = ChannelWiseRMSNorm(channels, eps=float(norm_eps), affine=False)
64
+ hidden = max(int(float(channels) * float(mlp_ratio)), 1)
65
+ self.pwconv1 = nn.Conv2d(channels, hidden, kernel_size=1, bias=True)
66
+ self.grn = GRN(hidden, eps=1e-6)
67
+ self.pwconv2 = nn.Conv2d(hidden, channels, kernel_size=1, bias=True)
68
+
69
+ if not use_external_adaln:
70
+ self.layer_scale = nn.Parameter(
71
+ torch.full((channels,), float(layer_scale_init))
72
+ )
73
+ else:
74
+ self.register_parameter("layer_scale", None)
75
+
76
+ def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor:
77
+ b, c, _, _ = x.shape
78
+
79
+ if adaln_m is not None:
80
+ m = adaln_m.to(device=x.device, dtype=x.dtype)
81
+ scale, gate = m.chunk(2, dim=-1)
82
+ else:
83
+ scale = gate = None
84
+
85
+ h = self.dwconv(x)
86
+ h = self.norm(h)
87
+
88
+ if scale is not None:
89
+ h = h * (1.0 + scale.view(b, c, 1, 1))
90
+
91
+ h = self.pwconv1(h)
92
+ h = F.gelu(h)
93
+ h = self.grn(h)
94
+ h = self.pwconv2(h)
95
+
96
+ if gate is not None:
97
+ gate_view = gate.view(b, c, 1, 1)
98
+ else:
99
+ gate_view = self.layer_scale.view(1, c, 1, 1).to( # type: ignore[union-attr]
100
+ device=h.device, dtype=h.dtype
101
+ )
102
+
103
+ return x + gate_view * h
fcdm_diffae/model.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FCDMDiffAE: standalone HuggingFace-compatible diffusion autoencoder."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ from .config import FCDMDiffAEConfig, FCDMDiffAEInferenceConfig
11
+ from .decoder import Decoder
12
+ from .encoder import Encoder, EncoderPosterior
13
+ from .samplers import run_ddim, run_dpmpp_2m
14
+ from .vp_diffusion import get_schedule, make_initial_state, sample_noise
15
+
16
+
17
+ def _resolve_model_dir(
18
+ path_or_repo_id: str | Path,
19
+ *,
20
+ revision: str | None,
21
+ cache_dir: str | Path | None,
22
+ ) -> Path:
23
+ """Resolve a local path or HuggingFace Hub repo ID to a local directory."""
24
+ local = Path(path_or_repo_id)
25
+ if local.is_dir():
26
+ return local
27
+ repo_id = str(path_or_repo_id)
28
+ try:
29
+ from huggingface_hub import snapshot_download
30
+ except ImportError:
31
+ raise ImportError(
32
+ f"'{repo_id}' is not an existing local directory. "
33
+ "To download from HuggingFace Hub, install huggingface_hub: "
34
+ "pip install huggingface_hub"
35
+ )
36
+ cache_dir_str = str(cache_dir) if cache_dir is not None else None
37
+ local_dir = snapshot_download(
38
+ repo_id,
39
+ revision=revision,
40
+ cache_dir=cache_dir_str,
41
+ )
42
+ return Path(local_dir)
43
+
44
+
45
+ class FCDMDiffAE(nn.Module):
46
+ """Standalone FCDM DiffAE model for HuggingFace distribution.
47
+
48
+ A diffusion autoencoder built on FCDM (Fully Convolutional Diffusion Model)
49
+ blocks. Encodes images to compact 128-channel spatial latents via a
50
+ VP-parameterized diagonal Gaussian posterior, and decodes them back via
51
+ iterative VP diffusion with a skip-concat decoder.
52
+
53
+ Usage::
54
+
55
+ model = FCDMDiffAE.from_pretrained("path/to/weights")
56
+ model = model.to("cuda", dtype=torch.bfloat16)
57
+
58
+ # Encode (returns posterior mode by default)
59
+ latents = model.encode(images) # images: [B,3,H,W] in [-1,1]
60
+
61
+ # Decode (1 step by default — PSNR-optimal)
62
+ recon = model.decode(latents, height=H, width=W)
63
+
64
+ # Reconstruct (encode + 1-step decode)
65
+ recon = model.reconstruct(images)
66
+ """
67
+
68
+ _LATENT_NORM_EPS: float = 1e-4
69
+
70
+ def __init__(self, config: FCDMDiffAEConfig) -> None:
71
+ super().__init__()
72
+ self.config = config
73
+
74
+ # Latent running stats for whitening/dewhitening
75
+ self.register_buffer(
76
+ "latent_norm_running_mean",
77
+ torch.zeros((config.bottleneck_dim,), dtype=torch.float32),
78
+ )
79
+ self.register_buffer(
80
+ "latent_norm_running_var",
81
+ torch.ones((config.bottleneck_dim,), dtype=torch.float32),
82
+ )
83
+
84
+ self.encoder = Encoder(
85
+ in_channels=config.in_channels,
86
+ patch_size=config.patch_size,
87
+ model_dim=config.model_dim,
88
+ depth=config.encoder_depth,
89
+ bottleneck_dim=config.bottleneck_dim,
90
+ mlp_ratio=config.mlp_ratio,
91
+ depthwise_kernel_size=config.depthwise_kernel_size,
92
+ bottleneck_posterior_kind=config.bottleneck_posterior_kind,
93
+ bottleneck_norm_mode=config.bottleneck_norm_mode,
94
+ )
95
+
96
+ self.decoder = Decoder(
97
+ in_channels=config.in_channels,
98
+ patch_size=config.patch_size,
99
+ model_dim=config.model_dim,
100
+ depth=config.decoder_depth,
101
+ start_block_count=config.decoder_start_blocks,
102
+ end_block_count=config.decoder_end_blocks,
103
+ bottleneck_dim=config.bottleneck_dim,
104
+ mlp_ratio=config.mlp_ratio,
105
+ depthwise_kernel_size=config.depthwise_kernel_size,
106
+ adaln_low_rank_rank=config.adaln_low_rank_rank,
107
+ )
108
+
109
+ @classmethod
110
+ def from_pretrained(
111
+ cls,
112
+ path_or_repo_id: str | Path,
113
+ *,
114
+ dtype: torch.dtype = torch.bfloat16,
115
+ device: str | torch.device = "cpu",
116
+ revision: str | None = None,
117
+ cache_dir: str | Path | None = None,
118
+ ) -> FCDMDiffAE:
119
+ """Load a pretrained model from a local directory or HuggingFace Hub.
120
+
121
+ The directory (or repo) should contain:
122
+ - config.json: Model architecture config.
123
+ - model.safetensors (preferred) or model.pt: Model weights.
124
+
125
+ Args:
126
+ path_or_repo_id: Local directory path or HuggingFace Hub repo ID.
127
+ dtype: Load weights in this dtype (float32 or bfloat16).
128
+ device: Target device.
129
+ revision: Git revision for Hub downloads.
130
+ cache_dir: Where to cache Hub downloads.
131
+
132
+ Returns:
133
+ Loaded model in eval mode.
134
+ """
135
+ model_dir = _resolve_model_dir(
136
+ path_or_repo_id, revision=revision, cache_dir=cache_dir
137
+ )
138
+ config = FCDMDiffAEConfig.load(model_dir / "config.json")
139
+ model = cls(config)
140
+
141
+ safetensors_path = model_dir / "model.safetensors"
142
+ pt_path = model_dir / "model.pt"
143
+
144
+ if safetensors_path.exists():
145
+ try:
146
+ from safetensors.torch import load_file
147
+
148
+ state_dict = load_file(str(safetensors_path), device=str(device))
149
+ except ImportError:
150
+ raise ImportError(
151
+ "safetensors package required to load .safetensors files. "
152
+ "Install with: pip install safetensors"
153
+ )
154
+ elif pt_path.exists():
155
+ state_dict = torch.load(
156
+ str(pt_path), map_location=device, weights_only=True
157
+ )
158
+ else:
159
+ raise FileNotFoundError(
160
+ f"No model weights found in {model_dir}. "
161
+ "Expected model.safetensors or model.pt."
162
+ )
163
+
164
+ model.load_state_dict(state_dict)
165
+ model = model.to(dtype=dtype, device=torch.device(device))
166
+ model.eval()
167
+ return model
168
+
169
+ def _latent_norm_stats(self) -> tuple[Tensor, Tensor]:
170
+ """Return (mean, std) tensors for latent whitening, shaped [1,C,1,1]."""
171
+ mean = self.latent_norm_running_mean.view(1, -1, 1, 1)
172
+ var = self.latent_norm_running_var.view(1, -1, 1, 1)
173
+ std = torch.sqrt(var.to(torch.float32) + self._LATENT_NORM_EPS)
174
+ return mean.to(torch.float32), std
175
+
176
+ def whiten(self, latents: Tensor) -> Tensor:
177
+ """Whiten encoder latents using per-channel running stats.
178
+
179
+ Use this before passing latents to a downstream latent-space
180
+ diffusion model. The whitened latents have approximately zero mean
181
+ and unit variance per channel.
182
+
183
+ Args:
184
+ latents: [B, bottleneck_dim, h, w] raw encoder output.
185
+
186
+ Returns:
187
+ Whitened latents [B, bottleneck_dim, h, w] in float32.
188
+ """
189
+ z = latents.to(torch.float32)
190
+ mean, std = self._latent_norm_stats()
191
+ return (z - mean.to(device=z.device)) / std.to(device=z.device)
192
+
193
+ def dewhiten(self, latents: Tensor) -> Tensor:
194
+ """Undo whitening to recover raw encoder latent scale.
195
+
196
+ Use this before passing whitened latents back to ``decode()``.
197
+
198
+ Args:
199
+ latents: [B, bottleneck_dim, h, w] whitened latents.
200
+
201
+ Returns:
202
+ Dewhitened latents [B, bottleneck_dim, h, w] in float32.
203
+ """
204
+ z = latents.to(torch.float32)
205
+ mean, std = self._latent_norm_stats()
206
+ return z * std.to(device=z.device) + mean.to(device=z.device)
207
+
208
+ def encode(self, images: Tensor) -> Tensor:
209
+ """Encode images to whitened latents (posterior mode).
210
+
211
+ Returns latents whitened using per-channel running stats, ready for
212
+ use by downstream latent-space diffusion models.
213
+
214
+ Args:
215
+ images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
216
+
217
+ Returns:
218
+ Whitened latents [B, bottleneck_dim, H/patch, W/patch].
219
+ """
220
+ try:
221
+ model_dtype = next(self.parameters()).dtype
222
+ except StopIteration:
223
+ model_dtype = torch.float32
224
+ z = self.encoder(images.to(dtype=model_dtype))
225
+ return self.whiten(z).to(dtype=model_dtype)
226
+
227
+ def encode_posterior(self, images: Tensor) -> EncoderPosterior:
228
+ """Encode images and return the full posterior (mean + logsnr).
229
+
230
+ Args:
231
+ images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
232
+
233
+ Returns:
234
+ EncoderPosterior with mean and logsnr tensors.
235
+ """
236
+ try:
237
+ model_dtype = next(self.parameters()).dtype
238
+ except StopIteration:
239
+ model_dtype = torch.float32
240
+ return self.encoder.encode_posterior(images.to(dtype=model_dtype))
241
+
242
+ @torch.no_grad()
243
+ def decode(
244
+ self,
245
+ latents: Tensor,
246
+ height: int,
247
+ width: int,
248
+ *,
249
+ inference_config: FCDMDiffAEInferenceConfig | None = None,
250
+ ) -> Tensor:
251
+ """Decode whitened latents to images via VP diffusion.
252
+
253
+ Latents are dewhitened internally before being passed to the decoder.
254
+
255
+ Args:
256
+ latents: [B, bottleneck_dim, h, w] whitened encoder latents.
257
+ height: Output image height (divisible by patch_size).
258
+ width: Output image width (divisible by patch_size).
259
+ inference_config: Optional inference parameters.
260
+
261
+ Returns:
262
+ Reconstructed images [B, 3, H, W] in float32.
263
+ """
264
+ cfg = inference_config or FCDMDiffAEInferenceConfig()
265
+ config = self.config
266
+ batch = int(latents.shape[0])
267
+ device = latents.device
268
+
269
+ try:
270
+ model_dtype = next(self.parameters()).dtype
271
+ except StopIteration:
272
+ model_dtype = torch.float32
273
+
274
+ # Dewhiten back to raw encoder scale for the decoder
275
+ latents = self.dewhiten(latents).to(dtype=model_dtype)
276
+
277
+ if height % config.patch_size != 0 or width % config.patch_size != 0:
278
+ raise ValueError(
279
+ f"height={height} and width={width} must be divisible by "
280
+ f"patch_size={config.patch_size}"
281
+ )
282
+
283
+ shape = (batch, config.in_channels, height, width)
284
+ noise = sample_noise(
285
+ shape,
286
+ noise_std=config.pixel_noise_std,
287
+ seed=cfg.seed,
288
+ device=torch.device("cpu"),
289
+ dtype=torch.float32,
290
+ )
291
+
292
+ schedule = get_schedule(cfg.schedule, cfg.num_steps).to(device=device)
293
+ initial_state = make_initial_state(
294
+ noise=noise.to(device=device),
295
+ t_start=schedule[0:1],
296
+ logsnr_min=config.logsnr_min,
297
+ logsnr_max=config.logsnr_max,
298
+ )
299
+
300
+ device_type = "cuda" if device.type == "cuda" else "cpu"
301
+ with torch.autocast(device_type=device_type, enabled=False):
302
+ latents_in = latents.to(device=device)
303
+
304
+ def _forward_fn(
305
+ x_t: Tensor,
306
+ t: Tensor,
307
+ latents: Tensor,
308
+ *,
309
+ drop_middle_blocks: bool = False,
310
+ mask_latent_tokens: bool = False,
311
+ ) -> Tensor:
312
+ return self.decoder(
313
+ x_t.to(dtype=model_dtype),
314
+ t,
315
+ latents.to(dtype=model_dtype),
316
+ drop_middle_blocks=drop_middle_blocks,
317
+ )
318
+
319
+ pdg_mode = "path_drop" if cfg.pdg else "disabled"
320
+
321
+ if cfg.sampler == "ddim":
322
+ sampler_fn = run_ddim
323
+ elif cfg.sampler == "dpmpp_2m":
324
+ sampler_fn = run_dpmpp_2m
325
+ else:
326
+ raise ValueError(
327
+ f"Unsupported sampler: {cfg.sampler!r}. Use 'ddim' or 'dpmpp_2m'."
328
+ )
329
+
330
+ result = sampler_fn(
331
+ forward_fn=_forward_fn,
332
+ initial_state=initial_state,
333
+ schedule=schedule,
334
+ latents=latents_in,
335
+ logsnr_min=config.logsnr_min,
336
+ logsnr_max=config.logsnr_max,
337
+ pdg_mode=pdg_mode,
338
+ pdg_strength=cfg.pdg_strength,
339
+ device=device,
340
+ )
341
+
342
+ return result
343
+
344
+ @torch.no_grad()
345
+ def reconstruct(
346
+ self,
347
+ images: Tensor,
348
+ *,
349
+ inference_config: FCDMDiffAEInferenceConfig | None = None,
350
+ ) -> Tensor:
351
+ """Encode then decode. Convenience wrapper.
352
+
353
+ Args:
354
+ images: [B, 3, H, W] in [-1, 1].
355
+ inference_config: Optional inference parameters.
356
+
357
+ Returns:
358
+ Reconstructed images [B, 3, H, W] in float32.
359
+ """
360
+ latents = self.encode(images)
361
+ _, _, h, w = images.shape
362
+ return self.decode(
363
+ latents, height=h, width=w, inference_config=inference_config
364
+ )
fcdm_diffae/norms.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Channel-wise RMSNorm for NCHW tensors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+
9
+ class ChannelWiseRMSNorm(nn.Module):
10
+ """Channel-wise RMSNorm with float32 reduction for numerical stability.
11
+
12
+ Normalizes across channels per spatial position. Supports optional
13
+ per-channel affine weight and bias.
14
+ """
15
+
16
+ def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None:
17
+ super().__init__()
18
+ self.channels: int = int(channels)
19
+ self._eps: float = float(eps)
20
+ if affine:
21
+ self.weight = nn.Parameter(torch.ones(self.channels))
22
+ self.bias = nn.Parameter(torch.zeros(self.channels))
23
+ else:
24
+ self.register_parameter("weight", None)
25
+ self.register_parameter("bias", None)
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ if x.dim() < 2:
29
+ return x
30
+ # Float32 accumulation for stability
31
+ ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
32
+ inv_rms = torch.rsqrt(ms + self._eps)
33
+ y = x * inv_rms
34
+ if self.weight is not None:
35
+ shape = (1, -1) + (1,) * (x.dim() - 2)
36
+ y = y * self.weight.view(shape).to(dtype=y.dtype)
37
+ if self.bias is not None:
38
+ y = y + self.bias.view(shape).to(dtype=y.dtype)
39
+ return y.to(dtype=x.dtype)
fcdm_diffae/samplers.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DDIM and DPM++2M samplers for VP diffusion with dual PDG support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Protocol
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+ from .vp_diffusion import (
11
+ alpha_sigma_from_logsnr,
12
+ broadcast_time_like,
13
+ shifted_cosine_interpolated_logsnr_from_t,
14
+ )
15
+
16
+
17
+ class DecoderForwardFn(Protocol):
18
+ """Callable that predicts x0 from (x_t, t, latents) with dual PDG flags."""
19
+
20
+ def __call__(
21
+ self,
22
+ x_t: Tensor,
23
+ t: Tensor,
24
+ latents: Tensor,
25
+ *,
26
+ drop_middle_blocks: bool = False,
27
+ mask_latent_tokens: bool = False,
28
+ ) -> Tensor: ...
29
+
30
+
31
+ def _reconstruct_eps_from_x0(
32
+ *, x_t: Tensor, x0_hat: Tensor, alpha: Tensor, sigma: Tensor
33
+ ) -> Tensor:
34
+ """Reconstruct eps_hat from (x_t, x0_hat) under VP parameterization.
35
+
36
+ eps_hat = (x_t - alpha * x0_hat) / sigma. All float32.
37
+ """
38
+ alpha_view = broadcast_time_like(alpha, x_t).to(dtype=torch.float32)
39
+ sigma_view = broadcast_time_like(sigma, x_t).to(dtype=torch.float32)
40
+ x_t_f32 = x_t.to(torch.float32)
41
+ x0_f32 = x0_hat.to(torch.float32)
42
+ return (x_t_f32 - alpha_view * x0_f32) / sigma_view
43
+
44
+
45
+ def _ddim_step(
46
+ *,
47
+ x0_hat: Tensor,
48
+ eps_hat: Tensor,
49
+ alpha_next: Tensor,
50
+ sigma_next: Tensor,
51
+ ref: Tensor,
52
+ ) -> Tensor:
53
+ """DDIM step: x_next = alpha_next * x0_hat + sigma_next * eps_hat."""
54
+ a = broadcast_time_like(alpha_next, ref).to(dtype=torch.float32)
55
+ s = broadcast_time_like(sigma_next, ref).to(dtype=torch.float32)
56
+ return a * x0_hat + s * eps_hat
57
+
58
+
59
+ def _predict_with_pdg(
60
+ forward_fn: DecoderForwardFn,
61
+ state: Tensor,
62
+ t_vec: Tensor,
63
+ latents: Tensor,
64
+ *,
65
+ pdg_mode: str,
66
+ pdg_strength: float,
67
+ ) -> Tensor:
68
+ """Run decoder forward with optional PDG guidance.
69
+
70
+ Args:
71
+ forward_fn: Decoder forward function.
72
+ state: Current noised state [B, C, H, W].
73
+ t_vec: Timestep vector [B].
74
+ latents: Encoder latents.
75
+ pdg_mode: "disabled", "path_drop", or "token_mask".
76
+ pdg_strength: CFG-like strength for PDG.
77
+
78
+ Returns:
79
+ x0_hat prediction in float32.
80
+ """
81
+ if pdg_mode == "path_drop":
82
+ x0_uncond = forward_fn(state, t_vec, latents, drop_middle_blocks=True).to(
83
+ torch.float32
84
+ )
85
+ x0_cond = forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
86
+ torch.float32
87
+ )
88
+ return x0_uncond + pdg_strength * (x0_cond - x0_uncond)
89
+ elif pdg_mode == "token_mask":
90
+ x0_uncond = forward_fn(state, t_vec, latents, mask_latent_tokens=True).to(
91
+ torch.float32
92
+ )
93
+ x0_cond = forward_fn(state, t_vec, latents, mask_latent_tokens=False).to(
94
+ torch.float32
95
+ )
96
+ return x0_uncond + pdg_strength * (x0_cond - x0_uncond)
97
+ else:
98
+ return forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
99
+ torch.float32
100
+ )
101
+
102
+
103
+ def run_ddim(
104
+ *,
105
+ forward_fn: DecoderForwardFn,
106
+ initial_state: Tensor,
107
+ schedule: Tensor,
108
+ latents: Tensor,
109
+ logsnr_min: float,
110
+ logsnr_max: float,
111
+ log_change_high: float = 0.0,
112
+ log_change_low: float = 0.0,
113
+ pdg_mode: str = "disabled",
114
+ pdg_strength: float = 1.5,
115
+ device: torch.device | None = None,
116
+ ) -> Tensor:
117
+ """Run DDIM sampling loop with dual PDG support.
118
+
119
+ Args:
120
+ forward_fn: Decoder forward function (x_t, t, latents) -> x0_hat.
121
+ initial_state: Starting noised state [B, C, H, W] in float32.
122
+ schedule: Descending t-schedule [num_steps] in [0, 1].
123
+ latents: Encoder latents [B, bottleneck_dim, h, w].
124
+ logsnr_min, logsnr_max: VP schedule endpoints.
125
+ log_change_high, log_change_low: Shifted-cosine schedule parameters.
126
+ pdg_mode: "disabled", "path_drop", or "token_mask".
127
+ pdg_strength: CFG-like strength for PDG.
128
+ device: Target device.
129
+
130
+ Returns:
131
+ Denoised samples [B, C, H, W] in float32.
132
+ """
133
+ run_device = device or initial_state.device
134
+ batch_size = int(initial_state.shape[0])
135
+ state = initial_state.to(device=run_device, dtype=torch.float32)
136
+
137
+ # Precompute logSNR, alpha, sigma for all schedule points
138
+ lmb = shifted_cosine_interpolated_logsnr_from_t(
139
+ schedule.to(device=run_device),
140
+ logsnr_min=logsnr_min,
141
+ logsnr_max=logsnr_max,
142
+ log_change_high=log_change_high,
143
+ log_change_low=log_change_low,
144
+ )
145
+ alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
146
+
147
+ for i in range(int(schedule.numel()) - 1):
148
+ t_i = schedule[i]
149
+ a_t = alpha_sched[i].expand(batch_size)
150
+ s_t = sigma_sched[i].expand(batch_size)
151
+ a_next = alpha_sched[i + 1].expand(batch_size)
152
+ s_next = sigma_sched[i + 1].expand(batch_size)
153
+
154
+ # Model prediction with optional PDG
155
+ t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
156
+ x0_hat = _predict_with_pdg(
157
+ forward_fn,
158
+ state,
159
+ t_vec,
160
+ latents,
161
+ pdg_mode=pdg_mode,
162
+ pdg_strength=pdg_strength,
163
+ )
164
+
165
+ eps_hat = _reconstruct_eps_from_x0(
166
+ x_t=state, x0_hat=x0_hat, alpha=a_t, sigma=s_t
167
+ )
168
+ state = _ddim_step(
169
+ x0_hat=x0_hat,
170
+ eps_hat=eps_hat,
171
+ alpha_next=a_next,
172
+ sigma_next=s_next,
173
+ ref=state,
174
+ )
175
+
176
+ return state
177
+
178
+
179
+ def run_dpmpp_2m(
180
+ *,
181
+ forward_fn: DecoderForwardFn,
182
+ initial_state: Tensor,
183
+ schedule: Tensor,
184
+ latents: Tensor,
185
+ logsnr_min: float,
186
+ logsnr_max: float,
187
+ log_change_high: float = 0.0,
188
+ log_change_low: float = 0.0,
189
+ pdg_mode: str = "disabled",
190
+ pdg_strength: float = 1.5,
191
+ device: torch.device | None = None,
192
+ ) -> Tensor:
193
+ """Run DPM++2M sampling loop with dual PDG support.
194
+
195
+ Multi-step solver using exponential integrator formulation in half-lambda space.
196
+ """
197
+ run_device = device or initial_state.device
198
+ batch_size = int(initial_state.shape[0])
199
+ state = initial_state.to(device=run_device, dtype=torch.float32)
200
+
201
+ # Precompute logSNR, alpha, sigma, half-lambda for all schedule points
202
+ lmb = shifted_cosine_interpolated_logsnr_from_t(
203
+ schedule.to(device=run_device),
204
+ logsnr_min=logsnr_min,
205
+ logsnr_max=logsnr_max,
206
+ log_change_high=log_change_high,
207
+ log_change_low=log_change_low,
208
+ )
209
+ alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
210
+ half_lambda = 0.5 * lmb.to(torch.float32)
211
+
212
+ x0_prev: Tensor | None = None
213
+
214
+ for i in range(int(schedule.numel()) - 1):
215
+ t_i = schedule[i]
216
+ s_t = sigma_sched[i].expand(batch_size)
217
+ a_next = alpha_sched[i + 1].expand(batch_size)
218
+ s_next = sigma_sched[i + 1].expand(batch_size)
219
+
220
+ # Model prediction with optional PDG
221
+ t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
222
+ x0_hat = _predict_with_pdg(
223
+ forward_fn,
224
+ state,
225
+ t_vec,
226
+ latents,
227
+ pdg_mode=pdg_mode,
228
+ pdg_strength=pdg_strength,
229
+ )
230
+
231
+ lam_t = half_lambda[i].expand(batch_size)
232
+ lam_next = half_lambda[i + 1].expand(batch_size)
233
+ h = (lam_next - lam_t).to(torch.float32)
234
+ phi_1 = torch.expm1(-h)
235
+
236
+ sigma_ratio = (s_next / s_t).to(torch.float32)
237
+
238
+ if i == 0 or x0_prev is None:
239
+ # First-order step
240
+ state = (
241
+ sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
242
+ - broadcast_time_like(a_next, state).to(torch.float32)
243
+ * broadcast_time_like(phi_1, state).to(torch.float32)
244
+ * x0_hat
245
+ )
246
+ else:
247
+ # Second-order step
248
+ lam_prev = half_lambda[i - 1].expand(batch_size)
249
+ h_0 = (lam_t - lam_prev).to(torch.float32)
250
+ r0 = h_0 / h
251
+ d1_0 = (x0_hat - x0_prev) / broadcast_time_like(r0, x0_hat)
252
+ common = broadcast_time_like(a_next, state).to(
253
+ torch.float32
254
+ ) * broadcast_time_like(phi_1, state).to(torch.float32)
255
+ state = (
256
+ sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
257
+ - common * x0_hat
258
+ - 0.5 * common * d1_0
259
+ )
260
+
261
+ x0_prev = x0_hat
262
+
263
+ return state
fcdm_diffae/straight_through_encoder.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PixelUnshuffle-based patchifier (no residual conv path)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class Patchify(nn.Module):
9
+ """PixelUnshuffle(patch) -> Conv2d 1x1 projection.
10
+
11
+ Converts [B, C, H, W] images into [B, out_channels, H/patch, W/patch] features.
12
+ """
13
+
14
+ def __init__(self, in_channels: int, patch: int, out_channels: int) -> None:
15
+ super().__init__()
16
+ self.patch = int(patch)
17
+ self.unshuffle = nn.PixelUnshuffle(self.patch)
18
+ in_after = in_channels * (self.patch * self.patch)
19
+ self.proj = nn.Conv2d(in_after, out_channels, kernel_size=1, bias=True)
20
+
21
+ def forward(self, x: Tensor) -> Tensor:
22
+ if x.shape[2] % self.patch != 0 or x.shape[3] % self.patch != 0:
23
+ raise ValueError(
24
+ f"Input H={x.shape[2]} and W={x.shape[3]} must be divisible by patch={self.patch}"
25
+ )
26
+ y = self.unshuffle(x)
27
+ return self.proj(y)
fcdm_diffae/time_embed.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sinusoidal timestep embedding with MLP projection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+
11
+ def _log_spaced_frequencies(
12
+ half: int, max_period: float, *, device: torch.device | None = None
13
+ ) -> Tensor:
14
+ """Log-spaced frequencies for sinusoidal embedding."""
15
+ return torch.exp(
16
+ -math.log(max_period)
17
+ * torch.arange(half, device=device, dtype=torch.float32)
18
+ / max(float(half - 1), 1.0)
19
+ )
20
+
21
+
22
+ def sinusoidal_time_embedding(
23
+ t: Tensor,
24
+ dim: int,
25
+ *,
26
+ max_period: float = 10000.0,
27
+ scale: float | None = None,
28
+ freqs: Tensor | None = None,
29
+ ) -> Tensor:
30
+ """Sinusoidal timestep embedding (DDPM/DiT-style). Always float32."""
31
+ t32 = t.to(torch.float32)
32
+ if scale is not None:
33
+ t32 = t32 * float(scale)
34
+ half = dim // 2
35
+ if freqs is not None:
36
+ freqs = freqs.to(device=t32.device, dtype=torch.float32)
37
+ else:
38
+ freqs = _log_spaced_frequencies(half, max_period, device=t32.device)
39
+ angles = t32[:, None] * freqs[None, :]
40
+ return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
41
+
42
+
43
+ class SinusoidalTimeEmbeddingMLP(nn.Module):
44
+ """Sinusoidal time embedding followed by Linear -> SiLU -> Linear."""
45
+
46
+ def __init__(
47
+ self,
48
+ dim: int,
49
+ *,
50
+ freq_dim: int = 256,
51
+ hidden_mult: float = 1.0,
52
+ time_scale: float = 1000.0,
53
+ max_period: float = 10000.0,
54
+ ) -> None:
55
+ super().__init__()
56
+ self.dim = int(dim)
57
+ self.freq_dim = int(freq_dim)
58
+ self.time_scale = float(time_scale)
59
+ self.max_period = float(max_period)
60
+ hidden_dim = max(int(round(int(dim) * float(hidden_mult))), 1)
61
+
62
+ freqs = _log_spaced_frequencies(self.freq_dim // 2, self.max_period)
63
+ self.register_buffer("freqs", freqs, persistent=True)
64
+
65
+ self.proj_in = nn.Linear(self.freq_dim, hidden_dim)
66
+ self.act = nn.SiLU()
67
+ self.proj_out = nn.Linear(hidden_dim, self.dim)
68
+
69
+ def forward(self, t: Tensor) -> Tensor:
70
+ freqs: Tensor = self.freqs # type: ignore[assignment]
71
+ emb_freq = sinusoidal_time_embedding(
72
+ t.to(torch.float32),
73
+ self.freq_dim,
74
+ max_period=self.max_period,
75
+ scale=self.time_scale,
76
+ freqs=freqs,
77
+ )
78
+ dtype_in = self.proj_in.weight.dtype
79
+ hidden = self.proj_in(emb_freq.to(dtype_in))
80
+ hidden = self.act(hidden)
81
+ if hidden.dtype != self.proj_out.weight.dtype:
82
+ hidden = hidden.to(self.proj_out.weight.dtype)
83
+ return self.proj_out(hidden)
fcdm_diffae/vp_diffusion.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VP diffusion math: logSNR schedules, alpha/sigma computation, noise construction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+
11
+ def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]:
12
+ """Compute (alpha, sigma) from logSNR in float32.
13
+
14
+ VP constraint: alpha^2 + sigma^2 = 1.
15
+ """
16
+ lmb32 = lmb.to(dtype=torch.float32)
17
+ alpha = torch.sqrt(torch.sigmoid(lmb32))
18
+ sigma = torch.sqrt(torch.sigmoid(-lmb32))
19
+ return alpha, sigma
20
+
21
+
22
+ def broadcast_time_like(coeff: Tensor, x: Tensor) -> Tensor:
23
+ """Broadcast [B] coefficient to match x for per-sample scaling."""
24
+ view_shape = (int(x.shape[0]),) + (1,) * (x.dim() - 1)
25
+ return coeff.view(view_shape)
26
+
27
+
28
+ def _cosine_interpolated_params(
29
+ logsnr_min: float, logsnr_max: float
30
+ ) -> tuple[float, float]:
31
+ """Compute (a, b) for cosine-interpolated logSNR schedule.
32
+
33
+ logsnr(t) = -2 * log(tan(a*t + b))
34
+ logsnr(0) = logsnr_max, logsnr(1) = logsnr_min
35
+ """
36
+ b = math.atan(math.exp(-0.5 * logsnr_max))
37
+ a = math.atan(math.exp(-0.5 * logsnr_min)) - b
38
+ return a, b
39
+
40
+
41
+ def cosine_interpolated_logsnr_from_t(
42
+ t: Tensor, *, logsnr_min: float, logsnr_max: float
43
+ ) -> Tensor:
44
+ """Map t in [0,1] to logSNR via cosine-interpolated schedule. Always float32."""
45
+ a, b = _cosine_interpolated_params(logsnr_min, logsnr_max)
46
+ t32 = t.to(dtype=torch.float32)
47
+ a_t = torch.tensor(a, device=t32.device, dtype=torch.float32)
48
+ b_t = torch.tensor(b, device=t32.device, dtype=torch.float32)
49
+ u = a_t * t32 + b_t
50
+ return -2.0 * torch.log(torch.tan(u))
51
+
52
+
53
+ def shifted_cosine_interpolated_logsnr_from_t(
54
+ t: Tensor,
55
+ *,
56
+ logsnr_min: float,
57
+ logsnr_max: float,
58
+ log_change_high: float = 0.0,
59
+ log_change_low: float = 0.0,
60
+ ) -> Tensor:
61
+ """SiD2 "shifted cosine" schedule: logSNR with resolution-dependent shifts.
62
+
63
+ lambda(t) = (1-t) * (base(t) + log_change_high) + t * (base(t) + log_change_low)
64
+ """
65
+ base = cosine_interpolated_logsnr_from_t(
66
+ t, logsnr_min=logsnr_min, logsnr_max=logsnr_max
67
+ )
68
+ t32 = t.to(dtype=torch.float32)
69
+ high = base + float(log_change_high)
70
+ low = base + float(log_change_low)
71
+ return (1.0 - t32) * high + t32 * low
72
+
73
+
74
+ def get_schedule(schedule_type: str, num_steps: int) -> Tensor:
75
+ """Generate a descending t-schedule in [0, 1] for VP diffusion sampling.
76
+
77
+ ``num_steps`` is the number of function evaluations (NFE = decoder forward
78
+ passes). Internally the schedule has ``num_steps + 1`` time points
79
+ (including both endpoints).
80
+
81
+ Args:
82
+ schedule_type: "linear" or "cosine".
83
+ num_steps: Number of decoder forward passes (NFE), >= 1.
84
+
85
+ Returns:
86
+ Descending 1D tensor with ``num_steps + 1`` elements from ~1.0 to ~0.0.
87
+ """
88
+ # NOTE: the upstream training code (src/ode/time_schedules.py) uses a
89
+ # different convention where num_steps counts schedule *points* (so NFE =
90
+ # num_steps - 1). This export package corrects the off-by-one so that
91
+ # num_steps means NFE directly. TODO: align the upstream convention.
92
+ n = max(int(num_steps) + 1, 2)
93
+ if schedule_type == "linear":
94
+ base = torch.linspace(0.0, 1.0, n)
95
+ elif schedule_type == "cosine":
96
+ i = torch.arange(n, dtype=torch.float32)
97
+ base = 0.5 * (1.0 - torch.cos(math.pi * (i / (n - 1))))
98
+ else:
99
+ raise ValueError(
100
+ f"Unsupported schedule type: {schedule_type!r}. Use 'linear' or 'cosine'."
101
+ )
102
+ # Descending: high t (noisy) -> low t (clean)
103
+ return torch.flip(base, dims=[0])
104
+
105
+
106
+ def make_initial_state(
107
+ *,
108
+ noise: Tensor,
109
+ t_start: Tensor,
110
+ logsnr_min: float,
111
+ logsnr_max: float,
112
+ log_change_high: float = 0.0,
113
+ log_change_low: float = 0.0,
114
+ ) -> Tensor:
115
+ """Construct VP initial state x_t0 = sigma_start * noise (since x0=0).
116
+
117
+ All math in float32.
118
+ """
119
+ batch = int(noise.shape[0])
120
+ lmb_start = shifted_cosine_interpolated_logsnr_from_t(
121
+ t_start.expand(batch).to(dtype=torch.float32),
122
+ logsnr_min=logsnr_min,
123
+ logsnr_max=logsnr_max,
124
+ log_change_high=log_change_high,
125
+ log_change_low=log_change_low,
126
+ )
127
+ _alpha_start, sigma_start = alpha_sigma_from_logsnr(lmb_start)
128
+ sigma_view = broadcast_time_like(sigma_start, noise)
129
+ return sigma_view * noise.to(dtype=torch.float32)
130
+
131
+
132
+ def sample_noise(
133
+ shape: tuple[int, ...],
134
+ *,
135
+ noise_std: float = 1.0,
136
+ seed: int | None = None,
137
+ device: torch.device | None = None,
138
+ dtype: torch.dtype = torch.float32,
139
+ ) -> Tensor:
140
+ """Sample Gaussian noise with optional seeding. CPU-seeded for reproducibility."""
141
+ if seed is None:
142
+ noise = torch.randn(
143
+ shape, device=device or torch.device("cpu"), dtype=torch.float32
144
+ )
145
+ else:
146
+ gen = torch.Generator(device="cpu")
147
+ gen.manual_seed(int(seed))
148
+ noise = torch.randn(shape, generator=gen, device="cpu", dtype=torch.float32)
149
+ noise = noise.mul(float(noise_std))
150
+ target_device = device if device is not None else torch.device("cpu")
151
+ return noise.to(device=target_device, dtype=dtype)