jiuhai commited on
Commit
4e1e978
·
verified ·
1 Parent(s): 4cd1d55

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. packages/ltx-core/src/ltx_core/__init__.py +0 -0
  2. packages/ltx-core/src/ltx_core/__pycache__/__init__.cpython-312.pyc +0 -0
  3. packages/ltx-core/src/ltx_core/__pycache__/types.cpython-312.pyc +0 -0
  4. packages/ltx-core/src/ltx_core/__pycache__/utils.cpython-312.pyc +0 -0
  5. packages/ltx-core/src/ltx_core/components/__init__.py +10 -0
  6. packages/ltx-core/src/ltx_core/components/diffusion_steps.py +95 -0
  7. packages/ltx-core/src/ltx_core/components/guiders.py +364 -0
  8. packages/ltx-core/src/ltx_core/components/noisers.py +35 -0
  9. packages/ltx-core/src/ltx_core/components/patchifiers.py +348 -0
  10. packages/ltx-core/src/ltx_core/components/protocols.py +101 -0
  11. packages/ltx-core/src/ltx_core/components/schedulers.py +130 -0
  12. packages/ltx-core/src/ltx_core/conditioning/__init__.py +19 -0
  13. packages/ltx-core/src/ltx_core/conditioning/exceptions.py +4 -0
  14. packages/ltx-core/src/ltx_core/conditioning/item.py +20 -0
  15. packages/ltx-core/src/ltx_core/conditioning/mask_utils.py +210 -0
  16. packages/ltx-core/src/ltx_core/guidance/__init__.py +15 -0
  17. packages/ltx-core/src/ltx_core/guidance/perturbations.py +79 -0
  18. packages/ltx-core/src/ltx_core/loader/__init__.py +48 -0
  19. packages/ltx-core/src/ltx_core/loader/__pycache__/__init__.cpython-312.pyc +0 -0
  20. packages/ltx-core/src/ltx_core/loader/__pycache__/fuse_loras.cpython-312.pyc +0 -0
  21. packages/ltx-core/src/ltx_core/loader/__pycache__/module_ops.cpython-312.pyc +0 -0
  22. packages/ltx-core/src/ltx_core/loader/__pycache__/primitives.cpython-312.pyc +0 -0
  23. packages/ltx-core/src/ltx_core/loader/__pycache__/registry.cpython-312.pyc +0 -0
  24. packages/ltx-core/src/ltx_core/loader/__pycache__/sd_ops.cpython-312.pyc +0 -0
  25. packages/ltx-core/src/ltx_core/loader/__pycache__/sft_loader.cpython-312.pyc +0 -0
  26. packages/ltx-core/src/ltx_core/loader/__pycache__/single_gpu_model_builder.cpython-312.pyc +0 -0
  27. packages/ltx-core/src/ltx_core/loader/fuse_loras.py +153 -0
  28. packages/ltx-core/src/ltx_core/loader/kernels.py +72 -0
  29. packages/ltx-core/src/ltx_core/loader/module_ops.py +14 -0
  30. packages/ltx-core/src/ltx_core/loader/primitives.py +109 -0
  31. packages/ltx-core/src/ltx_core/loader/registry.py +84 -0
  32. packages/ltx-core/src/ltx_core/loader/sd_ops.py +127 -0
  33. packages/ltx-core/src/ltx_core/loader/sft_loader.py +66 -0
  34. packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py +116 -0
  35. packages/ltx-core/src/ltx_core/model/__init__.py +8 -0
  36. packages/ltx-core/src/ltx_core/model/__pycache__/__init__.cpython-312.pyc +0 -0
  37. packages/ltx-core/src/ltx_core/model/__pycache__/model_protocol.cpython-312.pyc +0 -0
  38. packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py +29 -0
  39. packages/ltx-core/src/ltx_core/model/audio_vae/attention.py +71 -0
  40. packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py +508 -0
  41. packages/ltx-core/src/ltx_core/model/audio_vae/causal_conv_2d.py +110 -0
  42. packages/ltx-core/src/ltx_core/model/audio_vae/downsample.py +110 -0
  43. packages/ltx-core/src/ltx_core/model/audio_vae/model_configurator.py +200 -0
  44. packages/ltx-core/src/ltx_core/model/audio_vae/ops.py +73 -0
  45. packages/ltx-core/src/ltx_core/model/audio_vae/resnet.py +176 -0
  46. packages/ltx-core/src/ltx_core/model/audio_vae/upsample.py +106 -0
  47. packages/ltx-core/src/ltx_core/model/audio_vae/vocoder.py +575 -0
  48. packages/ltx-core/src/ltx_core/model/model_protocol.py +10 -0
  49. packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py +15 -0
  50. packages/ltx-core/src/ltx_core/model/upsampler/__init__.py +10 -0
packages/ltx-core/src/ltx_core/__init__.py ADDED
File without changes
packages/ltx-core/src/ltx_core/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (170 Bytes). View file
 
packages/ltx-core/src/ltx_core/__pycache__/types.cpython-312.pyc ADDED
Binary file (10 kB). View file
 
packages/ltx-core/src/ltx_core/__pycache__/utils.cpython-312.pyc ADDED
Binary file (3.68 kB). View file
 
