data-archetype commited on
Commit
8f5592f
·
verified ·
1 Parent(s): 6199cd7

Remove old capacitor_diffae package (renamed to fcdm_diffae)

Browse files
capacitor_diffae/__init__.py DELETED
@@ -1,33 +0,0 @@
1
- """CapacitorDiffAE: Standalone diffusion autoencoder with FCDM blocks.
2
-
3
- Capacitor DiffAE — a fast diffusion autoencoder with a 128-channel spatial
4
- bottleneck and a VP-parameterized diagonal Gaussian posterior. Built on FCDM
5
- (Fully Convolutional Diffusion Model) blocks with GRN and scale+gate AdaLN.
6
-
7
- Usage::
8
-
9
- from capacitor_diffae import CapacitorDiffAE, CapacitorDiffAEInferenceConfig
10
-
11
- model = CapacitorDiffAE.from_pretrained("path/to/weights", device="cuda")
12
-
13
- # Encode (returns posterior mode by default)
14
- latents = model.encode(images) # images: [B,3,H,W] in [-1,1]
15
-
16
- # Decode — PSNR-optimal (1 step, default)
17
- recon = model.decode(latents, height=H, width=W)
18
-
19
- # Decode — perceptual sharpness (10 steps + path-drop PDG)
20
- cfg = CapacitorDiffAEInferenceConfig(num_steps=10, pdg=True, pdg_strength=2.0)
21
- recon = model.decode(latents, height=H, width=W, inference_config=cfg)
22
- """
23
-
24
- from .config import CapacitorDiffAEConfig, CapacitorDiffAEInferenceConfig
25
- from .encoder import EncoderPosterior
26
- from .model import CapacitorDiffAE
27
-
28
- __all__ = [
29
- "CapacitorDiffAE",
30
- "CapacitorDiffAEConfig",
31
- "CapacitorDiffAEInferenceConfig",
32
- "EncoderPosterior",
33
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capacitor_diffae/adaln.py DELETED
@@ -1,50 +0,0 @@
1
- """Scale+Gate AdaLN (2-way) for FCDM decoder blocks."""
2
-
3
- from __future__ import annotations
4
-
5
- from torch import Tensor, nn
6
-
7
-
8
- class AdaLNScaleGateZeroProjector(nn.Module):
9
- """Packed 2-way AdaLN projection (SiLU -> Linear), zero-initialized.
10
-
11
- Outputs [B, 2*d_model] packed as (scale, gate).
12
- """
13
-
14
- def __init__(self, d_model: int, d_cond: int) -> None:
15
- super().__init__()
16
- self.d_model: int = int(d_model)
17
- self.d_cond: int = int(d_cond)
18
- self.act: nn.SiLU = nn.SiLU()
19
- self.proj: nn.Linear = nn.Linear(self.d_cond, 2 * self.d_model)
20
- nn.init.zeros_(self.proj.weight)
21
- nn.init.zeros_(self.proj.bias)
22
-
23
- def forward_activated(self, act_cond: Tensor) -> Tensor:
24
- """Return packed modulation for a pre-activated conditioning vector."""
25
- return self.proj(act_cond)
26
-
27
- def forward(self, cond: Tensor) -> Tensor:
28
- """Return packed modulation [B, 2*d_model]."""
29
- return self.forward_activated(self.act(cond))
30
-
31
-
32
- class AdaLNScaleGateZeroLowRankDelta(nn.Module):
33
- """Low-rank delta for 2-way AdaLN: down(d_cond -> rank) -> up(rank -> 2*d_model).
34
-
35
- Zero-initialized up projection preserves zero-output semantics at init.
36
- """
37
-
38
- def __init__(self, *, d_model: int, d_cond: int, rank: int) -> None:
39
- super().__init__()
40
- self.d_model: int = int(d_model)
41
- self.d_cond: int = int(d_cond)
42
- self.rank: int = int(rank)
43
- self.down: nn.Linear = nn.Linear(self.d_cond, self.rank, bias=False)
44
- self.up: nn.Linear = nn.Linear(self.rank, 2 * self.d_model, bias=False)
45
- nn.init.normal_(self.down.weight, mean=0.0, std=0.02)
46
- nn.init.zeros_(self.up.weight)
47
-
48
- def forward(self, act_cond: Tensor) -> Tensor:
49
- """Return packed delta modulation [B, 2*d_model]."""
50
- return self.up(self.down(act_cond))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capacitor_diffae/config.py DELETED
@@ -1,62 +0,0 @@
1
- """Frozen model architecture and user-tunable inference configuration."""
2
-
3
- from __future__ import annotations
4
-
5
- import json
6
- from dataclasses import asdict, dataclass
7
- from pathlib import Path
8
-
9
-
10
- @dataclass(frozen=True)
11
- class CapacitorDiffAEConfig:
12
- """Frozen model architecture config. Stored alongside weights as config.json."""
13
-
14
- in_channels: int = 3
15
- patch_size: int = 16
16
- model_dim: int = 896
17
- encoder_depth: int = 4
18
- decoder_depth: int = 8
19
- decoder_start_blocks: int = 2
20
- decoder_end_blocks: int = 2
21
- bottleneck_dim: int = 128
22
- mlp_ratio: float = 4.0
23
- depthwise_kernel_size: int = 7
24
- adaln_low_rank_rank: int = 128
25
- # Encoder posterior kind: "diagonal_gaussian" or "deterministic"
26
- bottleneck_posterior_kind: str = "diagonal_gaussian"
27
- # Post-bottleneck normalization: "channel_wise" or "disabled"
28
- bottleneck_norm_mode: str = "disabled"
29
- # VP diffusion schedule endpoints
30
- logsnr_min: float = -10.0
31
- logsnr_max: float = 10.0
32
- # Pixel-space noise std for VP diffusion initialization
33
- pixel_noise_std: float = 0.558
34
-
35
- def save(self, path: str | Path) -> None:
36
- """Save config as JSON."""
37
- p = Path(path)
38
- p.parent.mkdir(parents=True, exist_ok=True)
39
- p.write_text(json.dumps(asdict(self), indent=2) + "\n")
40
-
41
- @classmethod
42
- def load(cls, path: str | Path) -> CapacitorDiffAEConfig:
43
- """Load config from JSON."""
44
- data = json.loads(Path(path).read_text())
45
- return cls(**data)
46
-
47
-
48
- @dataclass
49
- class CapacitorDiffAEInferenceConfig:
50
- """User-tunable inference parameters with sensible defaults.
51
-
52
- PDG (Path-Drop Guidance) sharpens reconstructions by degrading conditioning
53
- in one pass and amplifying the difference. When enabled, uses 2 NFE per step.
54
- Recommended: ``pdg=True, pdg_strength=2.0, num_steps=10``.
55
- """
56
-
57
- num_steps: int = 1 # number of denoising steps (NFE)
58
- sampler: str = "ddim" # "ddim" or "dpmpp_2m"
59
- schedule: str = "linear" # "linear" or "cosine"
60
- pdg: bool = False # enable PDG for perceptual sharpening
61
- pdg_strength: float = 2.0 # CFG-like strength when pdg=True
62
- seed: int | None = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capacitor_diffae/decoder.py DELETED
@@ -1,169 +0,0 @@
1
- """Capacitor decoder: skip-concat topology with FCDM blocks and dual PDG.
2
-
3
- No outer RMSNorms (use_other_outer_rms_norms=False during training):
4
- norm_in, latent_norm, and norm_out are all absent.
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- import torch
10
- from torch import Tensor, nn
11
-
12
- from .adaln import AdaLNScaleGateZeroLowRankDelta, AdaLNScaleGateZeroProjector
13
- from .fcdm_block import FCDMBlock
14
- from .straight_through_encoder import Patchify
15
- from .time_embed import SinusoidalTimeEmbeddingMLP
16
-
17
-
18
- class Decoder(nn.Module):
19
- """VP diffusion decoder conditioned on encoder latents and timestep.
20
-
21
- Architecture (skip-concat, 2+4+2 default):
22
- Patchify x_t -> Fuse with upsampled z
23
- -> Start blocks (2) -> Middle blocks (4) -> Skip fuse -> End blocks (2)
24
- -> Conv1x1 -> PixelShuffle
25
-
26
- Dual PDG at inference:
27
- - Path drop: replace middle block output with ``path_drop_mask_feature``.
28
- - Token mask: replace a fraction of upsampled latent tokens with
29
- ``latent_mask_feature`` before fusion.
30
- """
31
-
32
- def __init__(
33
- self,
34
- in_channels: int,
35
- patch_size: int,
36
- model_dim: int,
37
- depth: int,
38
- start_block_count: int,
39
- end_block_count: int,
40
- bottleneck_dim: int,
41
- mlp_ratio: float,
42
- depthwise_kernel_size: int,
43
- adaln_low_rank_rank: int,
44
- ) -> None:
45
- super().__init__()
46
- self.patch_size = int(patch_size)
47
- self.model_dim = int(model_dim)
48
-
49
- # Input processing (no norm_in)
50
- self.patchify = Patchify(in_channels, patch_size, model_dim)
51
-
52
- # Latent conditioning path (no latent_norm)
53
- self.latent_up = nn.Conv2d(bottleneck_dim, model_dim, kernel_size=1, bias=True)
54
- self.fuse_in = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
55
-
56
- # Time embedding
57
- self.time_embed = SinusoidalTimeEmbeddingMLP(model_dim)
58
-
59
- # 2-way AdaLN: shared base projector + per-block low-rank deltas
60
- self.adaln_base = AdaLNScaleGateZeroProjector(
61
- d_model=model_dim, d_cond=model_dim
62
- )
63
- self.adaln_deltas = nn.ModuleList(
64
- [
65
- AdaLNScaleGateZeroLowRankDelta(
66
- d_model=model_dim, d_cond=model_dim, rank=adaln_low_rank_rank
67
- )
68
- for _ in range(depth)
69
- ]
70
- )
71
-
72
- # Block layout: start + middle + end
73
- middle_count = depth - start_block_count - end_block_count
74
- self._middle_start_idx = start_block_count
75
- self._end_start_idx = start_block_count + middle_count
76
-
77
- def _make_blocks(count: int) -> nn.ModuleList:
78
- return nn.ModuleList(
79
- [
80
- FCDMBlock(
81
- model_dim,
82
- mlp_ratio,
83
- depthwise_kernel_size=depthwise_kernel_size,
84
- use_external_adaln=True,
85
- )
86
- for _ in range(count)
87
- ]
88
- )
89
-
90
- self.start_blocks = _make_blocks(start_block_count)
91
- self.middle_blocks = _make_blocks(middle_count)
92
- self.fuse_skip = nn.Conv2d(2 * model_dim, model_dim, kernel_size=1, bias=True)
93
- self.end_blocks = _make_blocks(end_block_count)
94
-
95
- # Learned mask feature for path-drop PDG
96
- self.path_drop_mask_feature = nn.Parameter(torch.zeros((1, model_dim, 1, 1)))
97
-
98
- # Output head (no norm_out)
99
- self.out_proj = nn.Conv2d(
100
- model_dim, in_channels * (patch_size**2), kernel_size=1, bias=True
101
- )
102
- self.unpatchify = nn.PixelShuffle(patch_size)
103
-
104
- def _adaln_m_for_layer(self, cond: Tensor, layer_idx: int) -> Tensor:
105
- """Compute packed AdaLN modulation = shared_base + per-layer delta."""
106
- act = self.adaln_base.act(cond)
107
- base_m = self.adaln_base.forward_activated(act)
108
- delta_m = self.adaln_deltas[layer_idx](act)
109
- return base_m + delta_m
110
-
111
- def _run_blocks(
112
- self, blocks: nn.ModuleList, x: Tensor, cond: Tensor, start_index: int
113
- ) -> Tensor:
114
- """Run a group of decoder blocks with per-block AdaLN modulation."""
115
- for local_idx, block in enumerate(blocks):
116
- adaln_m = self._adaln_m_for_layer(cond, layer_idx=start_index + local_idx)
117
- x = block(x, adaln_m=adaln_m)
118
- return x
119
-
120
- def forward(
121
- self,
122
- x_t: Tensor,
123
- t: Tensor,
124
- latents: Tensor,
125
- *,
126
- drop_middle_blocks: bool = False,
127
- ) -> Tensor:
128
- """Single decoder forward pass.
129
-
130
- Args:
131
- x_t: Noised image [B, C, H, W].
132
- t: Timestep [B] in [0, 1].
133
- latents: Encoder latents [B, bottleneck_dim, h, w].
134
- drop_middle_blocks: Replace middle block output with mask feature (PDG).
135
-
136
- Returns:
137
- x0 prediction [B, C, H, W].
138
- """
139
- x_feat = self.patchify(x_t)
140
- z_up = self.latent_up(latents)
141
-
142
- fused = torch.cat([x_feat, z_up], dim=1)
143
- fused = self.fuse_in(fused)
144
-
145
- cond = self.time_embed(t.to(torch.float32).to(device=x_t.device))
146
-
147
- start_out = self._run_blocks(self.start_blocks, fused, cond, start_index=0)
148
-
149
- if drop_middle_blocks:
150
- middle_out = self.path_drop_mask_feature.to(
151
- device=x_t.device, dtype=x_t.dtype
152
- ).expand_as(start_out)
153
- else:
154
- middle_out = self._run_blocks(
155
- self.middle_blocks,
156
- start_out,
157
- cond,
158
- start_index=self._middle_start_idx,
159
- )
160
-
161
- skip_fused = torch.cat([start_out, middle_out], dim=1)
162
- skip_fused = self.fuse_skip(skip_fused)
163
-
164
- end_out = self._run_blocks(
165
- self.end_blocks, skip_fused, cond, start_index=self._end_start_idx
166
- )
167
-
168
- patches = self.out_proj(end_out)
169
- return self.unpatchify(patches)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capacitor_diffae/encoder.py DELETED
@@ -1,129 +0,0 @@
1
- """Capacitor encoder: patchify -> FCDMBlocks -> diagonal Gaussian posterior.
2
-
3
- No input RMSNorm (use_other_outer_rms_norms=False during training).
4
- Post-bottleneck RMSNorm (affine=False) on the mean branch.
5
- Encoder outputs posterior mode by default: alpha * RMSNorm(mean).
6
- """
7
-
8
- from __future__ import annotations
9
-
10
- from dataclasses import dataclass
11
-
12
- import torch
13
- from torch import Tensor, nn
14
-
15
- from .fcdm_block import FCDMBlock
16
- from .norms import ChannelWiseRMSNorm
17
- from .straight_through_encoder import Patchify
18
-
19
-
20
- @dataclass(frozen=True)
21
- class EncoderPosterior:
22
- """VP-parameterized diagonal Gaussian posterior.
23
-
24
- mean: Clean signal branch mu [B, bottleneck_dim, h, w]
25
- logsnr: Per-element log signal-to-noise ratio [B, bottleneck_dim, h, w]
26
- """
27
-
28
- mean: Tensor
29
- logsnr: Tensor
30
-
31
- @property
32
- def alpha(self) -> Tensor:
33
- """VP signal coefficient: sqrt(sigmoid(logsnr))."""
34
- return torch.sigmoid(self.logsnr).sqrt()
35
-
36
- @property
37
- def sigma(self) -> Tensor:
38
- """VP noise coefficient: sqrt(sigmoid(-logsnr))."""
39
- return torch.sigmoid(-self.logsnr).sqrt()
40
-
41
- def mode(self) -> Tensor:
42
- """Posterior mode in token space: alpha * mean."""
43
- return self.alpha.to(dtype=self.mean.dtype) * self.mean
44
-
45
- def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
46
- """Sample from posterior: alpha * mean + sigma * eps."""
47
- eps = torch.randn_like(self.mean, generator=generator) # type: ignore[call-overload]
48
- alpha = self.alpha.to(dtype=self.mean.dtype)
49
- sigma = self.sigma.to(dtype=self.mean.dtype)
50
- return alpha * self.mean + sigma * eps
51
-
52
-
53
- class Encoder(nn.Module):
54
- """Encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].
55
-
56
- With diagonal_gaussian posterior, the to_bottleneck projection outputs
57
- 2 * bottleneck_dim channels, split into mean and logsnr. The default
58
- encode() returns the posterior mode: alpha * RMSNorm(mean).
59
- """
60
-
61
- def __init__(
62
- self,
63
- in_channels: int,
64
- patch_size: int,
65
- model_dim: int,
66
- depth: int,
67
- bottleneck_dim: int,
68
- mlp_ratio: float,
69
- depthwise_kernel_size: int,
70
- bottleneck_posterior_kind: str = "diagonal_gaussian",
71
- bottleneck_norm_mode: str = "disabled",
72
- ) -> None:
73
- super().__init__()
74
- self.bottleneck_dim = int(bottleneck_dim)
75
- self.bottleneck_posterior_kind = bottleneck_posterior_kind
76
- self.bottleneck_norm_mode = bottleneck_norm_mode
77
- self.patchify = Patchify(in_channels, patch_size, model_dim)
78
- self.blocks = nn.ModuleList(
79
- [
80
- FCDMBlock(
81
- model_dim,
82
- mlp_ratio,
83
- depthwise_kernel_size=depthwise_kernel_size,
84
- use_external_adaln=False,
85
- )
86
- for _ in range(depth)
87
- ]
88
- )
89
- out_dim = (
90
- 2 * bottleneck_dim
91
- if bottleneck_posterior_kind == "diagonal_gaussian"
92
- else bottleneck_dim
93
- )
94
- self.to_bottleneck = nn.Conv2d(model_dim, out_dim, kernel_size=1, bias=True)
95
- if bottleneck_norm_mode == "channel_wise":
96
- self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)
97
- else:
98
- self.norm_out = nn.Identity()
99
-
100
- def encode_posterior(self, images: Tensor) -> EncoderPosterior:
101
- """Encode images and return the full posterior (mean + logsnr).
102
-
103
- Only valid when bottleneck_posterior_kind == "diagonal_gaussian".
104
- """
105
- z = self.patchify(images)
106
- for block in self.blocks:
107
- z = block(z)
108
- projection = self.to_bottleneck(z)
109
- mean, logsnr = projection.chunk(2, dim=1)
110
- mean = self.norm_out(mean)
111
- return EncoderPosterior(mean=mean, logsnr=logsnr)
112
-
113
- def forward(self, images: Tensor) -> Tensor:
114
- """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w].
115
-
116
- Returns posterior mode (alpha * mean) for diagonal_gaussian,
117
- or deterministic latents otherwise.
118
- """
119
- z = self.patchify(images)
120
- for block in self.blocks:
121
- z = block(z)
122
- projection = self.to_bottleneck(z)
123
- if self.bottleneck_posterior_kind == "diagonal_gaussian":
124
- mean, logsnr = projection.chunk(2, dim=1)
125
- mean = self.norm_out(mean)
126
- alpha = torch.sigmoid(logsnr).sqrt().to(dtype=mean.dtype)
127
- return alpha * mean
128
- z = self.norm_out(projection)
129
- return z
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capacitor_diffae/fcdm_block.py DELETED
@@ -1,103 +0,0 @@
1
- """FCDM block: ConvNeXt-style conv block with GRN and scale+gate AdaLN."""
2
-
3
- from __future__ import annotations
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- from torch import Tensor, nn
8
-
9
- from .norms import ChannelWiseRMSNorm
10
-
11
-
12
- class GRN(nn.Module):
13
- """Global Response Normalization for NCHW tensors."""
14
-
15
- def __init__(self, channels: int, *, eps: float = 1e-6) -> None:
16
- super().__init__()
17
- self.eps: float = float(eps)
18
- c = int(channels)
19
- self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
20
- self.beta = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
21
-
22
- def forward(self, x: Tensor) -> Tensor:
23
- g = torch.linalg.vector_norm(x, ord=2, dim=(2, 3), keepdim=True)
24
- g_fp32 = g.to(dtype=torch.float32)
25
- n = (g_fp32 / (g_fp32.mean(dim=1, keepdim=True) + self.eps)).to(dtype=x.dtype)
26
- gamma = self.gamma.to(device=x.device, dtype=x.dtype)
27
- beta = self.beta.to(device=x.device, dtype=x.dtype)
28
- return gamma * (x * n) + beta + x
29
-
30
-
31
- class FCDMBlock(nn.Module):
32
- """ConvNeXt-style block with scale+gate AdaLN and GRN.
33
-
34
- Two modes:
35
- - Unconditioned (encoder): uses learned layer-scale for near-identity init.
36
- - External AdaLN (decoder): receives packed [B, 2*C] modulation (scale, gate).
37
- The gate is applied raw (no tanh).
38
- """
39
-
40
- def __init__(
41
- self,
42
- channels: int,
43
- mlp_ratio: float,
44
- *,
45
- depthwise_kernel_size: int = 7,
46
- use_external_adaln: bool = False,
47
- norm_eps: float = 1e-6,
48
- layer_scale_init: float = 1e-3,
49
- ) -> None:
50
- super().__init__()
51
- self.channels: int = int(channels)
52
- self.mlp_ratio: float = float(mlp_ratio)
53
-
54
- self.dwconv = nn.Conv2d(
55
- channels,
56
- channels,
57
- kernel_size=depthwise_kernel_size,
58
- padding=depthwise_kernel_size // 2,
59
- stride=1,
60
- groups=channels,
61
- bias=True,
62
- )
63
- self.norm = ChannelWiseRMSNorm(channels, eps=float(norm_eps), affine=False)
64
- hidden = max(int(float(channels) * float(mlp_ratio)), 1)
65
- self.pwconv1 = nn.Conv2d(channels, hidden, kernel_size=1, bias=True)
66
- self.grn = GRN(hidden, eps=1e-6)
67
- self.pwconv2 = nn.Conv2d(hidden, channels, kernel_size=1, bias=True)
68
-
69
- if not use_external_adaln:
70
- self.layer_scale = nn.Parameter(
71
- torch.full((channels,), float(layer_scale_init))
72
- )
73
- else:
74
- self.register_parameter("layer_scale", None)
75
-
76
- def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor:
77
- b, c, _, _ = x.shape
78
-
79
- if adaln_m is not None:
80
- m = adaln_m.to(device=x.device, dtype=x.dtype)
81
- scale, gate = m.chunk(2, dim=-1)
82
- else:
83
- scale = gate = None
84
-
85
- h = self.dwconv(x)
86
- h = self.norm(h)
87
-
88
- if scale is not None:
89
- h = h * (1.0 + scale.view(b, c, 1, 1))
90
-
91
- h = self.pwconv1(h)
92
- h = F.gelu(h)
93
- h = self.grn(h)
94
- h = self.pwconv2(h)
95
-
96
- if gate is not None:
97
- gate_view = gate.view(b, c, 1, 1)
98
- else:
99
- gate_view = self.layer_scale.view(1, c, 1, 1).to( # type: ignore[union-attr]
100
- device=h.device, dtype=h.dtype
101
- )
102
-
103
- return x + gate_view * h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capacitor_diffae/model.py DELETED
@@ -1,364 +0,0 @@
1
- """CapacitorDiffAE: standalone HuggingFace-compatible diffusion autoencoder."""
2
-
3
- from __future__ import annotations
4
-
5
- from pathlib import Path
6
-
7
- import torch
8
- from torch import Tensor, nn
9
-
10
- from .config import CapacitorDiffAEConfig, CapacitorDiffAEInferenceConfig
11
- from .decoder import Decoder
12
- from .encoder import Encoder, EncoderPosterior
13
- from .samplers import run_ddim, run_dpmpp_2m
14
- from .vp_diffusion import get_schedule, make_initial_state, sample_noise
15
-
16
-
17
- def _resolve_model_dir(
18
- path_or_repo_id: str | Path,
19
- *,
20
- revision: str | None,
21
- cache_dir: str | Path | None,
22
- ) -> Path:
23
- """Resolve a local path or HuggingFace Hub repo ID to a local directory."""
24
- local = Path(path_or_repo_id)
25
- if local.is_dir():
26
- return local
27
- repo_id = str(path_or_repo_id)
28
- try:
29
- from huggingface_hub import snapshot_download
30
- except ImportError:
31
- raise ImportError(
32
- f"'{repo_id}' is not an existing local directory. "
33
- "To download from HuggingFace Hub, install huggingface_hub: "
34
- "pip install huggingface_hub"
35
- )
36
- cache_dir_str = str(cache_dir) if cache_dir is not None else None
37
- local_dir = snapshot_download(
38
- repo_id,
39
- revision=revision,
40
- cache_dir=cache_dir_str,
41
- )
42
- return Path(local_dir)
43
-
44
-
45
- class CapacitorDiffAE(nn.Module):
46
- """Standalone Capacitor DiffAE model for HuggingFace distribution.
47
-
48
- A diffusion autoencoder built on FCDM (Fully Convolutional Diffusion Model)
49
- blocks. Encodes images to compact 128-channel spatial latents via a
50
- VP-parameterized diagonal Gaussian posterior, and decodes them back via
51
- iterative VP diffusion with a skip-concat decoder.
52
-
53
- Usage::
54
-
55
- model = CapacitorDiffAE.from_pretrained("path/to/weights")
56
- model = model.to("cuda", dtype=torch.bfloat16)
57
-
58
- # Encode (returns posterior mode by default)
59
- latents = model.encode(images) # images: [B,3,H,W] in [-1,1]
60
-
61
- # Decode (1 step by default — PSNR-optimal)
62
- recon = model.decode(latents, height=H, width=W)
63
-
64
- # Reconstruct (encode + 1-step decode)
65
- recon = model.reconstruct(images)
66
- """
67
-
68
- _LATENT_NORM_EPS: float = 1e-4
69
-
70
- def __init__(self, config: CapacitorDiffAEConfig) -> None:
71
- super().__init__()
72
- self.config = config
73
-
74
- # Latent running stats for whitening/dewhitening
75
- self.register_buffer(
76
- "latent_norm_running_mean",
77
- torch.zeros((config.bottleneck_dim,), dtype=torch.float32),
78
- )
79
- self.register_buffer(
80
- "latent_norm_running_var",
81
- torch.ones((config.bottleneck_dim,), dtype=torch.float32),
82
- )
83
-
84
- self.encoder = Encoder(
85
- in_channels=config.in_channels,
86
- patch_size=config.patch_size,
87
- model_dim=config.model_dim,
88
- depth=config.encoder_depth,
89
- bottleneck_dim=config.bottleneck_dim,
90
- mlp_ratio=config.mlp_ratio,
91
- depthwise_kernel_size=config.depthwise_kernel_size,
92
- bottleneck_posterior_kind=config.bottleneck_posterior_kind,
93
- bottleneck_norm_mode=config.bottleneck_norm_mode,
94
- )
95
-
96
- self.decoder = Decoder(
97
- in_channels=config.in_channels,
98
- patch_size=config.patch_size,
99
- model_dim=config.model_dim,
100
- depth=config.decoder_depth,
101
- start_block_count=config.decoder_start_blocks,
102
- end_block_count=config.decoder_end_blocks,
103
- bottleneck_dim=config.bottleneck_dim,
104
- mlp_ratio=config.mlp_ratio,
105
- depthwise_kernel_size=config.depthwise_kernel_size,
106
- adaln_low_rank_rank=config.adaln_low_rank_rank,
107
- )
108
-
109
- @classmethod
110
- def from_pretrained(
111
- cls,
112
- path_or_repo_id: str | Path,
113
- *,
114
- dtype: torch.dtype = torch.bfloat16,
115
- device: str | torch.device = "cpu",
116
- revision: str | None = None,
117
- cache_dir: str | Path | None = None,
118
- ) -> CapacitorDiffAE:
119
- """Load a pretrained model from a local directory or HuggingFace Hub.
120
-
121
- The directory (or repo) should contain:
122
- - config.json: Model architecture config.
123
- - model.safetensors (preferred) or model.pt: Model weights.
124
-
125
- Args:
126
- path_or_repo_id: Local directory path or HuggingFace Hub repo ID.
127
- dtype: Load weights in this dtype (float32 or bfloat16).
128
- device: Target device.
129
- revision: Git revision for Hub downloads.
130
- cache_dir: Where to cache Hub downloads.
131
-
132
- Returns:
133
- Loaded model in eval mode.
134
- """
135
- model_dir = _resolve_model_dir(
136
- path_or_repo_id, revision=revision, cache_dir=cache_dir
137
- )
138
- config = CapacitorDiffAEConfig.load(model_dir / "config.json")
139
- model = cls(config)
140
-
141
- safetensors_path = model_dir / "model.safetensors"
142
- pt_path = model_dir / "model.pt"
143
-
144
- if safetensors_path.exists():
145
- try:
146
- from safetensors.torch import load_file
147
-
148
- state_dict = load_file(str(safetensors_path), device=str(device))
149
- except ImportError:
150
- raise ImportError(
151
- "safetensors package required to load .safetensors files. "
152
- "Install with: pip install safetensors"
153
- )
154
- elif pt_path.exists():
155
- state_dict = torch.load(
156
- str(pt_path), map_location=device, weights_only=True
157
- )
158
- else:
159
- raise FileNotFoundError(
160
- f"No model weights found in {model_dir}. "
161
- "Expected model.safetensors or model.pt."
162
- )
163
-
164
- model.load_state_dict(state_dict)
165
- model = model.to(dtype=dtype, device=torch.device(device))
166
- model.eval()
167
- return model
168
-
169
- def _latent_norm_stats(self) -> tuple[Tensor, Tensor]:
170
- """Return (mean, std) tensors for latent whitening, shaped [1,C,1,1]."""
171
- mean = self.latent_norm_running_mean.view(1, -1, 1, 1)
172
- var = self.latent_norm_running_var.view(1, -1, 1, 1)
173
- std = torch.sqrt(var.to(torch.float32) + self._LATENT_NORM_EPS)
174
- return mean.to(torch.float32), std
175
-
176
- def whiten(self, latents: Tensor) -> Tensor:
177
- """Whiten encoder latents using per-channel running stats.
178
-
179
- Use this before passing latents to a downstream latent-space
180
- diffusion model. The whitened latents have approximately zero mean
181
- and unit variance per channel.
182
-
183
- Args:
184
- latents: [B, bottleneck_dim, h, w] raw encoder output.
185
-
186
- Returns:
187
- Whitened latents [B, bottleneck_dim, h, w] in float32.
188
- """
189
- z = latents.to(torch.float32)
190
- mean, std = self._latent_norm_stats()
191
- return (z - mean.to(device=z.device)) / std.to(device=z.device)
192
-
193
- def dewhiten(self, latents: Tensor) -> Tensor:
194
- """Undo whitening to recover raw encoder latent scale.
195
-
196
- Use this before passing whitened latents back to ``decode()``.
197
-
198
- Args:
199
- latents: [B, bottleneck_dim, h, w] whitened latents.
200
-
201
- Returns:
202
- Dewhitened latents [B, bottleneck_dim, h, w] in float32.
203
- """
204
- z = latents.to(torch.float32)
205
- mean, std = self._latent_norm_stats()
206
- return z * std.to(device=z.device) + mean.to(device=z.device)
207
-
208
- def encode(self, images: Tensor) -> Tensor:
209
- """Encode images to whitened latents (posterior mode).
210
-
211
- Returns latents whitened using per-channel running stats, ready for
212
- use by downstream latent-space diffusion models.
213
-
214
- Args:
215
- images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
216
-
217
- Returns:
218
- Whitened latents [B, bottleneck_dim, H/patch, W/patch].
219
- """
220
- try:
221
- model_dtype = next(self.parameters()).dtype
222
- except StopIteration:
223
- model_dtype = torch.float32
224
- z = self.encoder(images.to(dtype=model_dtype))
225
- return self.whiten(z).to(dtype=model_dtype)
226
-
227
- def encode_posterior(self, images: Tensor) -> EncoderPosterior:
228
- """Encode images and return the full posterior (mean + logsnr).
229
-
230
- Args:
231
- images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
232
-
233
- Returns:
234
- EncoderPosterior with mean and logsnr tensors.
235
- """
236
- try:
237
- model_dtype = next(self.parameters()).dtype
238
- except StopIteration:
239
- model_dtype = torch.float32
240
- return self.encoder.encode_posterior(images.to(dtype=model_dtype))
241
-
242
- @torch.no_grad()
243
- def decode(
244
- self,
245
- latents: Tensor,
246
- height: int,
247
- width: int,
248
- *,
249
- inference_config: CapacitorDiffAEInferenceConfig | None = None,
250
- ) -> Tensor:
251
- """Decode whitened latents to images via VP diffusion.
252
-
253
- Latents are dewhitened internally before being passed to the decoder.
254
-
255
- Args:
256
- latents: [B, bottleneck_dim, h, w] whitened encoder latents.
257
- height: Output image height (divisible by patch_size).
258
- width: Output image width (divisible by patch_size).
259
- inference_config: Optional inference parameters.
260
-
261
- Returns:
262
- Reconstructed images [B, 3, H, W] in float32.
263
- """
264
- cfg = inference_config or CapacitorDiffAEInferenceConfig()
265
- config = self.config
266
- batch = int(latents.shape[0])
267
- device = latents.device
268
-
269
- try:
270
- model_dtype = next(self.parameters()).dtype
271
- except StopIteration:
272
- model_dtype = torch.float32
273
-
274
- # Dewhiten back to raw encoder scale for the decoder
275
- latents = self.dewhiten(latents).to(dtype=model_dtype)
276
-
277
- if height % config.patch_size != 0 or width % config.patch_size != 0:
278
- raise ValueError(
279
- f"height={height} and width={width} must be divisible by "
280
- f"patch_size={config.patch_size}"
281
- )
282
-
283
- shape = (batch, config.in_channels, height, width)
284
- noise = sample_noise(
285
- shape,
286
- noise_std=config.pixel_noise_std,
287
- seed=cfg.seed,
288
- device=torch.device("cpu"),
289
- dtype=torch.float32,
290
- )
291
-
292
- schedule = get_schedule(cfg.schedule, cfg.num_steps).to(device=device)
293
- initial_state = make_initial_state(
294
- noise=noise.to(device=device),
295
- t_start=schedule[0:1],
296
- logsnr_min=config.logsnr_min,
297
- logsnr_max=config.logsnr_max,
298
- )
299
-
300
- device_type = "cuda" if device.type == "cuda" else "cpu"
301
- with torch.autocast(device_type=device_type, enabled=False):
302
- latents_in = latents.to(device=device)
303
-
304
- def _forward_fn(
305
- x_t: Tensor,
306
- t: Tensor,
307
- latents: Tensor,
308
- *,
309
- drop_middle_blocks: bool = False,
310
- mask_latent_tokens: bool = False,
311
- ) -> Tensor:
312
- return self.decoder(
313
- x_t.to(dtype=model_dtype),
314
- t,
315
- latents.to(dtype=model_dtype),
316
- drop_middle_blocks=drop_middle_blocks,
317
- )
318
-
319
- pdg_mode = "path_drop" if cfg.pdg else "disabled"
320
-
321
- if cfg.sampler == "ddim":
322
- sampler_fn = run_ddim
323
- elif cfg.sampler == "dpmpp_2m":
324
- sampler_fn = run_dpmpp_2m
325
- else:
326
- raise ValueError(
327
- f"Unsupported sampler: {cfg.sampler!r}. Use 'ddim' or 'dpmpp_2m'."
328
- )
329
-
330
- result = sampler_fn(
331
- forward_fn=_forward_fn,
332
- initial_state=initial_state,
333
- schedule=schedule,
334
- latents=latents_in,
335
- logsnr_min=config.logsnr_min,
336
- logsnr_max=config.logsnr_max,
337
- pdg_mode=pdg_mode,
338
- pdg_strength=cfg.pdg_strength,
339
- device=device,
340
- )
341
-
342
- return result
343
-
344
- @torch.no_grad()
345
- def reconstruct(
346
- self,
347
- images: Tensor,
348
- *,
349
- inference_config: CapacitorDiffAEInferenceConfig | None = None,
350
- ) -> Tensor:
351
- """Encode then decode. Convenience wrapper.
352
-
353
- Args:
354
- images: [B, 3, H, W] in [-1, 1].
355
- inference_config: Optional inference parameters.
356
-
357
- Returns:
358
- Reconstructed images [B, 3, H, W] in float32.
359
- """
360
- latents = self.encode(images)
361
- _, _, h, w = images.shape
362
- return self.decode(
363
- latents, height=h, width=w, inference_config=inference_config
364
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capacitor_diffae/norms.py DELETED
@@ -1,39 +0,0 @@
1
- """Channel-wise RMSNorm for NCHW tensors."""
2
-
3
- from __future__ import annotations
4
-
5
- import torch
6
- from torch import Tensor, nn
7
-
8
-
9
- class ChannelWiseRMSNorm(nn.Module):
10
- """Channel-wise RMSNorm with float32 reduction for numerical stability.
11
-
12
- Normalizes across channels per spatial position. Supports optional
13
- per-channel affine weight and bias.
14
- """
15
-
16
- def __init__(self, channels: int, eps: float = 1e-6, affine: bool = True) -> None:
17
- super().__init__()
18
- self.channels: int = int(channels)
19
- self._eps: float = float(eps)
20
- if affine:
21
- self.weight = nn.Parameter(torch.ones(self.channels))
22
- self.bias = nn.Parameter(torch.zeros(self.channels))
23
- else:
24
- self.register_parameter("weight", None)
25
- self.register_parameter("bias", None)
26
-
27
- def forward(self, x: Tensor) -> Tensor:
28
- if x.dim() < 2:
29
- return x
30
- # Float32 accumulation for stability
31
- ms = torch.mean(torch.square(x), dim=1, keepdim=True, dtype=torch.float32)
32
- inv_rms = torch.rsqrt(ms + self._eps)
33
- y = x * inv_rms
34
- if self.weight is not None:
35
- shape = (1, -1) + (1,) * (x.dim() - 2)
36
- y = y * self.weight.view(shape).to(dtype=y.dtype)
37
- if self.bias is not None:
38
- y = y + self.bias.view(shape).to(dtype=y.dtype)
39
- return y.to(dtype=x.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capacitor_diffae/samplers.py DELETED
@@ -1,263 +0,0 @@
1
- """DDIM and DPM++2M samplers for VP diffusion with dual PDG support."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Protocol
6
-
7
- import torch
8
- from torch import Tensor
9
-
10
- from .vp_diffusion import (
11
- alpha_sigma_from_logsnr,
12
- broadcast_time_like,
13
- shifted_cosine_interpolated_logsnr_from_t,
14
- )
15
-
16
-
17
- class DecoderForwardFn(Protocol):
18
- """Callable that predicts x0 from (x_t, t, latents) with dual PDG flags."""
19
-
20
- def __call__(
21
- self,
22
- x_t: Tensor,
23
- t: Tensor,
24
- latents: Tensor,
25
- *,
26
- drop_middle_blocks: bool = False,
27
- mask_latent_tokens: bool = False,
28
- ) -> Tensor: ...
29
-
30
-
31
- def _reconstruct_eps_from_x0(
32
- *, x_t: Tensor, x0_hat: Tensor, alpha: Tensor, sigma: Tensor
33
- ) -> Tensor:
34
- """Reconstruct eps_hat from (x_t, x0_hat) under VP parameterization.
35
-
36
- eps_hat = (x_t - alpha * x0_hat) / sigma. All float32.
37
- """
38
- alpha_view = broadcast_time_like(alpha, x_t).to(dtype=torch.float32)
39
- sigma_view = broadcast_time_like(sigma, x_t).to(dtype=torch.float32)
40
- x_t_f32 = x_t.to(torch.float32)
41
- x0_f32 = x0_hat.to(torch.float32)
42
- return (x_t_f32 - alpha_view * x0_f32) / sigma_view
43
-
44
-
45
- def _ddim_step(
46
- *,
47
- x0_hat: Tensor,
48
- eps_hat: Tensor,
49
- alpha_next: Tensor,
50
- sigma_next: Tensor,
51
- ref: Tensor,
52
- ) -> Tensor:
53
- """DDIM step: x_next = alpha_next * x0_hat + sigma_next * eps_hat."""
54
- a = broadcast_time_like(alpha_next, ref).to(dtype=torch.float32)
55
- s = broadcast_time_like(sigma_next, ref).to(dtype=torch.float32)
56
- return a * x0_hat + s * eps_hat
57
-
58
-
59
- def _predict_with_pdg(
60
- forward_fn: DecoderForwardFn,
61
- state: Tensor,
62
- t_vec: Tensor,
63
- latents: Tensor,
64
- *,
65
- pdg_mode: str,
66
- pdg_strength: float,
67
- ) -> Tensor:
68
- """Run decoder forward with optional PDG guidance.
69
-
70
- Args:
71
- forward_fn: Decoder forward function.
72
- state: Current noised state [B, C, H, W].
73
- t_vec: Timestep vector [B].
74
- latents: Encoder latents.
75
- pdg_mode: "disabled", "path_drop", or "token_mask".
76
- pdg_strength: CFG-like strength for PDG.
77
-
78
- Returns:
79
- x0_hat prediction in float32.
80
- """
81
- if pdg_mode == "path_drop":
82
- x0_uncond = forward_fn(state, t_vec, latents, drop_middle_blocks=True).to(
83
- torch.float32
84
- )
85
- x0_cond = forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
86
- torch.float32
87
- )
88
- return x0_uncond + pdg_strength * (x0_cond - x0_uncond)
89
- elif pdg_mode == "token_mask":
90
- x0_uncond = forward_fn(state, t_vec, latents, mask_latent_tokens=True).to(
91
- torch.float32
92
- )
93
- x0_cond = forward_fn(state, t_vec, latents, mask_latent_tokens=False).to(
94
- torch.float32
95
- )
96
- return x0_uncond + pdg_strength * (x0_cond - x0_uncond)
97
- else:
98
- return forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
99
- torch.float32
100
- )
101
-
102
-
103
- def run_ddim(
104
- *,
105
- forward_fn: DecoderForwardFn,
106
- initial_state: Tensor,
107
- schedule: Tensor,
108
- latents: Tensor,
109
- logsnr_min: float,
110
- logsnr_max: float,
111
- log_change_high: float = 0.0,
112
- log_change_low: float = 0.0,
113
- pdg_mode: str = "disabled",
114
- pdg_strength: float = 1.5,
115
- device: torch.device | None = None,
116
- ) -> Tensor:
117
- """Run DDIM sampling loop with dual PDG support.
118
-
119
- Args:
120
- forward_fn: Decoder forward function (x_t, t, latents) -> x0_hat.
121
- initial_state: Starting noised state [B, C, H, W] in float32.
122
- schedule: Descending t-schedule [num_steps] in [0, 1].
123
- latents: Encoder latents [B, bottleneck_dim, h, w].
124
- logsnr_min, logsnr_max: VP schedule endpoints.
125
- log_change_high, log_change_low: Shifted-cosine schedule parameters.
126
- pdg_mode: "disabled", "path_drop", or "token_mask".
127
- pdg_strength: CFG-like strength for PDG.
128
- device: Target device.
129
-
130
- Returns:
131
- Denoised samples [B, C, H, W] in float32.
132
- """
133
- run_device = device or initial_state.device
134
- batch_size = int(initial_state.shape[0])
135
- state = initial_state.to(device=run_device, dtype=torch.float32)
136
-
137
- # Precompute logSNR, alpha, sigma for all schedule points
138
- lmb = shifted_cosine_interpolated_logsnr_from_t(
139
- schedule.to(device=run_device),
140
- logsnr_min=logsnr_min,
141
- logsnr_max=logsnr_max,
142
- log_change_high=log_change_high,
143
- log_change_low=log_change_low,
144
- )
145
- alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
146
-
147
- for i in range(int(schedule.numel()) - 1):
148
- t_i = schedule[i]
149
- a_t = alpha_sched[i].expand(batch_size)
150
- s_t = sigma_sched[i].expand(batch_size)
151
- a_next = alpha_sched[i + 1].expand(batch_size)
152
- s_next = sigma_sched[i + 1].expand(batch_size)
153
-
154
- # Model prediction with optional PDG
155
- t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
156
- x0_hat = _predict_with_pdg(
157
- forward_fn,
158
- state,
159
- t_vec,
160
- latents,
161
- pdg_mode=pdg_mode,
162
- pdg_strength=pdg_strength,
163
- )
164
-
165
- eps_hat = _reconstruct_eps_from_x0(
166
- x_t=state, x0_hat=x0_hat, alpha=a_t, sigma=s_t
167
- )
168
- state = _ddim_step(
169
- x0_hat=x0_hat,
170
- eps_hat=eps_hat,
171
- alpha_next=a_next,
172
- sigma_next=s_next,
173
- ref=state,
174
- )
175
-
176
- return state
177
-
178
-
179
- def run_dpmpp_2m(
180
- *,
181
- forward_fn: DecoderForwardFn,
182
- initial_state: Tensor,
183
- schedule: Tensor,
184
- latents: Tensor,
185
- logsnr_min: float,
186
- logsnr_max: float,
187
- log_change_high: float = 0.0,
188
- log_change_low: float = 0.0,
189
- pdg_mode: str = "disabled",
190
- pdg_strength: float = 1.5,
191
- device: torch.device | None = None,
192
- ) -> Tensor:
193
- """Run DPM++2M sampling loop with dual PDG support.
194
-
195
- Multi-step solver using exponential integrator formulation in half-lambda space.
196
- """
197
- run_device = device or initial_state.device
198
- batch_size = int(initial_state.shape[0])
199
- state = initial_state.to(device=run_device, dtype=torch.float32)
200
-
201
- # Precompute logSNR, alpha, sigma, half-lambda for all schedule points
202
- lmb = shifted_cosine_interpolated_logsnr_from_t(
203
- schedule.to(device=run_device),
204
- logsnr_min=logsnr_min,
205
- logsnr_max=logsnr_max,
206
- log_change_high=log_change_high,
207
- log_change_low=log_change_low,
208
- )
209
- alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb)
210
- half_lambda = 0.5 * lmb.to(torch.float32)
211
-
212
- x0_prev: Tensor | None = None
213
-
214
- for i in range(int(schedule.numel()) - 1):
215
- t_i = schedule[i]
216
- s_t = sigma_sched[i].expand(batch_size)
217
- a_next = alpha_sched[i + 1].expand(batch_size)
218
- s_next = sigma_sched[i + 1].expand(batch_size)
219
-
220
- # Model prediction with optional PDG
221
- t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32)
222
- x0_hat = _predict_with_pdg(
223
- forward_fn,
224
- state,
225
- t_vec,
226
- latents,
227
- pdg_mode=pdg_mode,
228
- pdg_strength=pdg_strength,
229
- )
230
-
231
- lam_t = half_lambda[i].expand(batch_size)
232
- lam_next = half_lambda[i + 1].expand(batch_size)
233
- h = (lam_next - lam_t).to(torch.float32)
234
- phi_1 = torch.expm1(-h)
235
-
236
- sigma_ratio = (s_next / s_t).to(torch.float32)
237
-
238
- if i == 0 or x0_prev is None:
239
- # First-order step
240
- state = (
241
- sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
242
- - broadcast_time_like(a_next, state).to(torch.float32)
243
- * broadcast_time_like(phi_1, state).to(torch.float32)
244
- * x0_hat
245
- )
246
- else:
247
- # Second-order step
248
- lam_prev = half_lambda[i - 1].expand(batch_size)
249
- h_0 = (lam_t - lam_prev).to(torch.float32)
250
- r0 = h_0 / h
251
- d1_0 = (x0_hat - x0_prev) / broadcast_time_like(r0, x0_hat)
252
- common = broadcast_time_like(a_next, state).to(
253
- torch.float32
254
- ) * broadcast_time_like(phi_1, state).to(torch.float32)
255
- state = (
256
- sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state
257
- - common * x0_hat
258
- - 0.5 * common * d1_0
259
- )
260
-
261
- x0_prev = x0_hat
262
-
263
- return state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capacitor_diffae/straight_through_encoder.py DELETED
@@ -1,27 +0,0 @@
1
- """PixelUnshuffle-based patchifier (no residual conv path)."""
2
-
3
- from __future__ import annotations
4
-
5
- from torch import Tensor, nn
6
-
7
-
8
- class Patchify(nn.Module):
9
- """PixelUnshuffle(patch) -> Conv2d 1x1 projection.
10
-
11
- Converts [B, C, H, W] images into [B, out_channels, H/patch, W/patch] features.
12
- """
13
-
14
- def __init__(self, in_channels: int, patch: int, out_channels: int) -> None:
15
- super().__init__()
16
- self.patch = int(patch)
17
- self.unshuffle = nn.PixelUnshuffle(self.patch)
18
- in_after = in_channels * (self.patch * self.patch)
19
- self.proj = nn.Conv2d(in_after, out_channels, kernel_size=1, bias=True)
20
-
21
- def forward(self, x: Tensor) -> Tensor:
22
- if x.shape[2] % self.patch != 0 or x.shape[3] % self.patch != 0:
23
- raise ValueError(
24
- f"Input H={x.shape[2]} and W={x.shape[3]} must be divisible by patch={self.patch}"
25
- )
26
- y = self.unshuffle(x)
27
- return self.proj(y)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capacitor_diffae/time_embed.py DELETED
@@ -1,83 +0,0 @@
1
- """Sinusoidal timestep embedding with MLP projection."""
2
-
3
- from __future__ import annotations
4
-
5
- import math
6
-
7
- import torch
8
- from torch import Tensor, nn
9
-
10
-
11
- def _log_spaced_frequencies(
12
- half: int, max_period: float, *, device: torch.device | None = None
13
- ) -> Tensor:
14
- """Log-spaced frequencies for sinusoidal embedding."""
15
- return torch.exp(
16
- -math.log(max_period)
17
- * torch.arange(half, device=device, dtype=torch.float32)
18
- / max(float(half - 1), 1.0)
19
- )
20
-
21
-
22
- def sinusoidal_time_embedding(
23
- t: Tensor,
24
- dim: int,
25
- *,
26
- max_period: float = 10000.0,
27
- scale: float | None = None,
28
- freqs: Tensor | None = None,
29
- ) -> Tensor:
30
- """Sinusoidal timestep embedding (DDPM/DiT-style). Always float32."""
31
- t32 = t.to(torch.float32)
32
- if scale is not None:
33
- t32 = t32 * float(scale)
34
- half = dim // 2
35
- if freqs is not None:
36
- freqs = freqs.to(device=t32.device, dtype=torch.float32)
37
- else:
38
- freqs = _log_spaced_frequencies(half, max_period, device=t32.device)
39
- angles = t32[:, None] * freqs[None, :]
40
- return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
41
-
42
-
43
- class SinusoidalTimeEmbeddingMLP(nn.Module):
44
- """Sinusoidal time embedding followed by Linear -> SiLU -> Linear."""
45
-
46
- def __init__(
47
- self,
48
- dim: int,
49
- *,
50
- freq_dim: int = 256,
51
- hidden_mult: float = 1.0,
52
- time_scale: float = 1000.0,
53
- max_period: float = 10000.0,
54
- ) -> None:
55
- super().__init__()
56
- self.dim = int(dim)
57
- self.freq_dim = int(freq_dim)
58
- self.time_scale = float(time_scale)
59
- self.max_period = float(max_period)
60
- hidden_dim = max(int(round(int(dim) * float(hidden_mult))), 1)
61
-
62
- freqs = _log_spaced_frequencies(self.freq_dim // 2, self.max_period)
63
- self.register_buffer("freqs", freqs, persistent=True)
64
-
65
- self.proj_in = nn.Linear(self.freq_dim, hidden_dim)
66
- self.act = nn.SiLU()
67
- self.proj_out = nn.Linear(hidden_dim, self.dim)
68
-
69
- def forward(self, t: Tensor) -> Tensor:
70
- freqs: Tensor = self.freqs # type: ignore[assignment]
71
- emb_freq = sinusoidal_time_embedding(
72
- t.to(torch.float32),
73
- self.freq_dim,
74
- max_period=self.max_period,
75
- scale=self.time_scale,
76
- freqs=freqs,
77
- )
78
- dtype_in = self.proj_in.weight.dtype
79
- hidden = self.proj_in(emb_freq.to(dtype_in))
80
- hidden = self.act(hidden)
81
- if hidden.dtype != self.proj_out.weight.dtype:
82
- hidden = hidden.to(self.proj_out.weight.dtype)
83
- return self.proj_out(hidden)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
capacitor_diffae/vp_diffusion.py DELETED
@@ -1,151 +0,0 @@
1
- """VP diffusion math: logSNR schedules, alpha/sigma computation, noise construction."""
2
-
3
- from __future__ import annotations
4
-
5
- import math
6
-
7
- import torch
8
- from torch import Tensor
9
-
10
-
11
- def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]:
12
- """Compute (alpha, sigma) from logSNR in float32.
13
-
14
- VP constraint: alpha^2 + sigma^2 = 1.
15
- """
16
- lmb32 = lmb.to(dtype=torch.float32)
17
- alpha = torch.sqrt(torch.sigmoid(lmb32))
18
- sigma = torch.sqrt(torch.sigmoid(-lmb32))
19
- return alpha, sigma
20
-
21
-
22
- def broadcast_time_like(coeff: Tensor, x: Tensor) -> Tensor:
23
- """Broadcast [B] coefficient to match x for per-sample scaling."""
24
- view_shape = (int(x.shape[0]),) + (1,) * (x.dim() - 1)
25
- return coeff.view(view_shape)
26
-
27
-
28
- def _cosine_interpolated_params(
29
- logsnr_min: float, logsnr_max: float
30
- ) -> tuple[float, float]:
31
- """Compute (a, b) for cosine-interpolated logSNR schedule.
32
-
33
- logsnr(t) = -2 * log(tan(a*t + b))
34
- logsnr(0) = logsnr_max, logsnr(1) = logsnr_min
35
- """
36
- b = math.atan(math.exp(-0.5 * logsnr_max))
37
- a = math.atan(math.exp(-0.5 * logsnr_min)) - b
38
- return a, b
39
-
40
-
41
- def cosine_interpolated_logsnr_from_t(
42
- t: Tensor, *, logsnr_min: float, logsnr_max: float
43
- ) -> Tensor:
44
- """Map t in [0,1] to logSNR via cosine-interpolated schedule. Always float32."""
45
- a, b = _cosine_interpolated_params(logsnr_min, logsnr_max)
46
- t32 = t.to(dtype=torch.float32)
47
- a_t = torch.tensor(a, device=t32.device, dtype=torch.float32)
48
- b_t = torch.tensor(b, device=t32.device, dtype=torch.float32)
49
- u = a_t * t32 + b_t
50
- return -2.0 * torch.log(torch.tan(u))
51
-
52
-
53
- def shifted_cosine_interpolated_logsnr_from_t(
54
- t: Tensor,
55
- *,
56
- logsnr_min: float,
57
- logsnr_max: float,
58
- log_change_high: float = 0.0,
59
- log_change_low: float = 0.0,
60
- ) -> Tensor:
61
- """SiD2 "shifted cosine" schedule: logSNR with resolution-dependent shifts.
62
-
63
- lambda(t) = (1-t) * (base(t) + log_change_high) + t * (base(t) + log_change_low)
64
- """
65
- base = cosine_interpolated_logsnr_from_t(
66
- t, logsnr_min=logsnr_min, logsnr_max=logsnr_max
67
- )
68
- t32 = t.to(dtype=torch.float32)
69
- high = base + float(log_change_high)
70
- low = base + float(log_change_low)
71
- return (1.0 - t32) * high + t32 * low
72
-
73
-
74
- def get_schedule(schedule_type: str, num_steps: int) -> Tensor:
75
- """Generate a descending t-schedule in [0, 1] for VP diffusion sampling.
76
-
77
- ``num_steps`` is the number of function evaluations (NFE = decoder forward
78
- passes). Internally the schedule has ``num_steps + 1`` time points
79
- (including both endpoints).
80
-
81
- Args:
82
- schedule_type: "linear" or "cosine".
83
- num_steps: Number of decoder forward passes (NFE), >= 1.
84
-
85
- Returns:
86
- Descending 1D tensor with ``num_steps + 1`` elements from ~1.0 to ~0.0.
87
- """
88
- # NOTE: the upstream training code (src/ode/time_schedules.py) uses a
89
- # different convention where num_steps counts schedule *points* (so NFE =
90
- # num_steps - 1). This export package corrects the off-by-one so that
91
- # num_steps means NFE directly. TODO: align the upstream convention.
92
- n = max(int(num_steps) + 1, 2)
93
- if schedule_type == "linear":
94
- base = torch.linspace(0.0, 1.0, n)
95
- elif schedule_type == "cosine":
96
- i = torch.arange(n, dtype=torch.float32)
97
- base = 0.5 * (1.0 - torch.cos(math.pi * (i / (n - 1))))
98
- else:
99
- raise ValueError(
100
- f"Unsupported schedule type: {schedule_type!r}. Use 'linear' or 'cosine'."
101
- )
102
- # Descending: high t (noisy) -> low t (clean)
103
- return torch.flip(base, dims=[0])
104
-
105
-
106
- def make_initial_state(
107
- *,
108
- noise: Tensor,
109
- t_start: Tensor,
110
- logsnr_min: float,
111
- logsnr_max: float,
112
- log_change_high: float = 0.0,
113
- log_change_low: float = 0.0,
114
- ) -> Tensor:
115
- """Construct VP initial state x_t0 = sigma_start * noise (since x0=0).
116
-
117
- All math in float32.
118
- """
119
- batch = int(noise.shape[0])
120
- lmb_start = shifted_cosine_interpolated_logsnr_from_t(
121
- t_start.expand(batch).to(dtype=torch.float32),
122
- logsnr_min=logsnr_min,
123
- logsnr_max=logsnr_max,
124
- log_change_high=log_change_high,
125
- log_change_low=log_change_low,
126
- )
127
- _alpha_start, sigma_start = alpha_sigma_from_logsnr(lmb_start)
128
- sigma_view = broadcast_time_like(sigma_start, noise)
129
- return sigma_view * noise.to(dtype=torch.float32)
130
-
131
-
132
- def sample_noise(
133
- shape: tuple[int, ...],
134
- *,
135
- noise_std: float = 1.0,
136
- seed: int | None = None,
137
- device: torch.device | None = None,
138
- dtype: torch.dtype = torch.float32,
139
- ) -> Tensor:
140
- """Sample Gaussian noise with optional seeding. CPU-seeded for reproducibility."""
141
- if seed is None:
142
- noise = torch.randn(
143
- shape, device=device or torch.device("cpu"), dtype=torch.float32
144
- )
145
- else:
146
- gen = torch.Generator(device="cpu")
147
- gen.manual_seed(int(seed))
148
- noise = torch.randn(shape, generator=gen, device="cpu", dtype=torch.float32)
149
- noise = noise.mul(float(noise_std))
150
- target_device = device if device is not None else torch.device("cpu")
151
- return noise.to(device=target_device, dtype=dtype)