packages/ltx-core/src/ltx_core/components/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Diffusion pipeline components.
3
+ Submodules:
4
+ diffusion_steps - Diffusion stepping algorithms (EulerDiffusionStep)
5
+ guiders - Guidance strategies (CFGGuider, STGGuider, APG variants)
6
+ noisers - Noise samplers (GaussianNoiser)
7
+ patchifiers - Latent patchification (VideoLatentPatchifier, AudioPatchifier)
8
+ protocols - Protocol definitions (Patchifier, etc.)
9
+ schedulers - Sigma schedulers (LTX2Scheduler, LinearQuadraticScheduler)
10
+ """
packages/ltx-core/src/ltx_core/components/diffusion_steps.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.components.protocols import DiffusionStepProtocol
4
+ from ltx_core.utils import to_velocity
5
+
6
+
7
+ class EulerDiffusionStep(DiffusionStepProtocol):
8
+ """
9
+ First-order Euler method for diffusion sampling.
10
+ Takes a single step from the current noise level (sigma) to the next by
11
+ computing velocity from the denoised prediction and applying: sample + velocity * dt.
12
+ """
13
+
14
+ def step(
15
+ self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **_kwargs
16
+ ) -> torch.Tensor:
17
+ sigma = sigmas[step_index]
18
+ sigma_next = sigmas[step_index + 1]
19
+ dt = sigma_next - sigma
20
+ velocity = to_velocity(sample, sigma, denoised_sample)
21
+
22
+ return (sample.to(torch.float32) + velocity.to(torch.float32) * dt).to(sample.dtype)
23
+
24
+
25
+ class Res2sDiffusionStep(DiffusionStepProtocol):
26
+ """
27
+ Second-order diffusion step for res_2s sampling with SDE noise injection.
28
+ Used by the res_2s denoising loop. Advances the sample from the current
29
+ sigma to the next by mixing a deterministic update (from the denoised
30
+ prediction) with injected noise via ``get_sde_coeff``, producing
31
+ variance-preserving transitions.
32
+ """
33
+
34
+ @staticmethod
35
+ def get_sde_coeff(
36
+ sigma_next: torch.Tensor,
37
+ sigma_up: torch.Tensor | None = None,
38
+ sigma_down: torch.Tensor | None = None,
39
+ sigma_max: torch.Tensor | None = None,
40
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
41
+ """
42
+ Compute SDE coefficients (alpha_ratio, sigma_down, sigma_up) for the step.
43
+ Given either ``sigma_down`` or ``sigma_up``, returns the mixing
44
+ coefficients used for variance-preserving noise injection. If
45
+ ``sigma_up`` is provided, ``sigma_down`` and ``alpha_ratio`` are
46
+ derived; if ``sigma_down`` is provided, ``sigma_up`` and
47
+ ``alpha_ratio`` are derived.
48
+ """
49
+ if sigma_down is not None:
50
+ alpha_ratio = (1 - sigma_next) / (1 - sigma_down)
51
+ sigma_up = (sigma_next**2 - sigma_down**2 * alpha_ratio**2).clamp(min=0) ** 0.5
52
+ elif sigma_up is not None:
53
+ # Fallback to avoid sqrt(neg_num)
54
+ sigma_up.clamp_(max=sigma_next * 0.9999)
55
+ sigmax = sigma_max if sigma_max is not None else torch.ones_like(sigma_next)
56
+ sigma_signal = sigmax - sigma_next
57
+ sigma_residual = (sigma_next**2 - sigma_up**2).clamp(min=0) ** 0.5
58
+ alpha_ratio = sigma_signal + sigma_residual
59
+ sigma_down = sigma_residual / alpha_ratio
60
+ else:
61
+ alpha_ratio = torch.ones_like(sigma_next)
62
+ sigma_down = sigma_next
63
+ sigma_up = torch.zeros_like(sigma_next)
64
+
65
+ sigma_up = torch.nan_to_num(sigma_up if sigma_up is not None else torch.zeros_like(sigma_next), 0.0)
66
+ # Replace NaNs in sigma_down with corresponding sigma_next elements (float32)
67
+ nan_mask = torch.isnan(sigma_down)
68
+ sigma_down[nan_mask] = sigma_next[nan_mask].to(sigma_down.dtype)
69
+ alpha_ratio = torch.nan_to_num(alpha_ratio, 1.0)
70
+
71
+ return alpha_ratio, sigma_down, sigma_up
72
+
73
+ def step(
74
+ self,
75
+ sample: torch.Tensor,
76
+ denoised_sample: torch.Tensor,
77
+ sigmas: torch.Tensor,
78
+ step_index: int,
79
+ noise: torch.Tensor,
80
+ ) -> torch.Tensor:
81
+ """Advance one step with SDE noise injection via get_sde_coeff."""
82
+ sigma = sigmas[step_index]
83
+ sigma_next = sigmas[step_index + 1]
84
+ alpha_ratio, sigma_down, sigma_up = self.get_sde_coeff(sigma_next, sigma_up=sigma_next * 0.5)
85
+ output_dtype = denoised_sample.dtype
86
+ if torch.any(sigma_up == 0) or torch.any(sigma_next == 0):
87
+ return denoised_sample
88
+
89
+ # Extract epsilon prediction
90
+ eps_next = (sample - denoised_sample) / (sigma - sigma_next)
91
+ denoised_next = sample - sigma * eps_next
92
+
93
+ # Mix deterministic and stochastic components
94
+ x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise
95
+ return x_noised.to(output_dtype)
packages/ltx-core/src/ltx_core/components/guiders.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Mapping, Sequence
3
+ from dataclasses import dataclass, field
4
+
5
+ import torch
6
+
7
+ from ltx_core.components.protocols import GuiderProtocol
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class CFGGuider(GuiderProtocol):
12
+ """
13
+ Classifier-free guidance (CFG) guider.
14
+ Computes the guidance delta as (scale - 1) * (cond - uncond), steering the
15
+ denoising process toward the conditioned prediction.
16
+ Attributes:
17
+ scale: Guidance strength. 1.0 means no guidance, higher values increase
18
+ adherence to the conditioning.
19
+ """
20
+
21
+ scale: float
22
+
23
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
24
+ return (self.scale - 1) * (cond - uncond)
25
+
26
+ def enabled(self) -> bool:
27
+ return self.scale != 1.0
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class CFGStarRescalingGuider(GuiderProtocol):
32
+ """
33
+ Calculates the CFG delta between conditioned and unconditioned samples.
34
+ To minimize offset in the denoising direction and move mostly along the
35
+ conditioning axis within the distribution, the unconditioned sample is
36
+ rescaled in accordance with the norm of the conditioned sample.
37
+ Attributes:
38
+ scale (float):
39
+ Global guidance strength. A value of 1.0 corresponds to no extra
40
+ guidance beyond the base model prediction. Values > 1.0 increase
41
+ the influence of the conditioned sample relative to the
42
+ unconditioned one.
43
+ """
44
+
45
+ scale: float
46
+
47
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
48
+ rescaled_neg = projection_coef(cond, uncond) * uncond
49
+ return (self.scale - 1) * (cond - rescaled_neg)
50
+
51
+ def enabled(self) -> bool:
52
+ return self.scale != 1.0
53
+
54
+
55
+ @dataclass(frozen=True)
56
+ class STGGuider(GuiderProtocol):
57
+ """
58
+ Calculates the STG delta between conditioned and perturbed denoised samples.
59
+ Perturbed samples are the result of the denoising process with perturbations,
60
+ e.g. attentions acting as passthrough for certain layers and modalities.
61
+ Attributes:
62
+ scale (float):
63
+ Global strength of the STG guidance. A value of 0.0 disables the
64
+ guidance. Larger values increase the correction applied in the
65
+ direction of (pos_denoised - perturbed_denoised).
66
+ """
67
+
68
+ scale: float
69
+
70
+ def delta(self, pos_denoised: torch.Tensor, perturbed_denoised: torch.Tensor) -> torch.Tensor:
71
+ return self.scale * (pos_denoised - perturbed_denoised)
72
+
73
+ def enabled(self) -> bool:
74
+ return self.scale != 0.0
75
+
76
+
77
+ @dataclass(frozen=True)
78
+ class LtxAPGGuider(GuiderProtocol):
79
+ """
80
+ Calculates the APG (adaptive projected guidance) delta between conditioned
81
+ and unconditioned samples.
82
+ To minimize offset in the denoising direction and move mostly along the
83
+ conditioning axis within the distribution, the (cond - uncond) delta is
84
+ decomposed into components parallel and orthogonal to the conditioned
85
+ sample. The `eta` parameter weights the parallel component, while `scale`
86
+ is applied to the orthogonal component. Optionally, a norm threshold can
87
+ be used to suppress guidance when the magnitude of the correction is small.
88
+ Attributes:
89
+ scale (float):
90
+ Strength applied to the component of the guidance that is orthogonal
91
+ to the conditioned sample. Controls how aggressively we move in
92
+ directions that change semantics but stay consistent with the
93
+ conditioning manifold.
94
+ eta (float):
95
+ Weight of the component of the guidance that is parallel to the
96
+ conditioned sample. A value of 1.0 keeps the full parallel
97
+ component; values in [0, 1] attenuate it, and values > 1.0 amplify
98
+ motion along the conditioning direction.
99
+ norm_threshold (float):
100
+ Minimum L2 norm of the guidance delta below which the guidance
101
+ can be reduced or ignored (depending on implementation).
102
+ This is useful for avoiding noisy or unstable updates when the
103
+ guidance signal is very small.
104
+ """
105
+
106
+ scale: float
107
+ eta: float = 1.0
108
+ norm_threshold: float = 0.0
109
+
110
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
111
+ guidance = cond - uncond
112
+ if self.norm_threshold > 0:
113
+ ones = torch.ones_like(guidance)
114
+ guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
115
+ scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
116
+ guidance = guidance * scale_factor
117
+ proj_coeff = projection_coef(guidance, cond)
118
+ g_parallel = proj_coeff * cond
119
+ g_orth = guidance - g_parallel
120
+ g_apg = g_parallel * self.eta + g_orth
121
+
122
+ return g_apg * (self.scale - 1)
123
+
124
+ def enabled(self) -> bool:
125
+ return self.scale != 1.0
126
+
127
+
128
+ @dataclass(frozen=False)
129
+ class LegacyStatefulAPGGuider(GuiderProtocol):
130
+ """
131
+ Calculates the APG (adaptive projected guidance) delta between conditioned
132
+ and unconditioned samples.
133
+ To minimize offset in the denoising direction and move mostly along the
134
+ conditioning axis within the distribution, the (cond - uncond) delta is
135
+ decomposed into components parallel and orthogonal to the conditioned
136
+ sample. The `eta` parameter weights the parallel component, while `scale`
137
+ is applied to the orthogonal component. Optionally, a norm threshold can
138
+ be used to suppress guidance when the magnitude of the correction is small.
139
+ Attributes:
140
+ scale (float):
141
+ Strength applied to the component of the guidance that is orthogonal
142
+ to the conditioned sample. Controls how aggressively we move in
143
+ directions that change semantics but stay consistent with the
144
+ conditioning manifold.
145
+ eta (float):
146
+ Weight of the component of the guidance that is parallel to the
147
+ conditioned sample. A value of 1.0 keeps the full parallel
148
+ component; values in [0, 1] attenuate it, and values > 1.0 amplify
149
+ motion along the conditioning direction.
150
+ norm_threshold (float):
151
+ Minimum L2 norm of the guidance delta below which the guidance
152
+ can be reduced or ignored (depending on implementation).
153
+ This is useful for avoiding noisy or unstable updates when the
154
+ guidance signal is very small.
155
+ momentum (float):
156
+ Exponential moving-average coefficient for accumulating guidance
157
+ over time. running_avg = momentum * running_avg + guidance
158
+ """
159
+
160
+ scale: float
161
+ eta: float
162
+ norm_threshold: float = 5.0
163
+ momentum: float = 0.0
164
+ # it is user's responsibility not to use same APGGuider for several denoisings or different modalities
165
+ # in order not to share accumulated average across different denoisings or modalities
166
+ running_avg: torch.Tensor | None = None
167
+
168
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
169
+ guidance = cond - uncond
170
+ if self.momentum != 0:
171
+ if self.running_avg is None:
172
+ self.running_avg = guidance.clone()
173
+ else:
174
+ self.running_avg = self.momentum * self.running_avg + guidance
175
+ guidance = self.running_avg
176
+
177
+ if self.norm_threshold > 0:
178
+ ones = torch.ones_like(guidance)
179
+ guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
180
+ scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
181
+ guidance = guidance * scale_factor
182
+
183
+ proj_coeff = projection_coef(guidance, cond)
184
+ g_parallel = proj_coeff * cond
185
+ g_orth = guidance - g_parallel
186
+ g_apg = g_parallel * self.eta + g_orth
187
+
188
+ return g_apg * self.scale
189
+
190
+ def enabled(self) -> bool:
191
+ return self.scale != 0.0
192
+
193
+
194
+ @dataclass(frozen=True)
195
+ class MultiModalGuiderParams:
196
+ """
197
+ Parameters for the multi-modal guider.
198
+ """
199
+
200
+ cfg_scale: float = 1.0
201
+ "CFG (Classifier-free guidance) scale controlling how strongly the model adheres to the prompt."
202
+ stg_scale: float = 0.0
203
+ "STG (Spatio-Temporal Guidance) scale controls how strongly the model reacts to the perturbation of the modality."
204
+ stg_blocks: list[int] | None = field(default_factory=list)
205
+ "Which transformer blocks to perturb for STG."
206
+ rescale_scale: float = 0.0
207
+ "Rescale scale controlling how strongly the model rescales the modality after applying other guidance."
208
+ modality_scale: float = 1.0
209
+ "Modality scale controlling how strongly the model reacts to the perturbation of the modality."
210
+ skip_step: int = 0
211
+ "Skip step controlling how often the model skips the step."
212
+
213
+
214
+ def _params_for_sigma_from_sorted_dict(
215
+ sigma: float, params_by_sigma: Sequence[tuple[float, MultiModalGuiderParams]]
216
+ ) -> MultiModalGuiderParams:
217
+ """
218
+ Return params for the given sigma from a sorted (sigma_upper_bound -> params) structure.
219
+ Keys are sorted descending (bin upper bounds). Bin i is (key_{i+1}, key_i].
220
+ Get all keys >= sigma; use last in list (smallest such key = upper bound of bin containing sigma),
221
+ or last entry in the sequence if list is empty (sigma above max key).
222
+ """
223
+ if not params_by_sigma:
224
+ raise ValueError("params_by_sigma must be non-empty")
225
+ sigma = float(sigma)
226
+ keys_desc = [k for k, _ in params_by_sigma]
227
+ keys_ge_sigma = [k for k in keys_desc if k >= sigma]
228
+ # sigma above all keys: use first bin (max key)
229
+ key = keys_ge_sigma[-1] if keys_ge_sigma else keys_desc[0]
230
+ return next(p for k, p in params_by_sigma if k == key)
231
+
232
+
233
+ @dataclass(frozen=True)
234
+ class MultiModalGuider:
235
+ """
236
+ Multi-modal guider with constant params per instance.
237
+ For sigma-dependent params, use MultiModalGuiderFactory.build_from_sigma(sigma) to
238
+ obtain a guider for each step.
239
+ """
240
+
241
+ params: MultiModalGuiderParams
242
+ negative_context: torch.Tensor | None = None
243
+
244
+ def calculate(
245
+ self,
246
+ cond: torch.Tensor,
247
+ uncond_text: torch.Tensor | float,
248
+ uncond_perturbed: torch.Tensor | float,
249
+ uncond_modality: torch.Tensor | float,
250
+ ) -> torch.Tensor:
251
+ """
252
+ The guider calculates the guidance delta as (scale - 1) * (cond - uncond) for cfg and modality cfg,
253
+ and as scale * (cond - uncond) for stg, steering the denoising process away from the unconditioned
254
+ prediction.
255
+ """
256
+ pred = (
257
+ cond
258
+ + (self.params.cfg_scale - 1) * (cond - uncond_text)
259
+ + self.params.stg_scale * (cond - uncond_perturbed)
260
+ + (self.params.modality_scale - 1) * (cond - uncond_modality)
261
+ )
262
+
263
+ if self.params.rescale_scale != 0:
264
+ factor = cond.std() / pred.std()
265
+ factor = self.params.rescale_scale * factor + (1 - self.params.rescale_scale)
266
+ pred = pred * factor
267
+
268
+ return pred
269
+
270
+ def do_unconditional_generation(self) -> bool:
271
+ """Returns True if the guider is doing unconditional generation."""
272
+ return not math.isclose(self.params.cfg_scale, 1.0)
273
+
274
+ def do_perturbed_generation(self) -> bool:
275
+ """Returns True if the guider is doing perturbed generation."""
276
+ return not math.isclose(self.params.stg_scale, 0.0)
277
+
278
+ def do_isolated_modality_generation(self) -> bool:
279
+ """Returns True if the guider is doing isolated modality generation."""
280
+ return not math.isclose(self.params.modality_scale, 1.0)
281
+
282
+ def should_skip_step(self, step: int) -> bool:
283
+ """Returns True if the guider should skip the step."""
284
+ if self.params.skip_step == 0:
285
+ return False
286
+ return step % (self.params.skip_step + 1) != 0
287
+
288
+
289
+ @dataclass(frozen=True)
290
+ class MultiModalGuiderFactory:
291
+ """
292
+ Factory that creates a MultiModalGuider for a given sigma.
293
+ Single source of truth: _params_by_sigma (schedule). Use constant() for
294
+ one params for all sigma, from_dict() for sigma-binned params.
295
+ """
296
+
297
+ negative_context: torch.Tensor | None = None
298
+ _params_by_sigma: tuple[tuple[float, MultiModalGuiderParams], ...] = ()
299
+
300
+ @classmethod
301
+ def constant(
302
+ cls,
303
+ params: MultiModalGuiderParams,
304
+ negative_context: torch.Tensor | None = None,
305
+ ) -> "MultiModalGuiderFactory":
306
+ """Build a factory with constant params (same guider for all sigma)."""
307
+ return cls(
308
+ negative_context=negative_context,
309
+ _params_by_sigma=((float("inf"), params),),
310
+ )
311
+
312
+ @classmethod
313
+ def from_dict(
314
+ cls,
315
+ sigma_to_params: Mapping[float, MultiModalGuiderParams],
316
+ negative_context: torch.Tensor | None = None,
317
+ ) -> "MultiModalGuiderFactory":
318
+ """
319
+ Build a factory from a dict of sigma_value -> MultiModalGuiderParams.
320
+ Keys are sorted descending and used for bin lookup in params(sigma).
321
+ """
322
+ if not sigma_to_params:
323
+ raise ValueError("sigma_to_params must be non-empty")
324
+ sorted_items = tuple(sorted(sigma_to_params.items(), key=lambda x: x[0], reverse=True))
325
+ return cls(negative_context=negative_context, _params_by_sigma=sorted_items)
326
+
327
+ def params(self, sigma: float | torch.Tensor) -> MultiModalGuiderParams:
328
+ """Return params effective for the given sigma (getter; single source of truth)."""
329
+ sigma_val = float(sigma.item() if isinstance(sigma, torch.Tensor) else sigma)
330
+ return _params_for_sigma_from_sorted_dict(sigma_val, self._params_by_sigma)
331
+
332
+ def build_from_sigma(self, sigma: float | torch.Tensor) -> MultiModalGuider:
333
+ """Return a MultiModalGuider with params effective for the given sigma."""
334
+ return MultiModalGuider(
335
+ params=self.params(sigma),
336
+ negative_context=self.negative_context,
337
+ )
338
+
339
+
340
+ def create_multimodal_guider_factory(
341
+ params: MultiModalGuiderParams | MultiModalGuiderFactory,
342
+ negative_context: torch.Tensor | None = None,
343
+ ) -> MultiModalGuiderFactory:
344
+ """
345
+ Create or return a MultiModalGuiderFactory. Pass constant params for a
346
+ single-params factory (uses MultiModalGuiderFactory.constant), or an existing
347
+ MultiModalGuiderFactory. When given a factory, returns it as-is unless
348
+ negative_context is provided. For sigma-dependent params use
349
+ MultiModalGuiderFactory.from_dict(...) and pass that as params.
350
+ """
351
+ if isinstance(params, MultiModalGuiderFactory):
352
+ if negative_context is not None and params.negative_context is not negative_context:
353
+ return MultiModalGuiderFactory.from_dict(dict(params._params_by_sigma), negative_context=negative_context)
354
+ return params
355
+ return MultiModalGuiderFactory.constant(params, negative_context=negative_context)
356
+
357
+
358
+ def projection_coef(to_project: torch.Tensor, project_onto: torch.Tensor) -> torch.Tensor:
359
+ batch_size = to_project.shape[0]
360
+ positive_flat = to_project.reshape(batch_size, -1)
361
+ negative_flat = project_onto.reshape(batch_size, -1)
362
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
363
+ squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
364
+ return dot_product / squared_norm
packages/ltx-core/src/ltx_core/components/noisers.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import replace
2
+ from typing import Protocol
3
+
4
+ import torch
5
+
6
+ from ltx_core.types import LatentState
7
+
8
+
9
+ class Noiser(Protocol):
10
+ """Protocol for adding noise to a latent state during diffusion."""
11
+
12
+ def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ...
13
+
14
+
15
+ class GaussianNoiser(Noiser):
16
+ """Adds Gaussian noise to a latent state, scaled by the denoise mask."""
17
+
18
+ def __init__(self, generator: torch.Generator):
19
+ super().__init__()
20
+
21
+ self.generator = generator
22
+
23
+ def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState:
24
+ noise = torch.randn(
25
+ *latent_state.latent.shape,
26
+ device=latent_state.latent.device,
27
+ dtype=latent_state.latent.dtype,
28
+ generator=self.generator,
29
+ )
30
+ scaled_mask = latent_state.denoise_mask * noise_scale
31
+ latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask)
32
+ return replace(
33
+ latent_state,
34
+ latent=latent.to(latent_state.latent.dtype),
35
+ )
packages/ltx-core/src/ltx_core/components/patchifiers.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import einops
5
+ import torch
6
+
7
+ from ltx_core.components.protocols import Patchifier
8
+ from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape
9
+
10
+
11
+ class VideoLatentPatchifier(Patchifier):
12
+ def __init__(self, patch_size: int):
13
+ # Patch sizes for video latents.
14
+ self._patch_size = (
15
+ 1, # temporal dimension
16
+ patch_size, # height dimension
17
+ patch_size, # width dimension
18
+ )
19
+
20
+ @property
21
+ def patch_size(self) -> Tuple[int, int, int]:
22
+ return self._patch_size
23
+
24
+ def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
25
+ return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
26
+
27
+ def patchify(
28
+ self,
29
+ latents: torch.Tensor,
30
+ ) -> torch.Tensor:
31
+ latents = einops.rearrange(
32
+ latents,
33
+ "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
34
+ p1=self._patch_size[0],
35
+ p2=self._patch_size[1],
36
+ p3=self._patch_size[2],
37
+ )
38
+
39
+ return latents
40
+
41
+ def unpatchify(
42
+ self,
43
+ latents: torch.Tensor,
44
+ output_shape: VideoLatentShape,
45
+ ) -> torch.Tensor:
46
+ assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
47
+
48
+ patch_grid_frames = output_shape.frames // self._patch_size[0]
49
+ patch_grid_height = output_shape.height // self._patch_size[1]
50
+ patch_grid_width = output_shape.width // self._patch_size[2]
51
+
52
+ latents = einops.rearrange(
53
+ latents,
54
+ "b (f h w) (c p q) -> b c f (h p) (w q)",
55
+ f=patch_grid_frames,
56
+ h=patch_grid_height,
57
+ w=patch_grid_width,
58
+ p=self._patch_size[1],
59
+ q=self._patch_size[2],
60
+ )
61
+
62
+ return latents
63
+
64
+ def get_patch_grid_bounds(
65
+ self,
66
+ output_shape: AudioLatentShape | VideoLatentShape,
67
+ device: Optional[torch.device] = None,
68
+ ) -> torch.Tensor:
69
+ """
70
+ Return the per-dimension bounds [inclusive start, exclusive end) for every
71
+ patch produced by `patchify`. The bounds are expressed in the original
72
+ video grid coordinates: frame/time, height, and width.
73
+ The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
74
+ - axis 1 (size 3) enumerates (frame/time, height, width) dimensions
75
+ - axis 3 (size 2) stores `[start, end)` indices within each dimension
76
+ Args:
77
+ output_shape: Video grid description containing frames, height, and width.
78
+ device: Device of the latent tensor.
79
+ """
80
+ if not isinstance(output_shape, VideoLatentShape):
81
+ raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
82
+
83
+ frames = output_shape.frames
84
+ height = output_shape.height
85
+ width = output_shape.width
86
+ batch_size = output_shape.batch
87
+
88
+ # Validate inputs to ensure positive dimensions
89
+ assert frames > 0, f"frames must be positive, got {frames}"
90
+ assert height > 0, f"height must be positive, got {height}"
91
+ assert width > 0, f"width must be positive, got {width}"
92
+ assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
93
+
94
+ # Generate grid coordinates for each dimension (frame, height, width)
95
+ # We use torch.arange to create the starting coordinates for each patch.
96
+ # indexing='ij' ensures the dimensions are in the order (frame, height, width).
97
+ grid_coords = torch.meshgrid(
98
+ torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
99
+ torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
100
+ torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
101
+ indexing="ij",
102
+ )
103
+
104
+ # Stack the grid coordinates to create the start coordinates tensor.
105
+ # Shape becomes (3, grid_f, grid_h, grid_w)
106
+ patch_starts = torch.stack(grid_coords, dim=0)
107
+
108
+ # Create a tensor containing the size of a single patch:
109
+ # (frame_patch_size, height_patch_size, width_patch_size).
110
+ # Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
111
+ patch_size_delta = torch.tensor(
112
+ self._patch_size,
113
+ device=patch_starts.device,
114
+ dtype=patch_starts.dtype,
115
+ ).view(3, 1, 1, 1)
116
+
117
+ # Calculate end coordinates: start + patch_size
118
+ # Shape becomes (3, grid_f, grid_h, grid_w)
119
+ patch_ends = patch_starts + patch_size_delta
120
+
121
+ # Stack start and end coordinates together along the last dimension
122
+ # Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
123
+ latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
124
+
125
+ # Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
126
+ # Final Shape: (batch_size, 3, num_patches, 2)
127
+ latent_coords = einops.repeat(
128
+ latent_coords,
129
+ "c f h w bounds -> b c (f h w) bounds",
130
+ b=batch_size,
131
+ bounds=2,
132
+ )
133
+
134
+ return latent_coords
135
+
136
+
137
+ def get_pixel_coords(
138
+ latent_coords: torch.Tensor,
139
+ scale_factors: SpatioTemporalScaleFactors,
140
+ causal_fix: bool = False,
141
+ ) -> torch.Tensor:
142
+ """
143
+ Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
144
+ each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
145
+ Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
146
+ Args:
147
+ latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
148
+ scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
149
+ per axis.
150
+ causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
151
+ that treat frame zero differently still yield non-negative timestamps.
152
+ """
153
+ # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
154
+ broadcast_shape = [1] * latent_coords.ndim
155
+ broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
156
+ scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
157
+
158
+ # Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
159
+ pixel_coords = latent_coords * scale_tensor
160
+
161
+ if causal_fix:
162
+ # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
163
+ # Shift and clamp to keep the first-frame timestamps causal and non-negative.
164
+ pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
165
+
166
+ return pixel_coords
167
+
168
+
169
+ class AudioPatchifier(Patchifier):
170
+ def __init__(
171
+ self,
172
+ patch_size: int,
173
+ sample_rate: int = 16000,
174
+ hop_length: int = 160,
175
+ audio_latent_downsample_factor: int = 4,
176
+ is_causal: bool = True,
177
+ shift: int = 0,
178
+ ):
179
+ """
180
+ Patchifier tailored for spectrogram/audio latents.
181
+ Args:
182
+ patch_size: Number of mel bins combined into a single patch. This
183
+ controls the resolution along the frequency axis.
184
+ sample_rate: Original waveform sampling rate. Used to map latent
185
+ indices back to seconds so downstream consumers can align audio
186
+ and video cues.
187
+ hop_length: Window hop length used for the spectrogram. Determines
188
+ how many real-time samples separate two consecutive latent frames.
189
+ audio_latent_downsample_factor: Ratio between spectrogram frames and
190
+ latent frames; compensates for additional downsampling inside the
191
+ VAE encoder.
192
+ is_causal: When True, timing is shifted to account for causal
193
+ receptive fields so timestamps do not peek into the future.
194
+ shift: Integer offset applied to the latent indices. Enables
195
+ constructing overlapping windows from the same latent sequence.
196
+ """
197
+ self.hop_length = hop_length
198
+ self.sample_rate = sample_rate
199
+ self.audio_latent_downsample_factor = audio_latent_downsample_factor
200
+ self.is_causal = is_causal
201
+ self.shift = shift
202
+ self._patch_size = (1, patch_size, patch_size)
203
+
204
+ @property
205
+ def patch_size(self) -> Tuple[int, int, int]:
206
+ return self._patch_size
207
+
208
+ def get_token_count(self, tgt_shape: AudioLatentShape) -> int:
209
+ return tgt_shape.frames
210
+
211
+ def _get_audio_latent_time_in_sec(
212
+ self,
213
+ start_latent: int,
214
+ end_latent: int,
215
+ dtype: torch.dtype,
216
+ device: Optional[torch.device] = None,
217
+ ) -> torch.Tensor:
218
+ """
219
+ Converts latent indices into real-time seconds while honoring causal
220
+ offsets and the configured hop length.
221
+ Args:
222
+ start_latent: Inclusive start index inside the latent sequence. This
223
+ sets the first timestamp returned.
224
+ end_latent: Exclusive end index. Determines how many timestamps get
225
+ generated.
226
+ dtype: Floating-point dtype used for the returned tensor, allowing
227
+ callers to control precision.
228
+ device: Target device for the timestamp tensor. When omitted the
229
+ computation occurs on CPU to avoid surprising GPU allocations.
230
+ """
231
+ if device is None:
232
+ device = torch.device("cpu")
233
+
234
+ audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
235
+
236
+ audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
237
+
238
+ if self.is_causal:
239
+ # Frame offset for causal alignment.
240
+ # The "+1" ensures the timestamp corresponds to the first sample that is fully available.
241
+ causal_offset = 1
242
+ audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)
243
+
244
+ return audio_mel_frame * self.hop_length / self.sample_rate
245
+
246
+ def _compute_audio_timings(
247
+ self,
248
+ batch_size: int,
249
+ num_steps: int,
250
+ device: Optional[torch.device] = None,
251
+ ) -> torch.Tensor:
252
+ """
253
+ Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.
254
+ This helper method underpins `get_patch_grid_bounds` for the audio patchifier.
255
+ Args:
256
+ batch_size: Number of sequences to broadcast the timings over.
257
+ num_steps: Number of latent frames (time steps) to convert into timestamps.
258
+ device: Device on which the resulting tensor should reside.
259
+ """
260
+ resolved_device = device
261
+ if resolved_device is None:
262
+ resolved_device = torch.device("cpu")
263
+
264
+ start_timings = self._get_audio_latent_time_in_sec(
265
+ self.shift,
266
+ num_steps + self.shift,
267
+ torch.float32,
268
+ resolved_device,
269
+ )
270
+ start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
271
+
272
+ end_timings = self._get_audio_latent_time_in_sec(
273
+ self.shift + 1,
274
+ num_steps + self.shift + 1,
275
+ torch.float32,
276
+ resolved_device,
277
+ )
278
+ end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
279
+
280
+ return torch.stack([start_timings, end_timings], dim=-1)
281
+
282
+ def patchify(
283
+ self,
284
+ audio_latents: torch.Tensor,
285
+ ) -> torch.Tensor:
286
+ """
287
+ Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`
288
+ to derive timestamps for each latent frame based on the configured hop
289
+ length and downsampling.
290
+ Args:
291
+ audio_latents: Latent tensor to patchify.
292
+ Returns:
293
+ Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the
294
+ corresponding timing metadata when needed.
295
+ """
296
+ audio_latents = einops.rearrange(
297
+ audio_latents,
298
+ "b c t f -> b t (c f)",
299
+ )
300
+
301
+ return audio_latents
302
+
303
+ def unpatchify(
304
+ self,
305
+ audio_latents: torch.Tensor,
306
+ output_shape: AudioLatentShape,
307
+ ) -> torch.Tensor:
308
+ """
309
+ Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.
310
+ Use `get_patch_grid_bounds` to recompute the timestamps that describe each
311
+ frame's position in real time.
312
+ Args:
313
+ audio_latents: Latent tensor to unpatchify.
314
+ output_shape: Shape of the unpatched output tensor.
315
+ Returns:
316
+ Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing
317
+ metadata associated with the restored latents.
318
+ """
319
+ # audio_latents shape: (batch, time, freq * channels)
320
+ audio_latents = einops.rearrange(
321
+ audio_latents,
322
+ "b t (c f) -> b c t f",
323
+ c=output_shape.channels,
324
+ f=output_shape.mel_bins,
325
+ )
326
+
327
+ return audio_latents
328
+
329
+ def get_patch_grid_bounds(
330
+ self,
331
+ output_shape: AudioLatentShape | VideoLatentShape,
332
+ device: Optional[torch.device] = None,
333
+ ) -> torch.Tensor:
334
+ """
335
+ Return the temporal bounds `[inclusive start, exclusive end)` for every
336
+ patch emitted by `patchify`. For audio this corresponds to timestamps in
337
+ seconds aligned with the original spectrogram grid.
338
+ The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:
339
+ - axis 1 (size 1) represents the temporal dimension
340
+ - axis 3 (size 2) stores the `[start, end)` timestamps per patch
341
+ Args:
342
+ output_shape: Audio grid specification describing the number of time steps.
343
+ device: Target device for the returned tensor.
344
+ """
345
+ if not isinstance(output_shape, AudioLatentShape):
346
+ raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates")
347
+
348
+ return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)
packages/ltx-core/src/ltx_core/components/protocols.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.types import AudioLatentShape, VideoLatentShape
6
+
7
+
8
+ class Patchifier(Protocol):
9
+ """
10
+ Protocol for patchifiers that convert latent tensors into patches and assemble them back.
11
+ """
12
+
13
+ def patchify(
14
+ self,
15
+ latents: torch.Tensor,
16
+ ) -> torch.Tensor:
17
+ ...
18
+ """
19
+ Convert latent tensors into flattened patch tokens.
20
+ Args:
21
+ latents: Latent tensor to patchify.
22
+ Returns:
23
+ Flattened patch tokens tensor.
24
+ """
25
+
26
+ def unpatchify(
27
+ self,
28
+ latents: torch.Tensor,
29
+ output_shape: AudioLatentShape | VideoLatentShape,
30
+ ) -> torch.Tensor:
31
+ """
32
+ Converts latent tensors between spatio-temporal formats and flattened sequence representations.
33
+ Args:
34
+ latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
35
+ output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
36
+ VideoLatentShape.
37
+ Returns:
38
+ Dense latent tensor restored from the flattened representation.
39
+ """
40
+
41
+ @property
42
+ def patch_size(self) -> Tuple[int, int, int]:
43
+ ...
44
+ """
45
+ Returns the patch size as a tuple of (temporal, height, width) dimensions
46
+ """
47
+
48
+ def get_patch_grid_bounds(
49
+ self,
50
+ output_shape: AudioLatentShape | VideoLatentShape,
51
+ device: torch.device | None = None,
52
+ ) -> torch.Tensor:
53
+ ...
54
+ """
55
+ Compute metadata describing where each latent patch resides within the
56
+ grid specified by `output_shape`.
57
+ Args:
58
+ output_shape: Target grid layout for the patches.
59
+ device: Target device for the returned tensor.
60
+ Returns:
61
+ Tensor containing patch coordinate metadata such as spatial or temporal intervals.
62
+ """
63
+
64
+
65
+ class SchedulerProtocol(Protocol):
66
+ """
67
+ Protocol for schedulers that provide a sigmas schedule tensor for a
68
+ given number of steps. Device is cpu.
69
+ """
70
+
71
+ def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ...
72
+
73
+
74
+ class GuiderProtocol(Protocol):
75
+ """
76
+ Protocol for guiders that compute a delta tensor given conditioning inputs.
77
+ The returned delta should be added to the conditional output (cond), enabling
78
+ multiple guiders to be chained together by accumulating their deltas.
79
+ """
80
+
81
+ scale: float
82
+
83
+ def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ...
84
+
85
+ def enabled(self) -> bool:
86
+ """
87
+ Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale
88
+ is 1.0.
89
+ """
90
+ ...
91
+
92
+
93
+ class DiffusionStepProtocol(Protocol):
94
+ """
95
+ Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor,
96
+ current denoised sample tensor, and sigmas tensor.
97
+ """
98
+
99
+ def step(
100
+ self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **kwargs
101
+ ) -> torch.Tensor: ...
packages/ltx-core/src/ltx_core/components/schedulers.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import lru_cache
3
+
4
+ import numpy
5
+ import scipy
6
+ import torch
7
+
8
+ from ltx_core.components.protocols import SchedulerProtocol
9
+
10
+ BASE_SHIFT_ANCHOR = 1024
11
+ MAX_SHIFT_ANCHOR = 4096
12
+
13
+
14
+ class LTX2Scheduler(SchedulerProtocol):
15
+ """
16
+ Default scheduler for LTX-2 diffusion sampling.
17
+ Generates a sigma schedule with token-count-dependent shifting and optional
18
+ stretching to a terminal value.
19
+ """
20
+
21
+ def execute(
22
+ self,
23
+ steps: int,
24
+ latent: torch.Tensor | None = None,
25
+ max_shift: float = 2.05,
26
+ base_shift: float = 0.95,
27
+ stretch: bool = True,
28
+ terminal: float = 0.1,
29
+ default_number_of_tokens: int = MAX_SHIFT_ANCHOR,
30
+ **_kwargs,
31
+ ) -> torch.FloatTensor:
32
+ tokens = math.prod(latent.shape[2:]) if latent is not None else default_number_of_tokens
33
+ sigmas = torch.linspace(1.0, 0.0, steps + 1)
34
+
35
+ x1 = BASE_SHIFT_ANCHOR
36
+ x2 = MAX_SHIFT_ANCHOR
37
+ mm = (max_shift - base_shift) / (x2 - x1)
38
+ b = base_shift - mm * x1
39
+ sigma_shift = (tokens) * mm + b
40
+
41
+ power = 1
42
+ sigmas = torch.where(
43
+ sigmas != 0,
44
+ math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
45
+ 0,
46
+ )
47
+
48
+ # Stretch sigmas so that its final value matches the given terminal value.
49
+ if stretch:
50
+ non_zero_mask = sigmas != 0
51
+ non_zero_sigmas = sigmas[non_zero_mask]
52
+ one_minus_z = 1.0 - non_zero_sigmas
53
+ scale_factor = one_minus_z[-1] / (1.0 - terminal)
54
+ stretched = 1.0 - (one_minus_z / scale_factor)
55
+ sigmas[non_zero_mask] = stretched
56
+
57
+ return sigmas.to(torch.float32)
58
+
59
+
60
+ class LinearQuadraticScheduler(SchedulerProtocol):
61
+ """
62
+ Scheduler with linear steps followed by quadratic steps.
63
+ Produces a sigma schedule that transitions linearly up to a threshold,
64
+ then follows a quadratic curve for the remaining steps.
65
+ """
66
+
67
+ def execute(
68
+ self, steps: int, threshold_noise: float = 0.025, linear_steps: int | None = None, **_kwargs
69
+ ) -> torch.FloatTensor:
70
+ if steps == 1:
71
+ return torch.FloatTensor([1.0, 0.0])
72
+
73
+ if linear_steps is None:
74
+ linear_steps = steps // 2
75
+ linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
76
+ threshold_noise_step_diff = linear_steps - threshold_noise * steps
77
+ quadratic_steps = steps - linear_steps
78
+ quadratic_sigma_schedule = []
79
+ if quadratic_steps > 0:
80
+ quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
81
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
82
+ const = quadratic_coef * (linear_steps**2)
83
+ quadratic_sigma_schedule = [
84
+ quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, steps)
85
+ ]
86
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
87
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
88
+ return torch.FloatTensor(sigma_schedule)
89
+
90
+
91
+ class BetaScheduler(SchedulerProtocol):
92
+ """
93
+ Scheduler using a beta distribution to sample timesteps.
94
+ Based on: https://arxiv.org/abs/2407.12173
95
+ """
96
+
97
+ shift = 2.37
98
+ timesteps_length = 10000
99
+
100
+ def execute(self, steps: int, alpha: float = 0.6, beta: float = 0.6) -> torch.FloatTensor:
101
+ """
102
+ Execute the beta scheduler.
103
+ Args:
104
+ steps: The number of steps to execute the scheduler for.
105
+ alpha: The alpha parameter for the beta distribution.
106
+ beta: The beta parameter for the beta distribution.
107
+ Warnings:
108
+ The number of steps within `sigmas` theoretically might be less than `steps+1`,
109
+ because of the deduplication of the identical timesteps
110
+ Returns:
111
+ A tensor of sigmas.
112
+ """
113
+ model_sampling_sigmas = _precalculate_model_sampling_sigmas(self.shift, self.timesteps_length)
114
+ total_timesteps = len(model_sampling_sigmas) - 1
115
+ ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
116
+ ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps).tolist()
117
+ ts = list(dict.fromkeys(ts))
118
+
119
+ sigmas = [float(model_sampling_sigmas[int(t)]) for t in ts] + [0.0]
120
+ return torch.FloatTensor(sigmas)
121
+
122
+
123
+ @lru_cache(maxsize=5)
124
+ def _precalculate_model_sampling_sigmas(shift: float, timesteps_length: int) -> torch.Tensor:
125
+ timesteps = torch.arange(1, timesteps_length + 1, 1) / timesteps_length
126
+ return torch.Tensor([flux_time_shift(shift, 1.0, t) for t in timesteps])
127
+
128
+
129
+ def flux_time_shift(mu: float, sigma: float, t: float) -> float:
130
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
packages/ltx-core/src/ltx_core/conditioning/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conditioning utilities: latent state, tools, and conditioning types."""
2
+
3
+ from ltx_core.conditioning.exceptions import ConditioningError
4
+ from ltx_core.conditioning.item import ConditioningItem
5
+ from ltx_core.conditioning.types import (
6
+ ConditioningItemAttentionStrengthWrapper,
7
+ VideoConditionByKeyframeIndex,
8
+ VideoConditionByLatentIndex,
9
+ VideoConditionByReferenceLatent,
10
+ )
11
+
12
+ __all__ = [
13
+ "ConditioningError",
14
+ "ConditioningItem",
15
+ "ConditioningItemAttentionStrengthWrapper",
16
+ "VideoConditionByKeyframeIndex",
17
+ "VideoConditionByLatentIndex",
18
+ "VideoConditionByReferenceLatent",
19
+ ]
packages/ltx-core/src/ltx_core/conditioning/exceptions.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ class ConditioningError(Exception):
2
+ """
3
+ Class for conditioning-related errors.
4
+ """
packages/ltx-core/src/ltx_core/conditioning/item.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol
2
+
3
+ from ltx_core.tools import LatentTools
4
+ from ltx_core.types import LatentState
5
+
6
+
7
+ class ConditioningItem(Protocol):
8
+ """Protocol for conditioning items that modify latent state during diffusion."""
9
+
10
+ def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
11
+ """
12
+ Apply the conditioning to the latent state.
13
+ Args:
14
+ latent_state: The latent state to apply the conditioning to. This is state always patchified.
15
+ Returns:
16
+ The latent state after the conditioning has been applied.
17
+ IMPORTANT: If the conditioning needs to add extra tokens to the latent, it should add them to the end of the
18
+ latent.
19
+ """
20
+ ...
packages/ltx-core/src/ltx_core/conditioning/mask_utils.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for building 2D self-attention masks for conditioning items."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import torch
8
+
9
+ if TYPE_CHECKING:
10
+ from ltx_core.types import LatentState
11
+
12
+
13
+ def resolve_cross_mask(
14
+ attention_mask: float | int | torch.Tensor,
15
+ num_new_tokens: int,
16
+ batch_size: int,
17
+ device: torch.device,
18
+ dtype: torch.dtype,
19
+ ) -> torch.Tensor:
20
+ """Convert an attention_mask (scalar or tensor) to a (B, M) cross_mask tensor.
21
+ Args:
22
+ attention_mask: Scalar value applied uniformly, 1D tensor of shape (M,)
23
+ broadcast across batch, or 2D tensor of shape (B, M).
24
+ num_new_tokens: Number of new conditioning tokens M.
25
+ batch_size: Batch size B.
26
+ device: Device for the output tensor.
27
+ dtype: Data type for the output tensor.
28
+ Returns:
29
+ Cross-mask tensor of shape (B, M).
30
+ """
31
+ if isinstance(attention_mask, (int, float)):
32
+ return torch.full(
33
+ (batch_size, num_new_tokens),
34
+ fill_value=float(attention_mask),
35
+ device=device,
36
+ dtype=dtype,
37
+ )
38
+ mask = attention_mask.to(device=device, dtype=dtype)
39
+
40
+ # Handle scalar (0-D) tensor like a Python scalar.
41
+ if mask.dim() == 0:
42
+ return torch.full(
43
+ (batch_size, num_new_tokens),
44
+ fill_value=float(mask.item()),
45
+ device=device,
46
+ dtype=dtype,
47
+ )
48
+
49
+ if mask.dim() == 1:
50
+ if mask.shape[0] != num_new_tokens:
51
+ raise ValueError(
52
+ f"1-D attention_mask length must equal num_new_tokens ({num_new_tokens}), got shape {tuple(mask.shape)}"
53
+ )
54
+ mask = mask.unsqueeze(0).expand(batch_size, -1)
55
+ elif mask.dim() == 2:
56
+ b, m = mask.shape
57
+ if m != num_new_tokens:
58
+ raise ValueError(
59
+ f"2-D attention_mask second dimension must equal num_new_tokens ({num_new_tokens}), "
60
+ f"got shape {tuple(mask.shape)}"
61
+ )
62
+ if b not in (batch_size, 1):
63
+ raise ValueError(
64
+ f"2-D attention_mask batch dimension must equal batch_size ({batch_size}) or 1, "
65
+ f"got shape {tuple(mask.shape)}"
66
+ )
67
+ if b == 1 and batch_size > 1:
68
+ mask = mask.expand(batch_size, -1)
69
+ else:
70
+ raise ValueError(
71
+ f"attention_mask tensor must be 0-D, 1-D, or 2-D, got {mask.dim()}-D with shape {tuple(mask.shape)}"
72
+ )
73
+ return mask
74
+
75
+
76
+ def update_attention_mask(
77
+ latent_state: LatentState,
78
+ attention_mask: float | torch.Tensor | None,
79
+ num_noisy_tokens: int,
80
+ num_new_tokens: int,
81
+ batch_size: int,
82
+ device: torch.device,
83
+ dtype: torch.dtype,
84
+ ) -> torch.Tensor | None:
85
+ """Build or update the self-attention mask for newly appended conditioning tokens.
86
+ If *attention_mask* is ``None`` and no existing mask is present, returns
87
+ ``None``. If *attention_mask* is ``None`` but an existing mask is present,
88
+ the mask is expanded with full attention (1s) for the new tokens so that
89
+ its dimensions stay consistent with the growing latent sequence. Otherwise,
90
+ resolves *attention_mask* to a per-token cross-mask and expands the 2-D
91
+ attention mask via :func:`build_attention_mask`.
92
+ Args:
93
+ latent_state: Current latent state (provides the existing mask and total
94
+ existing-token count).
95
+ attention_mask: Per-token attention weight. Scalar, 1-D ``(M,)``, 2-D
96
+ ``(B, M)`` tensor, or ``None`` (no-op).
97
+ num_noisy_tokens: Number of original noisy tokens (from
98
+ ``latent_tools.target_shape.token_count()``).
99
+ num_new_tokens: Number of new conditioning tokens being appended.
100
+ batch_size: Batch size.
101
+ device: Device for the output tensor.
102
+ dtype: Data type for the output tensor.
103
+ Returns:
104
+ Updated attention mask of shape ``(B, N+M, N+M)``, or ``None`` if no
105
+ masking is needed.
106
+ """
107
+ if attention_mask is None:
108
+ if latent_state.attention_mask is None:
109
+ return None
110
+ # Existing mask present but no new mask requested: pad with 1s (full
111
+ # attention) so the mask dimensions stay consistent with the growing
112
+ # latent sequence.
113
+ cross_mask = torch.ones(batch_size, num_new_tokens, device=device, dtype=dtype)
114
+ return build_attention_mask(
115
+ existing_mask=latent_state.attention_mask,
116
+ num_noisy_tokens=num_noisy_tokens,
117
+ num_new_tokens=num_new_tokens,
118
+ num_existing_tokens=latent_state.latent.shape[1],
119
+ cross_mask=cross_mask,
120
+ device=device,
121
+ dtype=dtype,
122
+ )
123
+
124
+ cross_mask = resolve_cross_mask(attention_mask, num_new_tokens, batch_size, device, dtype)
125
+ return build_attention_mask(
126
+ existing_mask=latent_state.attention_mask,
127
+ num_noisy_tokens=num_noisy_tokens,
128
+ num_new_tokens=num_new_tokens,
129
+ num_existing_tokens=latent_state.latent.shape[1],
130
+ cross_mask=cross_mask,
131
+ device=device,
132
+ dtype=dtype,
133
+ )
134
+
135
+
136
+ def build_attention_mask(
137
+ existing_mask: torch.Tensor | None,
138
+ num_noisy_tokens: int,
139
+ num_new_tokens: int,
140
+ num_existing_tokens: int,
141
+ cross_mask: torch.Tensor,
142
+ device: torch.device,
143
+ dtype: torch.dtype,
144
+ ) -> torch.Tensor:
145
+ """
146
+ Expand the attention mask to include newly appended conditioning tokens.
147
+ Each conditioning item appends M new reference tokens to the sequence. This function
148
+ builds a (B, N+M, N+M) attention mask with the following block structure:
149
+ noisy prev_ref new_ref
150
+ (N_noisy) (N-N_noisy) (M)
151
+ ┌───────────┬───────────┬───────────┐
152
+ noisy │ │ │ │
153
+ (N_noisy) │ existing │ existing │ cross │
154
+ │ │ │ │
155
+ ├───────────┼───────────┼───────────┤
156
+ prev_ref │ │ │ │
157
+ (N-N_noisy)│ existing │ existing │ 0 │
158
+ │ │ │ │
159
+ ├───────────┼───────────┼───────────┤
160
+ new_ref │ │ │ │
161
+ (M) │ cross │ 0 │ 1 │
162
+ │ │ │ │
163
+ └───────────┴───────────┴───────────┘
164
+ Where:
165
+ - **existing**: preserved from the previous mask (or 1.0 if first conditioning)
166
+ - **cross**: values from *cross_mask* (shape B, M), in [0, 1]
167
+ - **0**: no attention between different reference groups
168
+ Args:
169
+ existing_mask: Current attention mask of shape (B, N, N), or None if no mask exists yet.
170
+ When None, the top-left NxN block is filled with 1s (full attention between all
171
+ existing tokens including any prior reference tokens that had no mask).
172
+ num_noisy_tokens: Number of original noisy tokens (always at positions [0:num_noisy_tokens]).
173
+ num_new_tokens: Number of new conditioning tokens M being appended.
174
+ num_existing_tokens: Total number of current tokens N (noisy + any prior conditioning tokens).
175
+ cross_mask: Per-token attention weight of shape (B, M) controlling attention between
176
+ new reference tokens and noisy tokens. Values in [0, 1].
177
+ device: Device for the output tensor.
178
+ dtype: Data type for the output tensor.
179
+ Returns:
180
+ Attention mask of shape (B, N+M, N+M) with values in [0, 1].
181
+ """
182
+ batch_size = cross_mask.shape[0]
183
+ total = num_existing_tokens + num_new_tokens
184
+
185
+ # Start with zeros
186
+ mask = torch.zeros((batch_size, total, total), device=device, dtype=dtype)
187
+
188
+ # Top-left: preserve existing mask or fill with 1s for noisy tokens
189
+ if existing_mask is not None:
190
+ mask[:, :num_existing_tokens, :num_existing_tokens] = existing_mask
191
+ else:
192
+ mask[:, :num_existing_tokens, :num_existing_tokens] = 1.0
193
+
194
+ # Bottom-right: new reference tokens fully attend to themselves
195
+ mask[:, num_existing_tokens:, num_existing_tokens:] = 1.0
196
+
197
+ # Cross-attention between noisy tokens and new reference tokens
198
+ # cross_mask shape: (B, M) -> broadcast to (B, N_noisy, M) and (B, M, N_noisy)
199
+
200
+ # Noisy tokens attending to new reference tokens: [0:N_noisy, N:N+M]
201
+ # Each column j in this block gets cross_mask[:, j]
202
+ mask[:, :num_noisy_tokens, num_existing_tokens:] = cross_mask.unsqueeze(1)
203
+
204
+ # New reference tokens attending to noisy tokens: [N:N+M, 0:N_noisy]
205
+ # Each row i in this block gets cross_mask[:, i]
206
+ mask[:, num_existing_tokens:, :num_noisy_tokens] = cross_mask.unsqueeze(2)
207
+
208
+ # [N_noisy:N, N:N+M] and [N:N+M, N_noisy:N] remain 0 (no cross-ref attention)
209
+
210
+ return mask
packages/ltx-core/src/ltx_core/guidance/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Guidance and perturbation utilities for attention manipulation."""
2
+
3
+ from ltx_core.guidance.perturbations import (
4
+ BatchedPerturbationConfig,
5
+ Perturbation,
6
+ PerturbationConfig,
7
+ PerturbationType,
8
+ )
9
+
10
+ __all__ = [
11
+ "BatchedPerturbationConfig",
12
+ "Perturbation",
13
+ "PerturbationConfig",
14
+ "PerturbationType",
15
+ ]
packages/ltx-core/src/ltx_core/guidance/perturbations.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+
4
+ import torch
5
+ from torch._prims_common import DeviceLikeType
6
+
7
+
8
+ class PerturbationType(Enum):
9
+ """Types of attention perturbations for STG (Spatio-Temporal Guidance)."""
10
+
11
+ SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn"
12
+ SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn"
13
+ SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn"
14
+ SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn"
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class Perturbation:
19
+ """A single perturbation specifying which attention type to skip and in which blocks."""
20
+
21
+ type: PerturbationType
22
+ blocks: list[int] | None # None means all blocks
23
+
24
+ def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
25
+ if self.type != perturbation_type:
26
+ return False
27
+
28
+ if self.blocks is None:
29
+ return True
30
+
31
+ return block in self.blocks
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class PerturbationConfig:
36
+ """Configuration holding a list of perturbations for a single sample."""
37
+
38
+ perturbations: list[Perturbation] | None
39
+
40
+ def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
41
+ if self.perturbations is None:
42
+ return False
43
+
44
+ return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
45
+
46
+ @staticmethod
47
+ def empty() -> "PerturbationConfig":
48
+ return PerturbationConfig([])
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class BatchedPerturbationConfig:
53
+ """Perturbation configurations for a batch, with utilities for generating attention masks."""
54
+
55
+ perturbations: list[PerturbationConfig]
56
+
57
+ def mask(
58
+ self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype
59
+ ) -> torch.Tensor:
60
+ mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
61
+ for batch_idx, perturbation in enumerate(self.perturbations):
62
+ if perturbation.is_perturbed(perturbation_type, block):
63
+ mask[batch_idx] = 0
64
+
65
+ return mask
66
+
67
+ def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor:
68
+ mask = self.mask(perturbation_type, block, values.device, values.dtype)
69
+ return mask.view(mask.numel(), *([1] * len(values.shape[1:])))
70
+
71
+ def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
72
+ return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
73
+
74
+ def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
75
+ return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
76
+
77
+ @staticmethod
78
+ def empty(batch_size: int) -> "BatchedPerturbationConfig":
79
+ return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
packages/ltx-core/src/ltx_core/loader/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loader utilities for model weights, LoRAs, and safetensor operations."""
2
+
3
+ from ltx_core.loader.fuse_loras import apply_loras
4
+ from ltx_core.loader.module_ops import ModuleOps
5
+ from ltx_core.loader.primitives import (
6
+ LoRAAdaptableProtocol,
7
+ LoraPathStrengthAndSDOps,
8
+ LoraStateDictWithStrength,
9
+ ModelBuilderProtocol,
10
+ StateDict,
11
+ StateDictLoader,
12
+ )
13
+ from ltx_core.loader.registry import DummyRegistry, Registry, StateDictRegistry
14
+ from ltx_core.loader.sd_ops import (
15
+ LTXV_LORA_COMFY_RENAMING_MAP,
16
+ ContentMatching,
17
+ ContentReplacement,
18
+ KeyValueOperation,
19
+ KeyValueOperationResult,
20
+ SDKeyValueOperation,
21
+ SDOps,
22
+ )
23
+ from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader, SafetensorsStateDictLoader
24
+ from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
25
+
26
+ __all__ = [
27
+ "LTXV_LORA_COMFY_RENAMING_MAP",
28
+ "ContentMatching",
29
+ "ContentReplacement",
30
+ "DummyRegistry",
31
+ "KeyValueOperation",
32
+ "KeyValueOperationResult",
33
+ "LoRAAdaptableProtocol",
34
+ "LoraPathStrengthAndSDOps",
35
+ "LoraStateDictWithStrength",
36
+ "ModelBuilderProtocol",
37
+ "ModuleOps",
38
+ "Registry",
39
+ "SDKeyValueOperation",
40
+ "SDOps",
41
+ "SafetensorsModelStateDictLoader",
42
+ "SafetensorsStateDictLoader",
43
+ "SingleGPUModelBuilder",
44
+ "StateDict",
45
+ "StateDictLoader",
46
+ "StateDictRegistry",
47
+ "apply_loras",
48
+ ]
packages/ltx-core/src/ltx_core/loader/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.33 kB). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/fuse_loras.cpython-312.pyc ADDED
Binary file (7.41 kB). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/module_ops.cpython-312.pyc ADDED
Binary file (955 Bytes). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/primitives.cpython-312.pyc ADDED
Binary file (5.37 kB). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/registry.cpython-312.pyc ADDED
Binary file (5.68 kB). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/sd_ops.cpython-312.pyc ADDED
Binary file (6.81 kB). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/sft_loader.cpython-312.pyc ADDED
Binary file (4.36 kB). View file
 
packages/ltx-core/src/ltx_core/loader/__pycache__/single_gpu_model_builder.cpython-312.pyc ADDED
Binary file (8.84 kB). View file
 
packages/ltx-core/src/ltx_core/loader/fuse_loras.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict
4
+ from ltx_core.quantization.fp8_cast import calculate_weight_float8
5
+ from ltx_core.quantization.fp8_scaled_mm import quantize_weight_to_fp8_per_tensor
6
+
7
+
8
+ def apply_loras(
9
+ model_sd: StateDict,
10
+ lora_sd_and_strengths: list[LoraStateDictWithStrength],
11
+ dtype: torch.dtype | None = None,
12
+ destination_sd: StateDict | None = None,
13
+ ) -> StateDict:
14
+ sd = {}
15
+ if destination_sd is not None:
16
+ sd = destination_sd.sd
17
+ size = 0
18
+ device = torch.device("meta")
19
+ inner_dtypes = set()
20
+ for key, weight in model_sd.sd.items():
21
+ if weight is None:
22
+ continue
23
+ # Skip scale keys - they are handled together with their weight keys
24
+ if key.endswith(".weight_scale"):
25
+ continue
26
+ device = weight.device
27
+ target_dtype = dtype if dtype is not None else weight.dtype
28
+ deltas_dtype = target_dtype if target_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16
29
+
30
+ scale_key = key.replace(".weight", ".weight_scale") if key.endswith(".weight") else None
31
+ is_scaled_fp8 = scale_key is not None and scale_key in model_sd.sd
32
+
33
+ deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, device)
34
+ fused = _fuse_deltas(deltas, weight, key, sd, target_dtype, device, is_scaled_fp8, scale_key, model_sd)
35
+
36
+ sd.update(fused)
37
+ for tensor in fused.values():
38
+ inner_dtypes.add(tensor.dtype)
39
+ size += tensor.nbytes
40
+
41
+ if destination_sd is not None:
42
+ return destination_sd
43
+ return StateDict(sd, device, size, inner_dtypes)
44
+
45
+
46
+ def _prepare_deltas(
47
+ lora_sd_and_strengths: list[LoraStateDictWithStrength], key: str, dtype: torch.dtype, device: torch.device
48
+ ) -> torch.Tensor | None:
49
+ deltas = []
50
+ prefix = key[: -len(".weight")]
51
+ key_a = f"{prefix}.lora_A.weight"
52
+ key_b = f"{prefix}.lora_B.weight"
53
+ for lsd, coef in lora_sd_and_strengths:
54
+ if key_a not in lsd.sd or key_b not in lsd.sd:
55
+ continue
56
+ a = lsd.sd[key_a].to(device=device)
57
+ b = lsd.sd[key_b].to(device=device)
58
+ product = torch.matmul(b * coef, a)
59
+ del a, b
60
+ deltas.append(product.to(dtype=dtype))
61
+ if len(deltas) == 0:
62
+ return None
63
+ elif len(deltas) == 1:
64
+ return deltas[0]
65
+ return torch.sum(torch.stack(deltas, dim=0), dim=0)
66
+
67
+
68
+ def _fuse_deltas(
69
+ deltas: torch.Tensor | None,
70
+ weight: torch.Tensor,
71
+ key: str,
72
+ sd: dict[str, torch.Tensor],
73
+ target_dtype: torch.dtype,
74
+ device: torch.device,
75
+ is_scaled_fp8: bool,
76
+ scale_key: str | None,
77
+ model_sd: StateDict,
78
+ ) -> dict[str, torch.Tensor]:
79
+ if deltas is None:
80
+ if key in sd:
81
+ return {}
82
+ fused = _copy_weight_without_lora(weight, key, target_dtype, device, is_scaled_fp8, scale_key, model_sd)
83
+ elif weight.dtype == torch.float8_e4m3fn:
84
+ if is_scaled_fp8:
85
+ fused = _fuse_delta_with_scaled_fp8(deltas, weight, key, scale_key, model_sd)
86
+ else:
87
+ fused = _fuse_delta_with_cast_fp8(deltas, weight, key, target_dtype, device)
88
+ elif weight.dtype == torch.bfloat16:
89
+ fused = _fuse_delta_with_bfloat16(deltas, weight, key, target_dtype)
90
+ else:
91
+ raise ValueError(f"Unsupported dtype: {weight.dtype}")
92
+
93
+ return fused
94
+
95
+
96
+ def _copy_weight_without_lora(
97
+ weight: torch.Tensor,
98
+ key: str,
99
+ target_dtype: torch.dtype,
100
+ device: torch.device,
101
+ is_scaled_fp8: bool,
102
+ scale_key: str | None,
103
+ model_sd: StateDict,
104
+ ) -> dict[str, torch.Tensor]:
105
+ """Copy original weight (and scale if applicable) when no LoRA affects this key."""
106
+ result = {key: weight.clone().to(dtype=target_dtype, device=device)}
107
+ if is_scaled_fp8:
108
+ result[scale_key] = model_sd.sd[scale_key].clone()
109
+ return result
110
+
111
+
112
+ def _fuse_delta_with_scaled_fp8(
113
+ deltas: torch.Tensor,
114
+ weight: torch.Tensor,
115
+ key: str,
116
+ scale_key: str,
117
+ model_sd: StateDict,
118
+ ) -> dict[str, torch.Tensor]:
119
+ """Dequantize scaled FP8 weight, add LoRA delta, and re-quantize."""
120
+ weight_scale = model_sd.sd[scale_key]
121
+
122
+ original_weight = weight.t().to(torch.float32) * weight_scale
123
+
124
+ new_weight = original_weight + deltas.to(torch.float32)
125
+
126
+ new_fp8_weight, new_weight_scale = quantize_weight_to_fp8_per_tensor(new_weight)
127
+ return {key: new_fp8_weight, scale_key: new_weight_scale}
128
+
129
+
130
+ def _fuse_delta_with_cast_fp8(
131
+ deltas: torch.Tensor,
132
+ weight: torch.Tensor,
133
+ key: str,
134
+ target_dtype: torch.dtype,
135
+ device: torch.device,
136
+ ) -> dict[str, torch.Tensor]:
137
+ """Fuse LoRA delta with cast-only FP8 weight (no scale factor)."""
138
+ if str(device).startswith("cuda"):
139
+ deltas = calculate_weight_float8(deltas, weight)
140
+ else:
141
+ deltas.add_(weight.to(dtype=deltas.dtype, device=device))
142
+ return {key: deltas.to(dtype=target_dtype)}
143
+
144
+
145
+ def _fuse_delta_with_bfloat16(
146
+ deltas: torch.Tensor,
147
+ weight: torch.Tensor,
148
+ key: str,
149
+ target_dtype: torch.dtype,
150
+ ) -> dict[str, torch.Tensor]:
151
+ """Fuse LoRA delta with bfloat16 weight."""
152
+ deltas.add_(weight)
153
+ return {key: deltas.to(dtype=target_dtype)}
packages/ltx-core/src/ltx_core/loader/kernels.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: ANN001, ANN201, ERA001, N803, N806
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def fused_add_round_kernel(
8
+ x_ptr,
9
+ output_ptr, # contents will be added to the output
10
+ seed,
11
+ n_elements,
12
+ EXPONENT_BIAS,
13
+ MANTISSA_BITS,
14
+ BLOCK_SIZE: tl.constexpr,
15
+ ):
16
+ """
17
+ A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding
18
+ and add them to bfloat16 output weights. Might be used to upcast original model weights
19
+ and to further add them to precalculated deltas coming from LoRAs.
20
+ """
21
+ # Get program ID and compute offsets
22
+ pid = tl.program_id(axis=0)
23
+ block_start = pid * BLOCK_SIZE
24
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
25
+ mask = offsets < n_elements
26
+
27
+ # Load data
28
+ x = tl.load(x_ptr + offsets, mask=mask)
29
+ rand_vals = tl.rand(seed, offsets) - 0.5
30
+
31
+ x = tl.cast(x, tl.float16)
32
+ delta = tl.load(output_ptr + offsets, mask=mask)
33
+ delta = tl.cast(delta, tl.float16)
34
+ x = x + delta
35
+
36
+ x_bits = tl.cast(x, tl.int16, bitcast=True)
37
+
38
+ # Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for
39
+ # normal numbers and -14 for subnormals.
40
+ fp16_exponent_bits = (x_bits & 0x7C00) >> 10
41
+ fp16_normals = fp16_exponent_bits > 0
42
+ fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14)
43
+
44
+ # Add the target dtype's exponent bias and clamp to the target dtype's exponent range.
45
+ exponent = fp16_exponent + EXPONENT_BIAS
46
+ MAX_EXPONENT = 2 * EXPONENT_BIAS + 1
47
+ exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent)
48
+ exponent = tl.where(exponent < 0, 0, exponent)
49
+
50
+ # Normal ULP exponent, expressed as an fp16 exponent field:
51
+ # (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15
52
+ # Simplifies to: fp16_exponent - MANTISSA_BITS + 15
53
+ # See https://en.wikipedia.org/wiki/Unit_in_the_last_place
54
+ eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15))
55
+
56
+ # Calculate epsilon in the target dtype
57
+ eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True)
58
+
59
+ # Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) ->
60
+ # fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 =
61
+ # 16 - EXPONENT_BIAS - MANTISSA_BITS
62
+ eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True)
63
+ eps = tl.where(exponent > 0, eps_normal, eps_subnormal)
64
+
65
+ # Apply zero mask to epsilon
66
+ eps = tl.where(x == 0, 0.0, eps)
67
+
68
+ # Apply stochastic rounding
69
+ output = tl.cast(x + rand_vals * eps, tl.bfloat16)
70
+
71
+ # Store the result
72
+ tl.store(output_ptr + offsets, output, mask=mask)
packages/ltx-core/src/ltx_core/loader/module_ops.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, NamedTuple
2
+
3
+ import torch
4
+
5
+
6
+ class ModuleOps(NamedTuple):
7
+ """
8
+ Defines a named operation for matching and mutating PyTorch modules.
9
+ Used to selectively transform modules in a model (e.g., replacing layers with quantized versions).
10
+ """
11
+
12
+ name: str
13
+ matcher: Callable[[torch.nn.Module], bool]
14
+ mutator: Callable[[torch.nn.Module], torch.nn.Module]
packages/ltx-core/src/ltx_core/loader/primitives.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import NamedTuple, Protocol
3
+
4
+ import torch
5
+
6
+ from ltx_core.loader.module_ops import ModuleOps
7
+ from ltx_core.loader.sd_ops import SDOps
8
+ from ltx_core.model.model_protocol import ModelType
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class StateDict:
13
+ """
14
+ Immutable container for a PyTorch state dictionary.
15
+ Contains:
16
+ - sd: Dictionary of tensors (weights, buffers, etc.)
17
+ - device: Device where tensors are stored
18
+ - size: Total memory footprint in bytes
19
+ - dtype: Set of tensor dtypes present
20
+ """
21
+
22
+ sd: dict
23
+ device: torch.device
24
+ size: int
25
+ dtype: set[torch.dtype]
26
+
27
+ def footprint(self) -> tuple[int, torch.device]:
28
+ return self.size, self.device
29
+
30
+
31
+ class StateDictLoader(Protocol):
32
+ """
33
+ Protocol for loading state dictionaries from various sources.
34
+ Implementations must provide:
35
+ - metadata: Extract model metadata from a single path
36
+ - load: Load state dict from path(s) and apply SDOps transformations
37
+ """
38
+
39
+ def metadata(self, path: str) -> dict:
40
+ """
41
+ Load metadata from path
42
+ """
43
+
44
+ def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
45
+ """
46
+ Load state dict from path or paths (for sharded model storage) and apply sd_ops
47
+ """
48
+
49
+
50
+ class ModelBuilderProtocol(Protocol[ModelType]):
51
+ """
52
+ Protocol for building PyTorch models from configuration dictionaries.
53
+ Implementations must provide:
54
+ - meta_model: Create a model from configuration dictionary and apply module operations
55
+ - build: Create and initialize a model from state dictionary and apply dtype transformations
56
+ """
57
+
58
+ def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType:
59
+ """
60
+ Create a model on the meta device from a configuration dictionary.
61
+ This decouples model creation from weight loading, allowing the model
62
+ architecture to be instantiated without allocating memory for parameters.
63
+ Args:
64
+ config: Model configuration dictionary.
65
+ module_ops: Optional list of module operations to apply (e.g., quantization).
66
+ Returns:
67
+ Model instance on meta device (no actual memory allocated for parameters).
68
+ """
69
+ ...
70
+
71
+ def build(self, dtype: torch.dtype | None = None) -> ModelType:
72
+ """
73
+ Build the model
74
+ Args:
75
+ dtype: Target dtype for the model, if None, uses the dtype of the model_path model
76
+ Returns:
77
+ Model instance
78
+ """
79
+ ...
80
+
81
+
82
+ class LoRAAdaptableProtocol(Protocol):
83
+ """
84
+ Protocol for models that can be adapted with LoRAs.
85
+ Implementations must provide:
86
+ - lora: Add a LoRA to the model
87
+ """
88
+
89
+ def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol":
90
+ pass
91
+
92
+
93
+ class LoraPathStrengthAndSDOps(NamedTuple):
94
+ """
95
+ Tuple containing a LoRA path, strength, and SDOps for applying to the LoRA state dict.
96
+ """
97
+
98
+ path: str
99
+ strength: float
100
+ sd_ops: SDOps
101
+
102
+
103
+ class LoraStateDictWithStrength(NamedTuple):
104
+ """
105
+ Tuple containing a LoRA state dict and strength for applying to the model.
106
+ """
107
+
108
+ state_dict: StateDict
109
+ strength: float
packages/ltx-core/src/ltx_core/loader/registry.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import threading
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ from typing import Protocol
6
+
7
+ from ltx_core.loader.primitives import StateDict
8
+ from ltx_core.loader.sd_ops import SDOps
9
+
10
+
11
+ class Registry(Protocol):
12
+ """
13
+ Protocol for managing state dictionaries in a registry.
14
+ It is used to store state dictionaries and reuse them later without loading them again.
15
+ Implementations must provide:
16
+ - add: Add a state dictionary to the registry
17
+ - pop: Remove a state dictionary from the registry
18
+ - get: Retrieve a state dictionary from the registry
19
+ - clear: Clear all state dictionaries from the registry
20
+ """
21
+
22
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: ...
23
+
24
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
25
+
26
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
27
+
28
+ def clear(self) -> None: ...
29
+
30
+
31
+ class DummyRegistry(Registry):
32
+ """
33
+ Dummy registry that does not store state dictionaries.
34
+ """
35
+
36
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None:
37
+ pass
38
+
39
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
40
+ pass
41
+
42
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
43
+ pass
44
+
45
+ def clear(self) -> None:
46
+ pass
47
+
48
+
49
+ @dataclass
50
+ class StateDictRegistry(Registry):
51
+ """
52
+ Registry that stores state dictionaries in a dictionary.
53
+ """
54
+
55
+ _state_dicts: dict[str, StateDict] = field(default_factory=dict)
56
+ _lock: threading.Lock = field(default_factory=threading.Lock)
57
+
58
+ def _generate_id(self, paths: list[str], sd_ops: SDOps) -> str:
59
+ m = hashlib.sha256()
60
+ parts = [str(Path(p).resolve()) for p in paths]
61
+ if sd_ops is not None:
62
+ parts.append(sd_ops.name)
63
+ m.update("\0".join(parts).encode("utf-8"))
64
+ return m.hexdigest()
65
+
66
+ def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> str:
67
+ sd_id = self._generate_id(paths, sd_ops)
68
+ with self._lock:
69
+ if sd_id in self._state_dicts:
70
+ raise ValueError(f"State dict retrieved from {paths} with {sd_ops} already added, check with get first")
71
+ self._state_dicts[sd_id] = state_dict
72
+ return sd_id
73
+
74
+ def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
75
+ with self._lock:
76
+ return self._state_dicts.pop(self._generate_id(paths, sd_ops), None)
77
+
78
+ def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
79
+ with self._lock:
80
+ return self._state_dicts.get(self._generate_id(paths, sd_ops), None)
81
+
82
+ def clear(self) -> None:
83
+ with self._lock:
84
+ self._state_dicts.clear()
packages/ltx-core/src/ltx_core/loader/sd_ops.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, replace
2
+ from typing import NamedTuple, Protocol
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass(frozen=True, slots=True)
8
+ class ContentReplacement:
9
+ """
10
+ Represents a content replacement operation.
11
+ Used to replace a specific content with a replacement in a state dict key.
12
+ """
13
+
14
+ content: str
15
+ replacement: str
16
+
17
+
18
+ @dataclass(frozen=True, slots=True)
19
+ class ContentMatching:
20
+ """
21
+ Represents a content matching operation.
22
+ Used to match a specific prefix and suffix in a state dict key.
23
+ """
24
+
25
+ prefix: str = ""
26
+ suffix: str = ""
27
+
28
+
29
+ class KeyValueOperationResult(NamedTuple):
30
+ """
31
+ Represents the result of a key-value operation.
32
+ Contains the new key and value after the operation has been applied.
33
+ """
34
+
35
+ new_key: str
36
+ new_value: torch.Tensor
37
+
38
+
39
+ class KeyValueOperation(Protocol):
40
+ """
41
+ Protocol for key-value operations.
42
+ Used to apply operations to a specific key and value in a state dict.
43
+ """
44
+
45
+ def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ...
46
+
47
+
48
+ @dataclass(frozen=True, slots=True)
49
+ class SDKeyValueOperation:
50
+ """
51
+ Represents a key-value operation.
52
+ Used to apply operations to a specific key and value in a state dict.
53
+ """
54
+
55
+ key_matcher: ContentMatching
56
+ kv_operation: KeyValueOperation
57
+
58
+
59
+ @dataclass(frozen=True, slots=True)
60
+ class SDOps:
61
+ """Immutable class representing state dict key operations."""
62
+
63
+ name: str
64
+ mapping: tuple[
65
+ ContentReplacement | ContentMatching | SDKeyValueOperation, ...
66
+ ] = () # Immutable tuple of (key, value) pairs
67
+
68
+ def with_replacement(self, content: str, replacement: str) -> "SDOps":
69
+ """Create a new SDOps instance with the specified replacement added to the mapping."""
70
+
71
+ new_mapping = (*self.mapping, ContentReplacement(content, replacement))
72
+ return replace(self, mapping=new_mapping)
73
+
74
+ def with_matching(self, prefix: str = "", suffix: str = "") -> "SDOps":
75
+ """Create a new SDOps instance with the specified prefix and suffix matching added to the mapping."""
76
+
77
+ new_mapping = (*self.mapping, ContentMatching(prefix, suffix))
78
+ return replace(self, mapping=new_mapping)
79
+
80
+ def with_kv_operation(
81
+ self,
82
+ operation: KeyValueOperation,
83
+ key_prefix: str = "",
84
+ key_suffix: str = "",
85
+ ) -> "SDOps":
86
+ """Create a new SDOps instance with the specified value operation added to the mapping."""
87
+ key_matcher = ContentMatching(key_prefix, key_suffix)
88
+ sd_kv_operation = SDKeyValueOperation(key_matcher, operation)
89
+ new_mapping = (*self.mapping, sd_kv_operation)
90
+ return replace(self, mapping=new_mapping)
91
+
92
+ def apply_to_key(self, key: str) -> str | None:
93
+ """Apply the mapping to the given name."""
94
+ matchers = [content for content in self.mapping if isinstance(content, ContentMatching)]
95
+ valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers)
96
+ if not valid:
97
+ return None
98
+
99
+ for replacement in self.mapping:
100
+ if not isinstance(replacement, ContentReplacement):
101
+ continue
102
+ if replacement.content in key:
103
+ key = key.replace(replacement.content, replacement.replacement)
104
+ return key
105
+
106
+ def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
107
+ """Apply the value operation to the given name and associated value."""
108
+ for operation in self.mapping:
109
+ if not isinstance(operation, SDKeyValueOperation):
110
+ continue
111
+ if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix):
112
+ return operation.kv_operation(key, value)
113
+ return [KeyValueOperationResult(key, value)]
114
+
115
+
116
+ # Predefined SDOps instances
117
+ LTXV_LORA_COMFY_RENAMING_MAP = (
118
+ SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "")
119
+ )
120
+
121
+ LTXV_LORA_COMFY_TARGET_MAP = (
122
+ SDOps("LTXV_LORA_COMFY_TARGET_MAP")
123
+ .with_matching()
124
+ .with_replacement("diffusion_model.", "")
125
+ .with_replacement(".lora_A.weight", ".weight")
126
+ .with_replacement(".lora_B.weight", ".weight")
127
+ )
packages/ltx-core/src/ltx_core/loader/sft_loader.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import safetensors
4
+ import torch
5
+
6
+ from ltx_core.loader.primitives import StateDict, StateDictLoader
7
+ from ltx_core.loader.sd_ops import SDOps
8
+
9
+
10
+ class SafetensorsStateDictLoader(StateDictLoader):
11
+ """
12
+ Loads weights from safetensors files without metadata support.
13
+ Use this for loading raw weight files. For model files that include
14
+ configuration metadata, use SafetensorsModelStateDictLoader instead.
15
+ """
16
+
17
+ def metadata(self, path: str) -> dict:
18
+ raise NotImplementedError("Not implemented")
19
+
20
+ def load(self, path: str | list[str], sd_ops: SDOps, device: torch.device | None = None) -> StateDict:
21
+ """
22
+ Load state dict from path or paths (for sharded model storage) and apply sd_ops
23
+ """
24
+ sd = {}
25
+ size = 0
26
+ dtype = set()
27
+ device = device or torch.device("cpu")
28
+ model_paths = path if isinstance(path, list) else [path]
29
+ for shard_path in model_paths:
30
+ with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f:
31
+ safetensor_keys = f.keys()
32
+ for name in safetensor_keys:
33
+ expected_name = name if sd_ops is None else sd_ops.apply_to_key(name)
34
+ if expected_name is None:
35
+ continue
36
+ value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False)
37
+ key_value_pairs = ((expected_name, value),)
38
+ if sd_ops is not None:
39
+ key_value_pairs = sd_ops.apply_to_key_value(expected_name, value)
40
+ for key, value in key_value_pairs:
41
+ size += value.nbytes
42
+ dtype.add(value.dtype)
43
+ sd[key] = value
44
+
45
+ return StateDict(sd=sd, device=device, size=size, dtype=dtype)
46
+
47
+
48
+ class SafetensorsModelStateDictLoader(StateDictLoader):
49
+ """
50
+ Loads weights and configuration metadata from safetensors model files.
51
+ Unlike SafetensorsStateDictLoader, this loader can read model configuration
52
+ from the safetensors file metadata via the metadata() method.
53
+ """
54
+
55
+ def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None):
56
+ self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader()
57
+
58
+ def metadata(self, path: str) -> dict:
59
+ with safetensors.safe_open(path, framework="pt") as f:
60
+ meta = f.metadata()
61
+ if meta is None or "config" not in meta:
62
+ return {}
63
+ return json.loads(meta["config"])
64
+
65
+ def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
66
+ return self.weight_loader.load(path, sd_ops, device)
packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass, field, replace
3
+ from typing import Generic
4
+
5
+ import torch
6
+
7
+ from ltx_core.loader.fuse_loras import apply_loras
8
+ from ltx_core.loader.module_ops import ModuleOps
9
+ from ltx_core.loader.primitives import (
10
+ LoRAAdaptableProtocol,
11
+ LoraPathStrengthAndSDOps,
12
+ LoraStateDictWithStrength,
13
+ ModelBuilderProtocol,
14
+ StateDict,
15
+ StateDictLoader,
16
+ )
17
+ from ltx_core.loader.registry import DummyRegistry, Registry
18
+ from ltx_core.loader.sd_ops import SDOps
19
+ from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader
20
+ from ltx_core.model.model_protocol import ModelConfigurator, ModelType
21
+
22
+ logger: logging.Logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol):
27
+ """
28
+ Builder for PyTorch models residing on a single GPU.
29
+ Attributes:
30
+ model_class_configurator: Class responsible for constructing the model from a config dict.
31
+ model_path: Path (or tuple of shard paths) to the model's `.safetensors` checkpoint(s).
32
+ model_sd_ops: Optional state-dict operations applied when loading the model weights.
33
+ module_ops: Sequence of module-level mutations applied to the meta model before weight loading.
34
+ loras: Sequence of LoRA adapters (path, strength, optional sd_ops) to fuse into the model.
35
+ model_loader: Strategy for loading state dicts from disk. Defaults to
36
+ :class:`SafetensorsModelStateDictLoader`.
37
+ registry: Cache for already-loaded state dicts. Defaults to :class:`DummyRegistry` (no caching).
38
+ lora_load_device: Device used when loading LoRA weight tensors from disk. Defaults to
39
+ ``torch.device("cpu")``, which keeps LoRA weights in CPU memory and transfers them to
40
+ the target GPU sequentially during fusion, reducing peak GPU memory usage compared to
41
+ loading all LoRA weights directly onto the GPU at once.
42
+ """
43
+
44
+ model_class_configurator: type[ModelConfigurator[ModelType]]
45
+ model_path: str | tuple[str, ...]
46
+ model_sd_ops: SDOps | None = None
47
+ module_ops: tuple[ModuleOps, ...] = field(default_factory=tuple)
48
+ loras: tuple[LoraPathStrengthAndSDOps, ...] = field(default_factory=tuple)
49
+ model_loader: StateDictLoader = field(default_factory=SafetensorsModelStateDictLoader)
50
+ registry: Registry = field(default_factory=DummyRegistry)
51
+ lora_load_device: torch.device = field(default_factory=lambda: torch.device("cpu"))
52
+
53
+ def lora(self, lora_path: str, strength: float = 1.0, sd_ops: SDOps | None = None) -> "SingleGPUModelBuilder":
54
+ return replace(self, loras=(*self.loras, LoraPathStrengthAndSDOps(lora_path, strength, sd_ops)))
55
+
56
+ def model_config(self) -> dict:
57
+ first_shard_path = self.model_path[0] if isinstance(self.model_path, tuple) else self.model_path
58
+ return self.model_loader.metadata(first_shard_path)
59
+
60
+ def meta_model(self, config: dict, module_ops: tuple[ModuleOps, ...]) -> ModelType:
61
+ with torch.device("meta"):
62
+ model = self.model_class_configurator.from_config(config)
63
+ for module_op in module_ops:
64
+ if module_op.matcher(model):
65
+ model = module_op.mutator(model)
66
+ return model
67
+
68
+ def load_sd(
69
+ self, paths: list[str], registry: Registry, device: torch.device | None, sd_ops: SDOps | None = None
70
+ ) -> StateDict:
71
+ state_dict = registry.get(paths, sd_ops)
72
+ if state_dict is None:
73
+ state_dict = self.model_loader.load(paths, sd_ops=sd_ops, device=device)
74
+ registry.add(paths, sd_ops=sd_ops, state_dict=state_dict)
75
+ return state_dict
76
+
77
+ def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelType:
78
+ uninitialized_params = [name for name, param in meta_model.named_parameters() if str(param.device) == "meta"]
79
+ uninitialized_buffers = [name for name, buffer in meta_model.named_buffers() if str(buffer.device) == "meta"]
80
+ if uninitialized_params or uninitialized_buffers:
81
+ logger.warning(f"Uninitialized parameters or buffers: {uninitialized_params + uninitialized_buffers}")
82
+ return meta_model
83
+ retval = meta_model.to(device)
84
+ return retval
85
+
86
+ def build(self, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ModelType:
87
+ device = torch.device("cuda") if device is None else device
88
+ config = self.model_config()
89
+ meta_model = self.meta_model(config, self.module_ops)
90
+ model_paths = list(self.model_path) if isinstance(self.model_path, tuple) else [self.model_path]
91
+ model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device)
92
+
93
+ lora_strengths = [lora.strength for lora in self.loras]
94
+ if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0):
95
+ sd = model_state_dict.sd
96
+ if dtype is not None:
97
+ sd = {key: value.to(dtype=dtype) for key, value in model_state_dict.sd.items()}
98
+ meta_model.load_state_dict(sd, strict=False, assign=True)
99
+ return self._return_model(meta_model, device)
100
+
101
+ lora_state_dicts = [
102
+ self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=self.lora_load_device)
103
+ for lora in self.loras
104
+ ]
105
+ lora_sd_and_strengths = [
106
+ LoraStateDictWithStrength(sd, strength)
107
+ for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True)
108
+ ]
109
+ final_sd = apply_loras(
110
+ model_sd=model_state_dict,
111
+ lora_sd_and_strengths=lora_sd_and_strengths,
112
+ dtype=dtype,
113
+ destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None,
114
+ )
115
+ meta_model.load_state_dict(final_sd.sd, strict=False, assign=True)
116
+ return self._return_model(meta_model, device)
packages/ltx-core/src/ltx_core/model/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Model definitions for LTX-2."""
2
+
3
+ from ltx_core.model.model_protocol import ModelConfigurator, ModelType
4
+
5
+ __all__ = [
6
+ "ModelConfigurator",
7
+ "ModelType",
8
+ ]
packages/ltx-core/src/ltx_core/model/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (358 Bytes). View file
 
packages/ltx-core/src/ltx_core/model/__pycache__/model_protocol.cpython-312.pyc ADDED
Binary file (807 Bytes). View file
 
packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio VAE model components."""
2
+
3
+ from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder, decode_audio, encode_audio
4
+ from ltx_core.model.audio_vae.model_configurator import (
5
+ AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
6
+ AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
7
+ VOCODER_COMFY_KEYS_FILTER,
8
+ AudioDecoderConfigurator,
9
+ AudioEncoderConfigurator,
10
+ VocoderConfigurator,
11
+ )
12
+ from ltx_core.model.audio_vae.ops import AudioProcessor
13
+ from ltx_core.model.audio_vae.vocoder import Vocoder, VocoderWithBWE
14
+
15
+ __all__ = [
16
+ "AUDIO_VAE_DECODER_COMFY_KEYS_FILTER",
17
+ "AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER",
18
+ "VOCODER_COMFY_KEYS_FILTER",
19
+ "AudioDecoder",
20
+ "AudioDecoderConfigurator",
21
+ "AudioEncoder",
22
+ "AudioEncoderConfigurator",
23
+ "AudioProcessor",
24
+ "Vocoder",
25
+ "VocoderConfigurator",
26
+ "VocoderWithBWE",
27
+ "decode_audio",
28
+ "encode_audio",
29
+ ]
packages/ltx-core/src/ltx_core/model/audio_vae/attention.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.common.normalization import NormType, build_normalization_layer
6
+
7
+
8
+ class AttentionType(Enum):
9
+ """Enum for specifying the attention mechanism type."""
10
+
11
+ VANILLA = "vanilla"
12
+ LINEAR = "linear"
13
+ NONE = "none"
14
+
15
+
16
+ class AttnBlock(torch.nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_channels: int,
20
+ norm_type: NormType = NormType.GROUP,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.in_channels = in_channels
24
+
25
+ self.norm = build_normalization_layer(in_channels, normtype=norm_type)
26
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
27
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
28
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
29
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ h_ = x
33
+ h_ = self.norm(h_)
34
+ q = self.q(h_)
35
+ k = self.k(h_)
36
+ v = self.v(h_)
37
+
38
+ # compute attention
39
+ b, c, h, w = q.shape
40
+ q = q.reshape(b, c, h * w).contiguous()
41
+ q = q.permute(0, 2, 1).contiguous() # b,hw,c
42
+ k = k.reshape(b, c, h * w).contiguous() # b,c,hw
43
+ w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
44
+ w_ = w_ * (int(c) ** (-0.5))
45
+ w_ = torch.nn.functional.softmax(w_, dim=2)
46
+
47
+ # attend to values
48
+ v = v.reshape(b, c, h * w).contiguous()
49
+ w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
50
+ h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
51
+ h_ = h_.reshape(b, c, h, w).contiguous()
52
+
53
+ h_ = self.proj_out(h_)
54
+
55
+ return x + h_
56
+
57
+
58
+ def make_attn(
59
+ in_channels: int,
60
+ attn_type: AttentionType = AttentionType.VANILLA,
61
+ norm_type: NormType = NormType.GROUP,
62
+ ) -> torch.nn.Module:
63
+ match attn_type:
64
+ case AttentionType.VANILLA:
65
+ return AttnBlock(in_channels, norm_type=norm_type)
66
+ case AttentionType.NONE:
67
+ return torch.nn.Identity()
68
+ case AttentionType.LINEAR:
69
+ raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
70
+ case _:
71
+ raise ValueError(f"Unknown attention type: {attn_type}")
packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from ltx_core.components.patchifiers import AudioPatchifier
7
+ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
8
+ from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
9
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
10
+ from ltx_core.model.audio_vae.downsample import build_downsampling_path
11
+ from ltx_core.model.audio_vae.ops import AudioProcessor, PerChannelStatistics
12
+ from ltx_core.model.audio_vae.resnet import ResnetBlock
13
+ from ltx_core.model.audio_vae.upsample import build_upsampling_path
14
+ from ltx_core.model.audio_vae.vocoder import Vocoder
15
+ from ltx_core.model.common.normalization import NormType, build_normalization_layer
16
+ from ltx_core.types import Audio, AudioLatentShape
17
+
18
+ LATENT_DOWNSAMPLE_FACTOR = 4
19
+
20
+
21
+ def build_mid_block(
22
+ channels: int,
23
+ temb_channels: int,
24
+ dropout: float,
25
+ norm_type: NormType,
26
+ causality_axis: CausalityAxis,
27
+ attn_type: AttentionType,
28
+ add_attention: bool,
29
+ ) -> torch.nn.Module:
30
+ """Build the middle block with two ResNet blocks and optional attention."""
31
+ mid = torch.nn.Module()
32
+ mid.block_1 = ResnetBlock(
33
+ in_channels=channels,
34
+ out_channels=channels,
35
+ temb_channels=temb_channels,
36
+ dropout=dropout,
37
+ norm_type=norm_type,
38
+ causality_axis=causality_axis,
39
+ )
40
+ mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity()
41
+ mid.block_2 = ResnetBlock(
42
+ in_channels=channels,
43
+ out_channels=channels,
44
+ temb_channels=temb_channels,
45
+ dropout=dropout,
46
+ norm_type=norm_type,
47
+ causality_axis=causality_axis,
48
+ )
49
+ return mid
50
+
51
+
52
+ def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor:
53
+ """Run features through the middle block."""
54
+ features = mid.block_1(features, temb=None)
55
+ features = mid.attn_1(features)
56
+ return mid.block_2(features, temb=None)
57
+
58
+
59
+ class AudioEncoder(torch.nn.Module):
60
+ """
61
+ Encoder that compresses audio spectrograms into latent representations.
62
+ The encoder uses a series of downsampling blocks with residual connections,
63
+ attention mechanisms, and configurable causal convolutions.
64
+ """
65
+
66
+ def __init__( # noqa: PLR0913
67
+ self,
68
+ *,
69
+ ch: int,
70
+ ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
71
+ num_res_blocks: int,
72
+ attn_resolutions: Set[int],
73
+ dropout: float = 0.0,
74
+ resamp_with_conv: bool = True,
75
+ in_channels: int,
76
+ resolution: int,
77
+ z_channels: int,
78
+ double_z: bool = True,
79
+ attn_type: AttentionType = AttentionType.VANILLA,
80
+ mid_block_add_attention: bool = True,
81
+ norm_type: NormType = NormType.GROUP,
82
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
83
+ sample_rate: int = 16000,
84
+ mel_hop_length: int = 160,
85
+ n_fft: int = 1024,
86
+ is_causal: bool = True,
87
+ mel_bins: int = 64,
88
+ **_ignore_kwargs,
89
+ ) -> None:
90
+ """
91
+ Initialize the Encoder.
92
+ Args:
93
+ Arguments are configuration parameters, loaded from the audio VAE checkpoint config
94
+ (audio_vae.model.params.ddconfig):
95
+ ch: Base number of feature channels used in the first convolution layer.
96
+ ch_mult: Multiplicative factors for the number of channels at each resolution level.
97
+ num_res_blocks: Number of residual blocks to use at each resolution level.
98
+ attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention.
99
+ resolution: Input spatial resolution of the spectrogram (height, width).
100
+ z_channels: Number of channels in the latent representation.
101
+ norm_type: Normalization layer type to use within the network (e.g., group, batch).
102
+ causality_axis: Axis along which convolutions should be causal (e.g., time axis).
103
+ sample_rate: Audio sample rate in Hz for the input signals.
104
+ mel_hop_length: Hop length used when computing the mel spectrogram.
105
+ n_fft: FFT size used to compute the spectrogram.
106
+ mel_bins: Number of mel-frequency bins in the input spectrogram.
107
+ in_channels: Number of channels in the input spectrogram tensor.
108
+ double_z: If True, predict both mean and log-variance (doubling latent channels).
109
+ is_causal: If True, use causal convolutions suitable for streaming setups.
110
+ dropout: Dropout probability used in residual and mid blocks.
111
+ attn_type: Type of attention mechanism to use in attention blocks.
112
+ resamp_with_conv: If True, perform resolution changes using strided convolutions.
113
+ mid_block_add_attention: If True, add an attention block in the mid-level of the encoder.
114
+ """
115
+ super().__init__()
116
+
117
+ self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
118
+ self.sample_rate = sample_rate
119
+ self.mel_hop_length = mel_hop_length
120
+ self.n_fft = n_fft
121
+ self.is_causal = is_causal
122
+ self.mel_bins = mel_bins
123
+
124
+ self.patchifier = AudioPatchifier(
125
+ patch_size=1,
126
+ audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
127
+ sample_rate=sample_rate,
128
+ hop_length=mel_hop_length,
129
+ is_causal=is_causal,
130
+ )
131
+
132
+ self.ch = ch
133
+ self.temb_ch = 0
134
+ self.num_resolutions = len(ch_mult)
135
+ self.num_res_blocks = num_res_blocks
136
+ self.resolution = resolution
137
+ self.in_channels = in_channels
138
+ self.z_channels = z_channels
139
+ self.double_z = double_z
140
+ self.norm_type = norm_type
141
+ self.causality_axis = causality_axis
142
+ self.attn_type = attn_type
143
+
144
+ # downsampling
145
+ self.conv_in = make_conv2d(
146
+ in_channels,
147
+ self.ch,
148
+ kernel_size=3,
149
+ stride=1,
150
+ causality_axis=self.causality_axis,
151
+ )
152
+
153
+ self.non_linearity = torch.nn.SiLU()
154
+
155
+ self.down, block_in = build_downsampling_path(
156
+ ch=ch,
157
+ ch_mult=ch_mult,
158
+ num_resolutions=self.num_resolutions,
159
+ num_res_blocks=num_res_blocks,
160
+ resolution=resolution,
161
+ temb_channels=self.temb_ch,
162
+ dropout=dropout,
163
+ norm_type=self.norm_type,
164
+ causality_axis=self.causality_axis,
165
+ attn_type=self.attn_type,
166
+ attn_resolutions=attn_resolutions,
167
+ resamp_with_conv=resamp_with_conv,
168
+ )
169
+
170
+ self.mid = build_mid_block(
171
+ channels=block_in,
172
+ temb_channels=self.temb_ch,
173
+ dropout=dropout,
174
+ norm_type=self.norm_type,
175
+ causality_axis=self.causality_axis,
176
+ attn_type=self.attn_type,
177
+ add_attention=mid_block_add_attention,
178
+ )
179
+
180
+ self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
181
+ self.conv_out = make_conv2d(
182
+ block_in,
183
+ 2 * z_channels if double_z else z_channels,
184
+ kernel_size=3,
185
+ stride=1,
186
+ causality_axis=self.causality_axis,
187
+ )
188
+
189
+ def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
190
+ """
191
+ Encode audio spectrogram into latent representations.
192
+ Args:
193
+ spectrogram: Input spectrogram of shape (batch, channels, time, frequency)
194
+ Returns:
195
+ Encoded latent representation of shape (batch, channels, frames, mel_bins)
196
+ """
197
+ h = self.conv_in(spectrogram)
198
+ h = self._run_downsampling_path(h)
199
+ h = run_mid_block(self.mid, h)
200
+ h = self._finalize_output(h)
201
+
202
+ return self._normalize_latents(h)
203
+
204
+ def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor:
205
+ for level in range(self.num_resolutions):
206
+ stage = self.down[level]
207
+ for block_idx in range(self.num_res_blocks):
208
+ h = stage.block[block_idx](h, temb=None)
209
+ if stage.attn:
210
+ h = stage.attn[block_idx](h)
211
+
212
+ if level != self.num_resolutions - 1:
213
+ h = stage.downsample(h)
214
+
215
+ return h
216
+
217
+ def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
218
+ h = self.norm_out(h)
219
+ h = self.non_linearity(h)
220
+ return self.conv_out(h)
221
+
222
+ def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor:
223
+ """
224
+ Normalize encoder latents using per-channel statistics.
225
+ When the encoder is configured with ``double_z=True``, the final
226
+ convolution produces twice the number of latent channels, typically
227
+ interpreted as two concatenated tensors along the channel dimension
228
+ (e.g., mean and variance or other auxiliary parameters).
229
+ This method intentionally uses only the first half of the channels
230
+ (the "mean" component) as input to the patchifier and normalization
231
+ logic. The remaining channels are left unchanged by this method and
232
+ are expected to be consumed elsewhere in the VAE pipeline.
233
+ If ``double_z=False``, the encoder output already contains only the
234
+ mean latents and the chunking operation simply returns that tensor.
235
+ """
236
+ means = torch.chunk(latent_output, 2, dim=1)[0]
237
+ latent_shape = AudioLatentShape(
238
+ batch=means.shape[0],
239
+ channels=means.shape[1],
240
+ frames=means.shape[2],
241
+ mel_bins=means.shape[3],
242
+ )
243
+ latent_patched = self.patchifier.patchify(means)
244
+ latent_normalized = self.per_channel_statistics.normalize(latent_patched)
245
+ return self.patchifier.unpatchify(latent_normalized, latent_shape)
246
+
247
+
248
+ def encode_audio(
249
+ audio: Audio,
250
+ audio_encoder: AudioEncoder,
251
+ audio_processor: AudioProcessor | None = None,
252
+ ) -> torch.Tensor:
253
+ """Encode audio waveform into latent representation.
254
+ Args:
255
+ audio: Audio container with waveform tensor of shape (batch, channels, samples) and sampling rate.
256
+ audio_encoder: Audio encoder model
257
+ audio_processor: Audio processor model (optional, if not provided, it will be created from the audio encoder)
258
+ """
259
+ dtype = next(audio_encoder.parameters()).dtype
260
+ device = next(audio_encoder.parameters()).device
261
+
262
+ if audio_processor is None:
263
+ audio_processor = AudioProcessor(
264
+ target_sample_rate=audio_encoder.sample_rate,
265
+ mel_bins=audio_encoder.mel_bins,
266
+ mel_hop_length=audio_encoder.mel_hop_length,
267
+ n_fft=audio_encoder.n_fft,
268
+ ).to(device=device)
269
+
270
+ mel_spectrogram = audio_processor.waveform_to_mel(audio.to(device=device))
271
+
272
+ latent = audio_encoder(mel_spectrogram.to(dtype=dtype))
273
+ return latent
274
+
275
+
276
+ class AudioDecoder(torch.nn.Module):
277
+ """
278
+ Symmetric decoder that reconstructs audio spectrograms from latent features.
279
+ The decoder mirrors the encoder structure with configurable channel multipliers,
280
+ attention resolutions, and causal convolutions.
281
+ """
282
+
283
+ def __init__( # noqa: PLR0913
284
+ self,
285
+ *,
286
+ ch: int,
287
+ out_ch: int,
288
+ ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
289
+ num_res_blocks: int,
290
+ attn_resolutions: Set[int],
291
+ resolution: int,
292
+ z_channels: int,
293
+ norm_type: NormType = NormType.GROUP,
294
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
295
+ dropout: float = 0.0,
296
+ mid_block_add_attention: bool = True,
297
+ sample_rate: int = 16000,
298
+ mel_hop_length: int = 160,
299
+ is_causal: bool = True,
300
+ mel_bins: int | None = None,
301
+ ) -> None:
302
+ """
303
+ Initialize the Decoder.
304
+ Args:
305
+ Arguments are configuration parameters, loaded from the audio VAE checkpoint config
306
+ (audio_vae.model.params.ddconfig):
307
+ - ch, out_ch, ch_mult, num_res_blocks, attn_resolutions
308
+ - resolution, z_channels
309
+ - norm_type, causality_axis
310
+ """
311
+ super().__init__()
312
+
313
+ # Internal behavioural defaults that are not driven by the checkpoint.
314
+ resamp_with_conv = True
315
+ attn_type = AttentionType.VANILLA
316
+
317
+ # Per-channel statistics for denormalizing latents
318
+ self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
319
+ self.sample_rate = sample_rate
320
+ self.mel_hop_length = mel_hop_length
321
+ self.is_causal = is_causal
322
+ self.mel_bins = mel_bins
323
+ self.patchifier = AudioPatchifier(
324
+ patch_size=1,
325
+ audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
326
+ sample_rate=sample_rate,
327
+ hop_length=mel_hop_length,
328
+ is_causal=is_causal,
329
+ )
330
+
331
+ self.ch = ch
332
+ self.temb_ch = 0
333
+ self.num_resolutions = len(ch_mult)
334
+ self.num_res_blocks = num_res_blocks
335
+ self.resolution = resolution
336
+ self.out_ch = out_ch
337
+ self.give_pre_end = False
338
+ self.tanh_out = False
339
+ self.norm_type = norm_type
340
+ self.z_channels = z_channels
341
+ self.channel_multipliers = ch_mult
342
+ self.attn_resolutions = attn_resolutions
343
+ self.causality_axis = causality_axis
344
+ self.attn_type = attn_type
345
+
346
+ base_block_channels = ch * self.channel_multipliers[-1]
347
+ base_resolution = resolution // (2 ** (self.num_resolutions - 1))
348
+ self.z_shape = (1, z_channels, base_resolution, base_resolution)
349
+
350
+ self.conv_in = make_conv2d(
351
+ z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
352
+ )
353
+ self.non_linearity = torch.nn.SiLU()
354
+ self.mid = build_mid_block(
355
+ channels=base_block_channels,
356
+ temb_channels=self.temb_ch,
357
+ dropout=dropout,
358
+ norm_type=self.norm_type,
359
+ causality_axis=self.causality_axis,
360
+ attn_type=self.attn_type,
361
+ add_attention=mid_block_add_attention,
362
+ )
363
+ self.up, final_block_channels = build_upsampling_path(
364
+ ch=ch,
365
+ ch_mult=ch_mult,
366
+ num_resolutions=self.num_resolutions,
367
+ num_res_blocks=num_res_blocks,
368
+ resolution=resolution,
369
+ temb_channels=self.temb_ch,
370
+ dropout=dropout,
371
+ norm_type=self.norm_type,
372
+ causality_axis=self.causality_axis,
373
+ attn_type=self.attn_type,
374
+ attn_resolutions=attn_resolutions,
375
+ resamp_with_conv=resamp_with_conv,
376
+ initial_block_channels=base_block_channels,
377
+ )
378
+
379
+ self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
380
+ self.conv_out = make_conv2d(
381
+ final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
382
+ )
383
+
384
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
385
+ """
386
+ Decode latent features back to audio spectrograms.
387
+ Args:
388
+ sample: Encoded latent representation of shape (batch, channels, frames, mel_bins)
389
+ Returns:
390
+ Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
391
+ """
392
+ sample, target_shape = self._denormalize_latents(sample)
393
+
394
+ h = self.conv_in(sample)
395
+ h = run_mid_block(self.mid, h)
396
+ h = self._run_upsampling_path(h)
397
+ h = self._finalize_output(h)
398
+
399
+ return self._adjust_output_shape(h, target_shape)
400
+
401
+ def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]:
402
+ latent_shape = AudioLatentShape(
403
+ batch=sample.shape[0],
404
+ channels=sample.shape[1],
405
+ frames=sample.shape[2],
406
+ mel_bins=sample.shape[3],
407
+ )
408
+
409
+ sample_patched = self.patchifier.patchify(sample)
410
+ sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
411
+ sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
412
+
413
+ target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
414
+ if self.causality_axis != CausalityAxis.NONE:
415
+ target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
416
+
417
+ target_shape = AudioLatentShape(
418
+ batch=latent_shape.batch,
419
+ channels=self.out_ch,
420
+ frames=target_frames,
421
+ mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
422
+ )
423
+
424
+ return sample, target_shape
425
+
426
+ def _adjust_output_shape(
427
+ self,
428
+ decoded_output: torch.Tensor,
429
+ target_shape: AudioLatentShape,
430
+ ) -> torch.Tensor:
431
+ """
432
+ Adjust output shape to match target dimensions for variable-length audio.
433
+ This function handles the common case where decoded audio spectrograms need to be
434
+ resized to match a specific target shape.
435
+ Args:
436
+ decoded_output: Tensor of shape (batch, channels, time, frequency)
437
+ target_shape: AudioLatentShape describing (batch, channels, time, mel bins)
438
+ Returns:
439
+ Tensor adjusted to match target_shape exactly
440
+ """
441
+ # Current output shape: (batch, channels, time, frequency)
442
+ _, _, current_time, current_freq = decoded_output.shape
443
+ target_channels = target_shape.channels
444
+ target_time = target_shape.frames
445
+ target_freq = target_shape.mel_bins
446
+
447
+ # Step 1: Crop first to avoid exceeding target dimensions
448
+ decoded_output = decoded_output[
449
+ :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
450
+ ]
451
+
452
+ # Step 2: Calculate padding needed for time and frequency dimensions
453
+ time_padding_needed = target_time - decoded_output.shape[2]
454
+ freq_padding_needed = target_freq - decoded_output.shape[3]
455
+
456
+ # Step 3: Apply padding if needed
457
+ if time_padding_needed > 0 or freq_padding_needed > 0:
458
+ # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
459
+ # For audio: pad_left/right = frequency, pad_top/bottom = time
460
+ padding = (
461
+ 0,
462
+ max(freq_padding_needed, 0), # frequency padding (left, right)
463
+ 0,
464
+ max(time_padding_needed, 0), # time padding (top, bottom)
465
+ )
466
+ decoded_output = F.pad(decoded_output, padding)
467
+
468
+ # Step 4: Final safety crop to ensure exact target shape
469
+ decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
470
+
471
+ return decoded_output
472
+
473
+ def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor:
474
+ for level in reversed(range(self.num_resolutions)):
475
+ stage = self.up[level]
476
+ for block_idx, block in enumerate(stage.block):
477
+ h = block(h, temb=None)
478
+ if stage.attn:
479
+ h = stage.attn[block_idx](h)
480
+
481
+ if level != 0 and hasattr(stage, "upsample"):
482
+ h = stage.upsample(h)
483
+
484
+ return h
485
+
486
+ def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
487
+ if self.give_pre_end:
488
+ return h
489
+
490
+ h = self.norm_out(h)
491
+ h = self.non_linearity(h)
492
+ h = self.conv_out(h)
493
+ return torch.tanh(h) if self.tanh_out else h
494
+
495
+
496
+ def decode_audio(latent: torch.Tensor, audio_decoder: "AudioDecoder", vocoder: "Vocoder") -> Audio:
497
+ """
498
+ Decode an audio latent representation using the provided audio decoder and vocoder.
499
+ Args:
500
+ latent: Input audio latent tensor.
501
+ audio_decoder: Model to decode the latent to waveform features.
502
+ vocoder: Model to convert decoded features to audio waveform.
503
+ Returns:
504
+ Decoded audio with waveform and sampling rate.
505
+ """
506
+ decoded_audio = audio_decoder(latent)
507
+ waveform = vocoder(decoded_audio).squeeze(0).float()
508
+ return Audio(waveform=waveform, sampling_rate=vocoder.output_sampling_rate)
packages/ltx-core/src/ltx_core/model/audio_vae/causal_conv_2d.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
5
+
6
+
7
+ class CausalConv2d(torch.nn.Module):
8
+ """
9
+ A causal 2D convolution.
10
+ This layer ensures that the output at time `t` only depends on inputs
11
+ at time `t` and earlier. It achieves this by applying asymmetric padding
12
+ to the time dimension (width) before the convolution.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ in_channels: int,
18
+ out_channels: int,
19
+ kernel_size: int | tuple[int, int],
20
+ stride: int = 1,
21
+ dilation: int | tuple[int, int] = 1,
22
+ groups: int = 1,
23
+ bias: bool = True,
24
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ self.causality_axis = causality_axis
29
+
30
+ # Ensure kernel_size and dilation are tuples
31
+ kernel_size = torch.nn.modules.utils._pair(kernel_size)
32
+ dilation = torch.nn.modules.utils._pair(dilation)
33
+
34
+ # Calculate padding dimensions
35
+ pad_h = (kernel_size[0] - 1) * dilation[0]
36
+ pad_w = (kernel_size[1] - 1) * dilation[1]
37
+
38
+ # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
39
+ match self.causality_axis:
40
+ case CausalityAxis.NONE:
41
+ self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
42
+ case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
43
+ self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
44
+ case CausalityAxis.HEIGHT:
45
+ self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
46
+ case _:
47
+ raise ValueError(f"Invalid causality_axis: {causality_axis}")
48
+
49
+ # The internal convolution layer uses no padding, as we handle it manually
50
+ self.conv = torch.nn.Conv2d(
51
+ in_channels,
52
+ out_channels,
53
+ kernel_size,
54
+ stride=stride,
55
+ padding=0,
56
+ dilation=dilation,
57
+ groups=groups,
58
+ bias=bias,
59
+ )
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ # Apply causal padding before convolution
63
+ x = F.pad(x, self.padding)
64
+ return self.conv(x)
65
+
66
+
67
+ def make_conv2d(
68
+ in_channels: int,
69
+ out_channels: int,
70
+ kernel_size: int | tuple[int, int],
71
+ stride: int = 1,
72
+ padding: tuple[int, int, int, int] | None = None,
73
+ dilation: int = 1,
74
+ groups: int = 1,
75
+ bias: bool = True,
76
+ causality_axis: CausalityAxis | None = None,
77
+ ) -> torch.nn.Module:
78
+ """
79
+ Create a 2D convolution layer that can be either causal or non-causal.
80
+ Args:
81
+ in_channels: Number of input channels
82
+ out_channels: Number of output channels
83
+ kernel_size: Size of the convolution kernel
84
+ stride: Convolution stride
85
+ padding: Padding (if None, will be calculated based on causal flag)
86
+ dilation: Dilation rate
87
+ groups: Number of groups for grouped convolution
88
+ bias: Whether to use bias
89
+ causality_axis: Dimension along which to apply causality.
90
+ Returns:
91
+ Either a regular Conv2d or CausalConv2d layer
92
+ """
93
+ if causality_axis is not None:
94
+ # For causal convolution, padding is handled internally by CausalConv2d
95
+ return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
96
+ else:
97
+ # For non-causal convolution, use symmetric padding if not specified
98
+ if padding is None:
99
+ padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size)
100
+
101
+ return torch.nn.Conv2d(
102
+ in_channels,
103
+ out_channels,
104
+ kernel_size,
105
+ stride,
106
+ padding,
107
+ dilation,
108
+ groups,
109
+ bias,
110
+ )
packages/ltx-core/src/ltx_core/model/audio_vae/downsample.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
6
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
7
+ from ltx_core.model.audio_vae.resnet import ResnetBlock
8
+ from ltx_core.model.common.normalization import NormType
9
+
10
+
11
+ class Downsample(torch.nn.Module):
12
+ """
13
+ A downsampling layer that can use either a strided convolution
14
+ or average pooling. Supports standard and causal padding for the
15
+ convolutional mode.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ in_channels: int,
21
+ with_conv: bool,
22
+ causality_axis: CausalityAxis = CausalityAxis.WIDTH,
23
+ ) -> None:
24
+ super().__init__()
25
+ self.with_conv = with_conv
26
+ self.causality_axis = causality_axis
27
+
28
+ if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
29
+ raise ValueError("causality is only supported when `with_conv=True`.")
30
+
31
+ if self.with_conv:
32
+ # Do time downsampling here
33
+ # no asymmetric padding in torch conv, must do it ourselves
34
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ if self.with_conv:
38
+ # Padding tuple is in the order: (left, right, top, bottom).
39
+ match self.causality_axis:
40
+ case CausalityAxis.NONE:
41
+ pad = (0, 1, 0, 1)
42
+ case CausalityAxis.WIDTH:
43
+ pad = (2, 0, 0, 1)
44
+ case CausalityAxis.HEIGHT:
45
+ pad = (0, 1, 2, 0)
46
+ case CausalityAxis.WIDTH_COMPATIBILITY:
47
+ pad = (1, 0, 0, 1)
48
+ case _:
49
+ raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
50
+
51
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
52
+ x = self.conv(x)
53
+ else:
54
+ # This branch is only taken if with_conv=False, which implies causality_axis is NONE.
55
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
56
+
57
+ return x
58
+
59
+
60
+ def build_downsampling_path( # noqa: PLR0913
61
+ *,
62
+ ch: int,
63
+ ch_mult: Tuple[int, ...],
64
+ num_resolutions: int,
65
+ num_res_blocks: int,
66
+ resolution: int,
67
+ temb_channels: int,
68
+ dropout: float,
69
+ norm_type: NormType,
70
+ causality_axis: CausalityAxis,
71
+ attn_type: AttentionType,
72
+ attn_resolutions: Set[int],
73
+ resamp_with_conv: bool,
74
+ ) -> tuple[torch.nn.ModuleList, int]:
75
+ """Build the downsampling path with residual blocks, attention, and downsampling layers."""
76
+ down_modules = torch.nn.ModuleList()
77
+ curr_res = resolution
78
+ in_ch_mult = (1, *tuple(ch_mult))
79
+ block_in = ch
80
+
81
+ for i_level in range(num_resolutions):
82
+ block = torch.nn.ModuleList()
83
+ attn = torch.nn.ModuleList()
84
+ block_in = ch * in_ch_mult[i_level]
85
+ block_out = ch * ch_mult[i_level]
86
+
87
+ for _ in range(num_res_blocks):
88
+ block.append(
89
+ ResnetBlock(
90
+ in_channels=block_in,
91
+ out_channels=block_out,
92
+ temb_channels=temb_channels,
93
+ dropout=dropout,
94
+ norm_type=norm_type,
95
+ causality_axis=causality_axis,
96
+ )
97
+ )
98
+ block_in = block_out
99
+ if curr_res in attn_resolutions:
100
+ attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
101
+
102
+ down = torch.nn.Module()
103
+ down.block = block
104
+ down.attn = attn
105
+ if i_level != num_resolutions - 1:
106
+ down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
107
+ curr_res = curr_res // 2
108
+ down_modules.append(down)
109
+
110
+ return down_modules, block_in
packages/ltx-core/src/ltx_core/model/audio_vae/model_configurator.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.loader.sd_ops import KeyValueOperationResult, SDOps
4
+ from ltx_core.model.audio_vae.attention import AttentionType
5
+ from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder
6
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
7
+ from ltx_core.model.audio_vae.vocoder import MelSTFT, Vocoder, VocoderWithBWE
8
+ from ltx_core.model.common.normalization import NormType
9
+ from ltx_core.model.model_protocol import ModelConfigurator
10
+ from ltx_core.utils import check_config_value
11
+
12
+
13
+ def _vocoder_from_config(
14
+ cfg: dict,
15
+ apply_final_activation: bool = True,
16
+ output_sampling_rate: int | None = None,
17
+ ) -> Vocoder:
18
+ """Instantiate a Vocoder from a flat config dict.
19
+ Args:
20
+ cfg: Vocoder config dict (keys match Vocoder constructor args).
21
+ apply_final_activation: Whether to apply tanh/clamp at the output.
22
+ output_sampling_rate: Explicit override for the output sample rate.
23
+ When None, reads from cfg["output_sampling_rate"] (default 24000).
24
+ """
25
+ return Vocoder(
26
+ resblock_kernel_sizes=cfg.get("resblock_kernel_sizes", [3, 7, 11]),
27
+ upsample_rates=cfg.get("upsample_rates", [6, 5, 2, 2, 2]),
28
+ upsample_kernel_sizes=cfg.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]),
29
+ resblock_dilation_sizes=cfg.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]),
30
+ upsample_initial_channel=cfg.get("upsample_initial_channel", 1024),
31
+ resblock=cfg.get("resblock", "1"),
32
+ output_sampling_rate=(
33
+ output_sampling_rate if output_sampling_rate is not None else cfg.get("output_sampling_rate", 24000)
34
+ ),
35
+ activation=cfg.get("activation", "snake"),
36
+ use_tanh_at_final=cfg.get("use_tanh_at_final", True),
37
+ apply_final_activation=apply_final_activation,
38
+ use_bias_at_final=cfg.get("use_bias_at_final", True),
39
+ )
40
+
41
+
42
+ class VocoderConfigurator(ModelConfigurator[Vocoder]):
43
+ """Configurator that auto-detects the checkpoint format.
44
+ Returns a plain Vocoder for pre-ltx-2.3 checkpoints (flat config) or a
45
+ VocoderWithBWE for ltx-2.3+ checkpoints (nested "vocoder" + "bwe" config).
46
+ """
47
+
48
+ @classmethod
49
+ def from_config(cls: type[Vocoder], config: dict) -> Vocoder | VocoderWithBWE:
50
+ cfg = config.get("vocoder", {})
51
+
52
+ if "bwe" not in cfg:
53
+ check_config_value(cfg, "resblock", "1")
54
+ check_config_value(cfg, "stereo", True)
55
+ return _vocoder_from_config(cfg)
56
+
57
+ vocoder_cfg = cfg.get("vocoder", {})
58
+ bwe_cfg = cfg["bwe"]
59
+
60
+ check_config_value(vocoder_cfg, "resblock", "AMP1")
61
+ check_config_value(vocoder_cfg, "stereo", True)
62
+ check_config_value(vocoder_cfg, "activation", "snakebeta")
63
+ check_config_value(bwe_cfg, "resblock", "AMP1")
64
+ check_config_value(bwe_cfg, "stereo", True)
65
+ check_config_value(bwe_cfg, "activation", "snakebeta")
66
+
67
+ vocoder = _vocoder_from_config(
68
+ vocoder_cfg,
69
+ output_sampling_rate=bwe_cfg["input_sampling_rate"],
70
+ )
71
+ bwe_generator = _vocoder_from_config(
72
+ bwe_cfg,
73
+ apply_final_activation=False,
74
+ output_sampling_rate=bwe_cfg["output_sampling_rate"],
75
+ )
76
+ mel_stft = MelSTFT(
77
+ filter_length=bwe_cfg["n_fft"],
78
+ hop_length=bwe_cfg["hop_length"],
79
+ win_length=bwe_cfg["n_fft"],
80
+ n_mel_channels=bwe_cfg["num_mels"],
81
+ )
82
+ return VocoderWithBWE(
83
+ vocoder=vocoder,
84
+ bwe_generator=bwe_generator,
85
+ mel_stft=mel_stft,
86
+ input_sampling_rate=bwe_cfg["input_sampling_rate"],
87
+ output_sampling_rate=bwe_cfg["output_sampling_rate"],
88
+ hop_length=bwe_cfg["hop_length"],
89
+ )
90
+
91
+
92
+ def _strip_vocoder_prefix(key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
93
+ """Strip the leading 'vocoder.' prefix exactly once.
94
+ Uses removeprefix instead of str.replace so that BWE keys like
95
+ 'vocoder.vocoder.conv_pre' become 'vocoder.conv_pre' (not 'conv_pre').
96
+ Works identically for legacy keys like 'vocoder.conv_pre' → 'conv_pre'.
97
+ """
98
+ return [KeyValueOperationResult(key.removeprefix("vocoder."), value)]
99
+
100
+
101
+ VOCODER_COMFY_KEYS_FILTER = (
102
+ SDOps("VOCODER_COMFY_KEYS_FILTER")
103
+ .with_matching(prefix="vocoder.")
104
+ .with_kv_operation(operation=_strip_vocoder_prefix, key_prefix="vocoder.")
105
+ )
106
+
107
+
108
+ class AudioDecoderConfigurator(ModelConfigurator[AudioDecoder]):
109
+ @classmethod
110
+ def from_config(cls: type[AudioDecoder], config: dict) -> AudioDecoder:
111
+ audio_vae_cfg = config.get("audio_vae", {})
112
+ model_cfg = audio_vae_cfg.get("model", {})
113
+ model_params = model_cfg.get("params", {})
114
+ ddconfig = model_params.get("ddconfig", {})
115
+ preprocessing_cfg = audio_vae_cfg.get("preprocessing", {})
116
+ stft_cfg = preprocessing_cfg.get("stft", {})
117
+ mel_cfg = preprocessing_cfg.get("mel", {})
118
+ variables_cfg = audio_vae_cfg.get("variables", {})
119
+
120
+ sample_rate = model_params.get("sampling_rate", 16000)
121
+ mel_hop_length = stft_cfg.get("hop_length", 160)
122
+ is_causal = stft_cfg.get("causal", True)
123
+ mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins")
124
+
125
+ return AudioDecoder(
126
+ ch=ddconfig.get("ch", 128),
127
+ out_ch=ddconfig.get("out_ch", 2),
128
+ ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))),
129
+ num_res_blocks=ddconfig.get("num_res_blocks", 2),
130
+ attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}),
131
+ resolution=ddconfig.get("resolution", 256),
132
+ z_channels=ddconfig.get("z_channels", 8),
133
+ norm_type=NormType(ddconfig.get("norm_type", "pixel")),
134
+ causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")),
135
+ dropout=ddconfig.get("dropout", 0.0),
136
+ mid_block_add_attention=ddconfig.get("mid_block_add_attention", True),
137
+ sample_rate=sample_rate,
138
+ mel_hop_length=mel_hop_length,
139
+ is_causal=is_causal,
140
+ mel_bins=mel_bins,
141
+ )
142
+
143
+
144
+ class AudioEncoderConfigurator(ModelConfigurator[AudioEncoder]):
145
+ @classmethod
146
+ def from_config(cls: type[AudioEncoder], config: dict) -> AudioEncoder:
147
+ audio_vae_cfg = config.get("audio_vae", {})
148
+ model_cfg = audio_vae_cfg.get("model", {})
149
+ model_params = model_cfg.get("params", {})
150
+ ddconfig = model_params.get("ddconfig", {})
151
+ preprocessing_cfg = audio_vae_cfg.get("preprocessing", {})
152
+ stft_cfg = preprocessing_cfg.get("stft", {})
153
+ mel_cfg = preprocessing_cfg.get("mel", {})
154
+ variables_cfg = audio_vae_cfg.get("variables", {})
155
+
156
+ sample_rate = model_params.get("sampling_rate", 16000)
157
+ mel_hop_length = stft_cfg.get("hop_length", 160)
158
+ n_fft = stft_cfg.get("filter_length", 1024)
159
+ is_causal = stft_cfg.get("causal", True)
160
+ mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins")
161
+
162
+ return AudioEncoder(
163
+ ch=ddconfig.get("ch", 128),
164
+ ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))),
165
+ num_res_blocks=ddconfig.get("num_res_blocks", 2),
166
+ attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}),
167
+ resolution=ddconfig.get("resolution", 256),
168
+ z_channels=ddconfig.get("z_channels", 8),
169
+ double_z=ddconfig.get("double_z", True),
170
+ dropout=ddconfig.get("dropout", 0.0),
171
+ resamp_with_conv=ddconfig.get("resamp_with_conv", True),
172
+ in_channels=ddconfig.get("in_channels", 2),
173
+ attn_type=AttentionType(ddconfig.get("attn_type", "vanilla")),
174
+ mid_block_add_attention=ddconfig.get("mid_block_add_attention", True),
175
+ norm_type=NormType(ddconfig.get("norm_type", "pixel")),
176
+ causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")),
177
+ sample_rate=sample_rate,
178
+ mel_hop_length=mel_hop_length,
179
+ n_fft=n_fft,
180
+ is_causal=is_causal,
181
+ mel_bins=mel_bins,
182
+ )
183
+
184
+
185
+ AUDIO_VAE_DECODER_COMFY_KEYS_FILTER = (
186
+ SDOps("AUDIO_VAE_DECODER_COMFY_KEYS_FILTER")
187
+ .with_matching(prefix="audio_vae.decoder.")
188
+ .with_matching(prefix="audio_vae.per_channel_statistics.")
189
+ .with_replacement("audio_vae.decoder.", "")
190
+ .with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.")
191
+ )
192
+
193
+
194
+ AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER = (
195
+ SDOps("AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER")
196
+ .with_matching(prefix="audio_vae.encoder.")
197
+ .with_matching(prefix="audio_vae.per_channel_statistics.")
198
+ .with_replacement("audio_vae.encoder.", "")
199
+ .with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.")
200
+ )
packages/ltx-core/src/ltx_core/model/audio_vae/ops.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from torch import nn
4
+
5
+ from ltx_core.types import Audio
6
+
7
+
8
+ class AudioProcessor(nn.Module):
9
+ """Converts audio waveforms to log-mel spectrograms with optional resampling."""
10
+
11
+ def __init__(
12
+ self,
13
+ target_sample_rate: int,
14
+ mel_bins: int,
15
+ mel_hop_length: int,
16
+ n_fft: int,
17
+ ) -> None:
18
+ super().__init__()
19
+ self.target_sample_rate = target_sample_rate
20
+ self.mel_transform = torchaudio.transforms.MelSpectrogram(
21
+ sample_rate=target_sample_rate,
22
+ n_fft=n_fft,
23
+ win_length=n_fft,
24
+ hop_length=mel_hop_length,
25
+ f_min=0.0,
26
+ f_max=target_sample_rate / 2.0,
27
+ n_mels=mel_bins,
28
+ window_fn=torch.hann_window,
29
+ center=True,
30
+ pad_mode="reflect",
31
+ power=1.0,
32
+ mel_scale="slaney",
33
+ norm="slaney",
34
+ )
35
+
36
+ def resample_audio(self, audio: Audio) -> Audio:
37
+ """Resample audio to the processor's target sample rate if needed."""
38
+ if audio.sampling_rate == self.target_sample_rate:
39
+ return audio
40
+ resampled = torchaudio.functional.resample(audio.waveform, audio.sampling_rate, self.target_sample_rate)
41
+ resampled = resampled.to(device=audio.waveform.device, dtype=audio.waveform.dtype)
42
+ return Audio(waveform=resampled, sampling_rate=self.target_sample_rate)
43
+
44
+ def waveform_to_mel(
45
+ self,
46
+ audio: Audio,
47
+ ) -> torch.Tensor:
48
+ """Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
49
+ waveform = self.resample_audio(audio).waveform
50
+
51
+ mel = self.mel_transform(waveform)
52
+ mel = torch.log(torch.clamp(mel, min=1e-5))
53
+
54
+ mel = mel.to(device=waveform.device, dtype=waveform.dtype)
55
+ return mel.permute(0, 1, 3, 2).contiguous()
56
+
57
+
58
+ class PerChannelStatistics(nn.Module):
59
+ """
60
+ Per-channel statistics for normalizing and denormalizing the latent representation.
61
+ This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict.
62
+ """
63
+
64
+ def __init__(self, latent_channels: int = 128) -> None:
65
+ super().__init__()
66
+ self.register_buffer("std-of-means", torch.empty(latent_channels))
67
+ self.register_buffer("mean-of-means", torch.empty(latent_channels))
68
+
69
+ def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
70
+ return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
71
+
72
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
73
+ return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
packages/ltx-core/src/ltx_core/model/audio_vae/resnet.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
6
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
7
+ from ltx_core.model.common.normalization import NormType, build_normalization_layer
8
+
9
+ LRELU_SLOPE = 0.1
10
+
11
+
12
+ class ResBlock1(torch.nn.Module):
13
+ def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)):
14
+ super(ResBlock1, self).__init__()
15
+ self.convs1 = torch.nn.ModuleList(
16
+ [
17
+ torch.nn.Conv1d(
18
+ channels,
19
+ channels,
20
+ kernel_size,
21
+ 1,
22
+ dilation=dilation[0],
23
+ padding="same",
24
+ ),
25
+ torch.nn.Conv1d(
26
+ channels,
27
+ channels,
28
+ kernel_size,
29
+ 1,
30
+ dilation=dilation[1],
31
+ padding="same",
32
+ ),
33
+ torch.nn.Conv1d(
34
+ channels,
35
+ channels,
36
+ kernel_size,
37
+ 1,
38
+ dilation=dilation[2],
39
+ padding="same",
40
+ ),
41
+ ]
42
+ )
43
+
44
+ self.convs2 = torch.nn.ModuleList(
45
+ [
46
+ torch.nn.Conv1d(
47
+ channels,
48
+ channels,
49
+ kernel_size,
50
+ 1,
51
+ dilation=1,
52
+ padding="same",
53
+ ),
54
+ torch.nn.Conv1d(
55
+ channels,
56
+ channels,
57
+ kernel_size,
58
+ 1,
59
+ dilation=1,
60
+ padding="same",
61
+ ),
62
+ torch.nn.Conv1d(
63
+ channels,
64
+ channels,
65
+ kernel_size,
66
+ 1,
67
+ dilation=1,
68
+ padding="same",
69
+ ),
70
+ ]
71
+ )
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ for conv1, conv2 in zip(self.convs1, self.convs2, strict=True):
75
+ xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
76
+ xt = conv1(xt)
77
+ xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
78
+ xt = conv2(xt)
79
+ x = xt + x
80
+ return x
81
+
82
+
83
+ class ResBlock2(torch.nn.Module):
84
+ def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)):
85
+ super(ResBlock2, self).__init__()
86
+ self.convs = torch.nn.ModuleList(
87
+ [
88
+ torch.nn.Conv1d(
89
+ channels,
90
+ channels,
91
+ kernel_size,
92
+ 1,
93
+ dilation=dilation[0],
94
+ padding="same",
95
+ ),
96
+ torch.nn.Conv1d(
97
+ channels,
98
+ channels,
99
+ kernel_size,
100
+ 1,
101
+ dilation=dilation[1],
102
+ padding="same",
103
+ ),
104
+ ]
105
+ )
106
+
107
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
108
+ for conv in self.convs:
109
+ xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
110
+ xt = conv(xt)
111
+ x = xt + x
112
+ return x
113
+
114
+
115
+ class ResnetBlock(torch.nn.Module):
116
+ def __init__(
117
+ self,
118
+ *,
119
+ in_channels: int,
120
+ out_channels: int | None = None,
121
+ conv_shortcut: bool = False,
122
+ dropout: float = 0.0,
123
+ temb_channels: int = 512,
124
+ norm_type: NormType = NormType.GROUP,
125
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
126
+ ) -> None:
127
+ super().__init__()
128
+ self.causality_axis = causality_axis
129
+
130
+ if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
131
+ raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
132
+ self.in_channels = in_channels
133
+ out_channels = in_channels if out_channels is None else out_channels
134
+ self.out_channels = out_channels
135
+ self.use_conv_shortcut = conv_shortcut
136
+
137
+ self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
138
+ self.non_linearity = torch.nn.SiLU()
139
+ self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
140
+ if temb_channels > 0:
141
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
142
+ self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
143
+ self.dropout = torch.nn.Dropout(dropout)
144
+ self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
145
+ if self.in_channels != self.out_channels:
146
+ if self.use_conv_shortcut:
147
+ self.conv_shortcut = make_conv2d(
148
+ in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
149
+ )
150
+ else:
151
+ self.nin_shortcut = make_conv2d(
152
+ in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
153
+ )
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ temb: torch.Tensor | None = None,
159
+ ) -> torch.Tensor:
160
+ h = x
161
+ h = self.norm1(h)
162
+ h = self.non_linearity(h)
163
+ h = self.conv1(h)
164
+
165
+ if temb is not None:
166
+ h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
167
+
168
+ h = self.norm2(h)
169
+ h = self.non_linearity(h)
170
+ h = self.dropout(h)
171
+ h = self.conv2(h)
172
+
173
+ if self.in_channels != self.out_channels:
174
+ x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
175
+
176
+ return x + h
packages/ltx-core/src/ltx_core/model/audio_vae/upsample.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set, Tuple
2
+
3
+ import torch
4
+
5
+ from ltx_core.model.audio_vae.attention import AttentionType, make_attn
6
+ from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
7
+ from ltx_core.model.audio_vae.causality_axis import CausalityAxis
8
+ from ltx_core.model.audio_vae.resnet import ResnetBlock
9
+ from ltx_core.model.common.normalization import NormType
10
+
11
+
12
+ class Upsample(torch.nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_channels: int,
16
+ with_conv: bool,
17
+ causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
18
+ ) -> None:
19
+ super().__init__()
20
+ self.with_conv = with_conv
21
+ self.causality_axis = causality_axis
22
+ if self.with_conv:
23
+ self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
27
+ if self.with_conv:
28
+ x = self.conv(x)
29
+ # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
30
+ # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
31
+ # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
32
+ # So the output elements rely on the following windows:
33
+ # 0: [-,-,0]
34
+ # 1: [-,0,0]
35
+ # 2: [0,0,1]
36
+ # 3: [0,1,1]
37
+ # 4: [1,1,2]
38
+ # 5: [1,2,2]
39
+ # Notice that the first and second elements in the output rely only on the first element in the input,
40
+ # while all other elements rely on two elements in the input.
41
+ # So we can drop the first element to undo the padding (rather than the last element).
42
+ # This is a no-op for non-causal convolutions.
43
+ match self.causality_axis:
44
+ case CausalityAxis.NONE:
45
+ pass # x remains unchanged
46
+ case CausalityAxis.HEIGHT:
47
+ x = x[:, :, 1:, :]
48
+ case CausalityAxis.WIDTH:
49
+ x = x[:, :, :, 1:]
50
+ case CausalityAxis.WIDTH_COMPATIBILITY:
51
+ pass # x remains unchanged
52
+ case _:
53
+ raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
54
+
55
+ return x
56
+
57
+
58
+ def build_upsampling_path( # noqa: PLR0913
59
+ *,
60
+ ch: int,
61
+ ch_mult: Tuple[int, ...],
62
+ num_resolutions: int,
63
+ num_res_blocks: int,
64
+ resolution: int,
65
+ temb_channels: int,
66
+ dropout: float,
67
+ norm_type: NormType,
68
+ causality_axis: CausalityAxis,
69
+ attn_type: AttentionType,
70
+ attn_resolutions: Set[int],
71
+ resamp_with_conv: bool,
72
+ initial_block_channels: int,
73
+ ) -> tuple[torch.nn.ModuleList, int]:
74
+ """Build the upsampling path with residual blocks, attention, and upsampling layers."""
75
+ up_modules = torch.nn.ModuleList()
76
+ block_in = initial_block_channels
77
+ curr_res = resolution // (2 ** (num_resolutions - 1))
78
+
79
+ for level in reversed(range(num_resolutions)):
80
+ stage = torch.nn.Module()
81
+ stage.block = torch.nn.ModuleList()
82
+ stage.attn = torch.nn.ModuleList()
83
+ block_out = ch * ch_mult[level]
84
+
85
+ for _ in range(num_res_blocks + 1):
86
+ stage.block.append(
87
+ ResnetBlock(
88
+ in_channels=block_in,
89
+ out_channels=block_out,
90
+ temb_channels=temb_channels,
91
+ dropout=dropout,
92
+ norm_type=norm_type,
93
+ causality_axis=causality_axis,
94
+ )
95
+ )
96
+ block_in = block_out
97
+ if curr_res in attn_resolutions:
98
+ stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
99
+
100
+ if level != 0:
101
+ stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
102
+ curr_res *= 2
103
+
104
+ up_modules.insert(0, stage)
105
+
106
+ return up_modules, block_in
packages/ltx-core/src/ltx_core/model/audio_vae/vocoder.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ import einops
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from ltx_core.model.audio_vae.resnet import LRELU_SLOPE, ResBlock1
10
+
11
+
12
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
18
+ # Adopted from https://github.com/NVIDIA/BigVGAN
19
+ # ---------------------------------------------------------------------------
20
+
21
+
22
+ def _sinc(x: torch.Tensor) -> torch.Tensor:
23
+ return torch.where(
24
+ x == 0,
25
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
26
+ torch.sin(math.pi * x) / math.pi / x,
27
+ )
28
+
29
+
30
+ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
31
+ even = kernel_size % 2 == 0
32
+ half_size = kernel_size // 2
33
+ delta_f = 4 * half_width
34
+ amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35
+ if amplitude > 50.0:
36
+ beta = 0.1102 * (amplitude - 8.7)
37
+ elif amplitude >= 21.0:
38
+ beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
39
+ else:
40
+ beta = 0.0
41
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42
+ time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
43
+ if cutoff == 0:
44
+ filter_ = torch.zeros_like(time)
45
+ else:
46
+ filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
47
+ filter_ /= filter_.sum()
48
+ return filter_.view(1, 1, kernel_size)
49
+
50
+
51
+ class LowPassFilter1d(nn.Module):
52
+ def __init__(
53
+ self,
54
+ cutoff: float = 0.5,
55
+ half_width: float = 0.6,
56
+ stride: int = 1,
57
+ padding: bool = True,
58
+ padding_mode: str = "replicate",
59
+ kernel_size: int = 12,
60
+ ) -> None:
61
+ super().__init__()
62
+ if cutoff < -0.0:
63
+ raise ValueError("Minimum cutoff must be larger than zero.")
64
+ if cutoff > 0.5:
65
+ raise ValueError("A cutoff above 0.5 does not make sense.")
66
+ self.kernel_size = kernel_size
67
+ self.even = kernel_size % 2 == 0
68
+ self.pad_left = kernel_size // 2 - int(self.even)
69
+ self.pad_right = kernel_size // 2
70
+ self.stride = stride
71
+ self.padding = padding
72
+ self.padding_mode = padding_mode
73
+ self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size))
74
+
75
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
76
+ _, n_channels, _ = x.shape
77
+ if self.padding:
78
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
79
+ return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels)
80
+
81
+
82
+ class UpSample1d(nn.Module):
83
+ def __init__(
84
+ self,
85
+ ratio: int = 2,
86
+ kernel_size: int | None = None,
87
+ persistent: bool = True,
88
+ window_type: str = "kaiser",
89
+ ) -> None:
90
+ super().__init__()
91
+ self.ratio = ratio
92
+ self.stride = ratio
93
+
94
+ if window_type == "hann":
95
+ # Hann-windowed sinc filter equivalent to torchaudio.functional.resample
96
+ rolloff = 0.99
97
+ lowpass_filter_width = 6
98
+ width = math.ceil(lowpass_filter_width / rolloff)
99
+ self.kernel_size = 2 * width * ratio + 1
100
+ self.pad = width
101
+ self.pad_left = 2 * width * ratio
102
+ self.pad_right = self.kernel_size - ratio
103
+ time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
104
+ time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
105
+ window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
106
+ sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
107
+ else:
108
+ # Kaiser-windowed sinc filter (BigVGAN default).
109
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
110
+ self.pad = self.kernel_size // ratio - 1
111
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
112
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
113
+ sinc_filter = kaiser_sinc_filter1d(
114
+ cutoff=0.5 / ratio,
115
+ half_width=0.6 / ratio,
116
+ kernel_size=self.kernel_size,
117
+ )
118
+
119
+ self.register_buffer("filter", sinc_filter, persistent=persistent)
120
+
121
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
122
+ _, n_channels, _ = x.shape
123
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
124
+ filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1)
125
+ x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels)
126
+ return x[..., self.pad_left : -self.pad_right]
127
+
128
+
129
+ class DownSample1d(nn.Module):
130
+ def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None:
131
+ super().__init__()
132
+ self.ratio = ratio
133
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
134
+ self.lowpass = LowPassFilter1d(
135
+ cutoff=0.5 / ratio,
136
+ half_width=0.6 / ratio,
137
+ stride=ratio,
138
+ kernel_size=self.kernel_size,
139
+ )
140
+
141
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
142
+ return self.lowpass(x)
143
+
144
+
145
+ class Activation1d(nn.Module):
146
+ def __init__(
147
+ self,
148
+ activation: nn.Module,
149
+ up_ratio: int = 2,
150
+ down_ratio: int = 2,
151
+ up_kernel_size: int = 12,
152
+ down_kernel_size: int = 12,
153
+ ) -> None:
154
+ super().__init__()
155
+ self.act = activation
156
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
157
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
158
+
159
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
160
+ x = self.upsample(x)
161
+ x = self.act(x)
162
+ return self.downsample(x)
163
+
164
+
165
+ class Snake(nn.Module):
166
+ def __init__(
167
+ self,
168
+ in_features: int,
169
+ alpha: float = 1.0,
170
+ alpha_trainable: bool = True,
171
+ alpha_logscale: bool = True,
172
+ ) -> None:
173
+ super().__init__()
174
+ self.alpha_logscale = alpha_logscale
175
+ self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
176
+ self.alpha.requires_grad = alpha_trainable
177
+ self.eps = 1e-9
178
+
179
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
180
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
181
+ if self.alpha_logscale:
182
+ alpha = torch.exp(alpha)
183
+ return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2)
184
+
185
+
186
+ class SnakeBeta(nn.Module):
187
+ def __init__(
188
+ self,
189
+ in_features: int,
190
+ alpha: float = 1.0,
191
+ alpha_trainable: bool = True,
192
+ alpha_logscale: bool = True,
193
+ ) -> None:
194
+ super().__init__()
195
+ self.alpha_logscale = alpha_logscale
196
+ self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
197
+ self.alpha.requires_grad = alpha_trainable
198
+ self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
199
+ self.beta.requires_grad = alpha_trainable
200
+ self.eps = 1e-9
201
+
202
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
203
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
204
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
205
+ if self.alpha_logscale:
206
+ alpha = torch.exp(alpha)
207
+ beta = torch.exp(beta)
208
+ return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2)
209
+
210
+
211
+ class AMPBlock1(nn.Module):
212
+ def __init__(
213
+ self,
214
+ channels: int,
215
+ kernel_size: int = 3,
216
+ dilation: tuple[int, int, int] = (1, 3, 5),
217
+ activation: str = "snake",
218
+ ) -> None:
219
+ super().__init__()
220
+ act_cls = SnakeBeta if activation == "snakebeta" else Snake
221
+ self.convs1 = nn.ModuleList(
222
+ [
223
+ nn.Conv1d(
224
+ channels,
225
+ channels,
226
+ kernel_size,
227
+ 1,
228
+ dilation=dilation[0],
229
+ padding=get_padding(kernel_size, dilation[0]),
230
+ ),
231
+ nn.Conv1d(
232
+ channels,
233
+ channels,
234
+ kernel_size,
235
+ 1,
236
+ dilation=dilation[1],
237
+ padding=get_padding(kernel_size, dilation[1]),
238
+ ),
239
+ nn.Conv1d(
240
+ channels,
241
+ channels,
242
+ kernel_size,
243
+ 1,
244
+ dilation=dilation[2],
245
+ padding=get_padding(kernel_size, dilation[2]),
246
+ ),
247
+ ]
248
+ )
249
+
250
+ self.convs2 = nn.ModuleList(
251
+ [
252
+ nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
253
+ nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
254
+ nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
255
+ ]
256
+ )
257
+
258
+ self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))])
259
+ self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))])
260
+
261
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
262
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True):
263
+ xt = a1(x)
264
+ xt = c1(xt)
265
+ xt = a2(xt)
266
+ xt = c2(xt)
267
+ x = x + xt
268
+ return x
269
+
270
+
271
+ class Vocoder(torch.nn.Module):
272
+ """
273
+ Vocoder model for synthesizing audio from Mel spectrograms.
274
+ Args:
275
+ resblock_kernel_sizes: List of kernel sizes for the residual blocks.
276
+ This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
277
+ upsample_rates: List of upsampling rates.
278
+ This value is read from the checkpoint at `config.vocoder.upsample_rates`.
279
+ upsample_kernel_sizes: List of kernel sizes for the upsampling layers.
280
+ This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`.
281
+ resblock_dilation_sizes: List of dilation sizes for the residual blocks.
282
+ This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
283
+ upsample_initial_channel: Initial number of channels for the upsampling layers.
284
+ This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
285
+ resblock: Type of residual block to use ("1", "2", or "AMP1").
286
+ This value is read from the checkpoint at `config.vocoder.resblock`.
287
+ output_sampling_rate: Waveform sample rate.
288
+ This value is read from the checkpoint at `config.vocoder.output_sampling_rate`.
289
+ activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1".
290
+ use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True).
291
+ apply_final_activation: Whether to apply the final tanh/clamp activation.
292
+ use_bias_at_final: Whether to use bias in the final conv layer.
293
+ """
294
+
295
+ def __init__( # noqa: PLR0913
296
+ self,
297
+ resblock_kernel_sizes: List[int] | None = None,
298
+ upsample_rates: List[int] | None = None,
299
+ upsample_kernel_sizes: List[int] | None = None,
300
+ resblock_dilation_sizes: List[List[int]] | None = None,
301
+ upsample_initial_channel: int = 1024,
302
+ resblock: str = "1",
303
+ output_sampling_rate: int = 24000,
304
+ activation: str = "snake",
305
+ use_tanh_at_final: bool = True,
306
+ apply_final_activation: bool = True,
307
+ use_bias_at_final: bool = True,
308
+ ) -> None:
309
+ super().__init__()
310
+
311
+ # Mutable default values are not supported as default arguments.
312
+ if resblock_kernel_sizes is None:
313
+ resblock_kernel_sizes = [3, 7, 11]
314
+ if upsample_rates is None:
315
+ upsample_rates = [6, 5, 2, 2, 2]
316
+ if upsample_kernel_sizes is None:
317
+ upsample_kernel_sizes = [16, 15, 8, 4, 4]
318
+ if resblock_dilation_sizes is None:
319
+ resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
320
+
321
+ self.output_sampling_rate = output_sampling_rate
322
+ self.num_kernels = len(resblock_kernel_sizes)
323
+ self.num_upsamples = len(upsample_rates)
324
+ self.use_tanh_at_final = use_tanh_at_final
325
+ self.apply_final_activation = apply_final_activation
326
+ self.is_amp = resblock == "AMP1"
327
+
328
+ # All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel
329
+ # bins each), 2 output channels.
330
+ self.conv_pre = nn.Conv1d(
331
+ in_channels=128,
332
+ out_channels=upsample_initial_channel,
333
+ kernel_size=7,
334
+ stride=1,
335
+ padding=3,
336
+ )
337
+ resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1
338
+
339
+ self.ups = nn.ModuleList(
340
+ nn.ConvTranspose1d(
341
+ upsample_initial_channel // (2**i),
342
+ upsample_initial_channel // (2 ** (i + 1)),
343
+ kernel_size,
344
+ stride,
345
+ padding=(kernel_size - stride) // 2,
346
+ )
347
+ for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True))
348
+ )
349
+
350
+ final_channels = upsample_initial_channel // (2 ** len(upsample_rates))
351
+ self.resblocks = nn.ModuleList()
352
+
353
+ for i in range(len(upsample_rates)):
354
+ ch = upsample_initial_channel // (2 ** (i + 1))
355
+ for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
356
+ if self.is_amp:
357
+ self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation))
358
+ else:
359
+ self.resblocks.append(resblock_cls(ch, kernel_size, dilations))
360
+
361
+ if self.is_amp:
362
+ self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels))
363
+ else:
364
+ self.act_post = nn.LeakyReLU()
365
+
366
+ # All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo).
367
+ self.conv_post = nn.Conv1d(
368
+ in_channels=final_channels,
369
+ out_channels=2,
370
+ kernel_size=7,
371
+ stride=1,
372
+ padding=3,
373
+ bias=use_bias_at_final,
374
+ )
375
+
376
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
377
+ """
378
+ Forward pass of the vocoder.
379
+ Args:
380
+ x: Input Mel spectrogram tensor. Can be either:
381
+ - 3D: (batch_size, time, mel_bins) for mono
382
+ - 4D: (batch_size, 2, time, mel_bins) for stereo
383
+ Returns:
384
+ Audio waveform tensor of shape (batch_size, out_channels, audio_length)
385
+ """
386
+ x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time)
387
+
388
+ if x.dim() == 4: # stereo
389
+ assert x.shape[1] == 2, "Input must have 2 channels for stereo"
390
+ x = einops.rearrange(x, "b s c t -> b (s c) t")
391
+
392
+ x = self.conv_pre(x)
393
+
394
+ for i in range(self.num_upsamples):
395
+ if not self.is_amp:
396
+ x = F.leaky_relu(x, LRELU_SLOPE)
397
+ x = self.ups[i](x)
398
+ start = i * self.num_kernels
399
+ end = start + self.num_kernels
400
+
401
+ # Evaluate all resblocks with the same input tensor so they can run
402
+ # independently (and thus in parallel on accelerator hardware) before
403
+ # aggregating their outputs via mean.
404
+ block_outputs = torch.stack(
405
+ [self.resblocks[idx](x) for idx in range(start, end)],
406
+ dim=0,
407
+ )
408
+ x = block_outputs.mean(dim=0)
409
+
410
+ x = self.act_post(x)
411
+ x = self.conv_post(x)
412
+
413
+ if self.apply_final_activation:
414
+ x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1)
415
+
416
+ return x
417
+
418
+
419
+ class _STFTFn(nn.Module):
420
+ """Implements STFT as a convolution with precomputed DFT x Hann-window bases.
421
+ The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
422
+ Hann window are stored as buffers and loaded from the checkpoint. Using the exact
423
+ bfloat16 bases from training ensures the mel values fed to the BWE generator are
424
+ bit-identical to what it was trained on.
425
+ """
426
+
427
+ def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
428
+ super().__init__()
429
+ self.hop_length = hop_length
430
+ self.win_length = win_length
431
+ n_freqs = filter_length // 2 + 1
432
+ self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
433
+ self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
434
+
435
+ def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
436
+ """Compute magnitude and phase spectrogram from a batch of waveforms.
437
+ Applies causal (left-only) padding of win_length - hop_length samples so that
438
+ each output frame depends only on past and present input — no lookahead.
439
+ Args:
440
+ y: Waveform tensor of shape (B, T).
441
+ Returns:
442
+ magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
443
+ phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
444
+ """
445
+ if y.dim() == 2:
446
+ y = y.unsqueeze(1) # (B, 1, T)
447
+ left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
448
+ y = F.pad(y, (left_pad, 0))
449
+ spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
450
+ n_freqs = spec.shape[1] // 2
451
+ real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
452
+ magnitude = torch.sqrt(real**2 + imag**2)
453
+ phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
454
+ return magnitude, phase
455
+
456
+
457
+ class MelSTFT(nn.Module):
458
+ """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
459
+ Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
460
+ waveform and projecting the linear magnitude spectrum onto the mel filterbank.
461
+ The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
462
+ (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
463
+ """
464
+
465
+ def __init__(
466
+ self,
467
+ filter_length: int,
468
+ hop_length: int,
469
+ win_length: int,
470
+ n_mel_channels: int,
471
+ ) -> None:
472
+ super().__init__()
473
+ self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
474
+
475
+ # Initialized to zeros; load_state_dict overwrites with the checkpoint's
476
+ # exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]).
477
+ n_freqs = filter_length // 2 + 1
478
+ self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
479
+
480
+ def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
481
+ """Compute log-mel spectrogram and auxiliary spectral quantities.
482
+ Args:
483
+ y: Waveform tensor of shape (B, T).
484
+ Returns:
485
+ log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
486
+ magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
487
+ phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
488
+ energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
489
+ """
490
+ magnitude, phase = self.stft_fn(y)
491
+ energy = torch.norm(magnitude, dim=1)
492
+ mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
493
+ log_mel = torch.log(torch.clamp(mel, min=1e-5))
494
+ return log_mel, magnitude, phase, energy
495
+
496
+
497
+ class VocoderWithBWE(nn.Module):
498
+ """Vocoder with bandwidth extension (BWE) upsampling.
499
+ Chains a mel-to-wav vocoder with a BWE module that upsamples the output
500
+ to a higher sample rate. The BWE computes a mel spectrogram from the
501
+ vocoder output, runs it through a second generator to predict a residual,
502
+ and adds it to a sinc-resampled skip connection.
503
+ """
504
+
505
+ def __init__(
506
+ self,
507
+ vocoder: Vocoder,
508
+ bwe_generator: Vocoder,
509
+ mel_stft: MelSTFT,
510
+ input_sampling_rate: int,
511
+ output_sampling_rate: int,
512
+ hop_length: int,
513
+ ) -> None:
514
+ super().__init__()
515
+ self.vocoder = vocoder
516
+ self.bwe_generator = bwe_generator
517
+ self.mel_stft = mel_stft
518
+ self.input_sampling_rate = input_sampling_rate
519
+ self.output_sampling_rate = output_sampling_rate
520
+ self.hop_length = hop_length
521
+ # Compute the resampler on CPU so the sinc filter is materialized even when
522
+ # the model is constructed on meta device (SingleGPUModelBuilder pattern).
523
+ # The filter is not stored in the checkpoint (persistent=False).
524
+ with torch.device("cpu"):
525
+ self.resampler = UpSample1d(
526
+ ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann"
527
+ )
528
+
529
+ @property
530
+ def conv_pre(self) -> nn.Conv1d:
531
+ return self.vocoder.conv_pre
532
+
533
+ @property
534
+ def conv_post(self) -> nn.Conv1d:
535
+ return self.vocoder.conv_post
536
+
537
+ def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor:
538
+ """Compute log-mel spectrogram from waveform using causal STFT bases.
539
+ Args:
540
+ audio: Waveform tensor of shape (B, C, T).
541
+ Returns:
542
+ mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames).
543
+ """
544
+ batch, n_channels, _ = audio.shape
545
+ flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
546
+ mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
547
+ return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
548
+
549
+ def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
550
+ """Run the full vocoder + BWE forward pass.
551
+ Args:
552
+ mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo
553
+ or (B, T, mel_bins) for mono. Same format as Vocoder.forward.
554
+ Returns:
555
+ Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1].
556
+ """
557
+ x = self.vocoder(mel_spec)
558
+ _, _, length_low_rate = x.shape
559
+ output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
560
+
561
+ # Pad to multiple of hop_length for exact mel frame count
562
+ remainder = length_low_rate % self.hop_length
563
+ if remainder != 0:
564
+ x = F.pad(x, (0, self.hop_length - remainder))
565
+
566
+ # Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames)
567
+ mel = self._compute_mel(x)
568
+
569
+ # Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator
570
+ mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins)
571
+ residual = self.bwe_generator(mel_for_bwe)
572
+ skip = self.resampler(x)
573
+ assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
574
+
575
+ return torch.clamp(residual + skip, -1, 1)[..., :output_length]
packages/ltx-core/src/ltx_core/model/model_protocol.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Protocol, TypeVar
2
+
3
+ ModelType = TypeVar("ModelType")
4
+
5
+
6
+ class ModelConfigurator(Protocol[ModelType]):
7
+ """Protocol for model loader classes that instantiates models from a configuration dictionary."""
8
+
9
+ @classmethod
10
+ def from_config(cls, config: dict) -> ModelType: ...
packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.model.transformer.gelu_approx import GELUApprox
4
+
5
+
6
+ class FeedForward(torch.nn.Module):
7
+ def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None:
8
+ super().__init__()
9
+ inner_dim = int(dim * mult)
10
+ project_in = GELUApprox(dim, inner_dim)
11
+
12
+ self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out))
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return self.net(x)
packages/ltx-core/src/ltx_core/model/upsampler/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """Latent upsampler model components."""
2
+
3
+ from ltx_core.model.upsampler.model import LatentUpsampler, upsample_video
4
+ from ltx_core.model.upsampler.model_configurator import LatentUpsamplerConfigurator
5
+
6
+ __all__ = [
7
+ "LatentUpsampler",
8
+ "LatentUpsamplerConfigurator",
9
+ "upsample_video",
10
+ ]