Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- packages/ltx-core/src/ltx_core/__init__.py +0 -0
- packages/ltx-core/src/ltx_core/__pycache__/__init__.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/__pycache__/types.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/__pycache__/utils.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/components/__init__.py +10 -0
- packages/ltx-core/src/ltx_core/components/diffusion_steps.py +95 -0
- packages/ltx-core/src/ltx_core/components/guiders.py +364 -0
- packages/ltx-core/src/ltx_core/components/noisers.py +35 -0
- packages/ltx-core/src/ltx_core/components/patchifiers.py +348 -0
- packages/ltx-core/src/ltx_core/components/protocols.py +101 -0
- packages/ltx-core/src/ltx_core/components/schedulers.py +130 -0
- packages/ltx-core/src/ltx_core/conditioning/__init__.py +19 -0
- packages/ltx-core/src/ltx_core/conditioning/exceptions.py +4 -0
- packages/ltx-core/src/ltx_core/conditioning/item.py +20 -0
- packages/ltx-core/src/ltx_core/conditioning/mask_utils.py +210 -0
- packages/ltx-core/src/ltx_core/guidance/__init__.py +15 -0
- packages/ltx-core/src/ltx_core/guidance/perturbations.py +79 -0
- packages/ltx-core/src/ltx_core/loader/__init__.py +48 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/__init__.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/fuse_loras.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/module_ops.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/primitives.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/registry.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/sd_ops.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/sft_loader.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/__pycache__/single_gpu_model_builder.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/loader/fuse_loras.py +153 -0
- packages/ltx-core/src/ltx_core/loader/kernels.py +72 -0
- packages/ltx-core/src/ltx_core/loader/module_ops.py +14 -0
- packages/ltx-core/src/ltx_core/loader/primitives.py +109 -0
- packages/ltx-core/src/ltx_core/loader/registry.py +84 -0
- packages/ltx-core/src/ltx_core/loader/sd_ops.py +127 -0
- packages/ltx-core/src/ltx_core/loader/sft_loader.py +66 -0
- packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py +116 -0
- packages/ltx-core/src/ltx_core/model/__init__.py +8 -0
- packages/ltx-core/src/ltx_core/model/__pycache__/__init__.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/__pycache__/model_protocol.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py +29 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/attention.py +71 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py +508 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/causal_conv_2d.py +110 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/downsample.py +110 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/model_configurator.py +200 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/ops.py +73 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/resnet.py +176 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/upsample.py +106 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/vocoder.py +575 -0
- packages/ltx-core/src/ltx_core/model/model_protocol.py +10 -0
- packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py +15 -0
- packages/ltx-core/src/ltx_core/model/upsampler/__init__.py +10 -0
packages/ltx-core/src/ltx_core/__init__.py
ADDED
|
File without changes
|
packages/ltx-core/src/ltx_core/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (170 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/__pycache__/types.cpython-312.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (3.68 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/components/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Diffusion pipeline components.
|
| 3 |
+
Submodules:
|
| 4 |
+
diffusion_steps - Diffusion stepping algorithms (EulerDiffusionStep)
|
| 5 |
+
guiders - Guidance strategies (CFGGuider, STGGuider, APG variants)
|
| 6 |
+
noisers - Noise samplers (GaussianNoiser)
|
| 7 |
+
patchifiers - Latent patchification (VideoLatentPatchifier, AudioPatchifier)
|
| 8 |
+
protocols - Protocol definitions (Patchifier, etc.)
|
| 9 |
+
schedulers - Sigma schedulers (LTX2Scheduler, LinearQuadraticScheduler)
|
| 10 |
+
"""
|
packages/ltx-core/src/ltx_core/components/diffusion_steps.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.components.protocols import DiffusionStepProtocol
|
| 4 |
+
from ltx_core.utils import to_velocity
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EulerDiffusionStep(DiffusionStepProtocol):
|
| 8 |
+
"""
|
| 9 |
+
First-order Euler method for diffusion sampling.
|
| 10 |
+
Takes a single step from the current noise level (sigma) to the next by
|
| 11 |
+
computing velocity from the denoised prediction and applying: sample + velocity * dt.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def step(
|
| 15 |
+
self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **_kwargs
|
| 16 |
+
) -> torch.Tensor:
|
| 17 |
+
sigma = sigmas[step_index]
|
| 18 |
+
sigma_next = sigmas[step_index + 1]
|
| 19 |
+
dt = sigma_next - sigma
|
| 20 |
+
velocity = to_velocity(sample, sigma, denoised_sample)
|
| 21 |
+
|
| 22 |
+
return (sample.to(torch.float32) + velocity.to(torch.float32) * dt).to(sample.dtype)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Res2sDiffusionStep(DiffusionStepProtocol):
|
| 26 |
+
"""
|
| 27 |
+
Second-order diffusion step for res_2s sampling with SDE noise injection.
|
| 28 |
+
Used by the res_2s denoising loop. Advances the sample from the current
|
| 29 |
+
sigma to the next by mixing a deterministic update (from the denoised
|
| 30 |
+
prediction) with injected noise via ``get_sde_coeff``, producing
|
| 31 |
+
variance-preserving transitions.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def get_sde_coeff(
|
| 36 |
+
sigma_next: torch.Tensor,
|
| 37 |
+
sigma_up: torch.Tensor | None = None,
|
| 38 |
+
sigma_down: torch.Tensor | None = None,
|
| 39 |
+
sigma_max: torch.Tensor | None = None,
|
| 40 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 41 |
+
"""
|
| 42 |
+
Compute SDE coefficients (alpha_ratio, sigma_down, sigma_up) for the step.
|
| 43 |
+
Given either ``sigma_down`` or ``sigma_up``, returns the mixing
|
| 44 |
+
coefficients used for variance-preserving noise injection. If
|
| 45 |
+
``sigma_up`` is provided, ``sigma_down`` and ``alpha_ratio`` are
|
| 46 |
+
derived; if ``sigma_down`` is provided, ``sigma_up`` and
|
| 47 |
+
``alpha_ratio`` are derived.
|
| 48 |
+
"""
|
| 49 |
+
if sigma_down is not None:
|
| 50 |
+
alpha_ratio = (1 - sigma_next) / (1 - sigma_down)
|
| 51 |
+
sigma_up = (sigma_next**2 - sigma_down**2 * alpha_ratio**2).clamp(min=0) ** 0.5
|
| 52 |
+
elif sigma_up is not None:
|
| 53 |
+
# Fallback to avoid sqrt(neg_num)
|
| 54 |
+
sigma_up.clamp_(max=sigma_next * 0.9999)
|
| 55 |
+
sigmax = sigma_max if sigma_max is not None else torch.ones_like(sigma_next)
|
| 56 |
+
sigma_signal = sigmax - sigma_next
|
| 57 |
+
sigma_residual = (sigma_next**2 - sigma_up**2).clamp(min=0) ** 0.5
|
| 58 |
+
alpha_ratio = sigma_signal + sigma_residual
|
| 59 |
+
sigma_down = sigma_residual / alpha_ratio
|
| 60 |
+
else:
|
| 61 |
+
alpha_ratio = torch.ones_like(sigma_next)
|
| 62 |
+
sigma_down = sigma_next
|
| 63 |
+
sigma_up = torch.zeros_like(sigma_next)
|
| 64 |
+
|
| 65 |
+
sigma_up = torch.nan_to_num(sigma_up if sigma_up is not None else torch.zeros_like(sigma_next), 0.0)
|
| 66 |
+
# Replace NaNs in sigma_down with corresponding sigma_next elements (float32)
|
| 67 |
+
nan_mask = torch.isnan(sigma_down)
|
| 68 |
+
sigma_down[nan_mask] = sigma_next[nan_mask].to(sigma_down.dtype)
|
| 69 |
+
alpha_ratio = torch.nan_to_num(alpha_ratio, 1.0)
|
| 70 |
+
|
| 71 |
+
return alpha_ratio, sigma_down, sigma_up
|
| 72 |
+
|
| 73 |
+
def step(
|
| 74 |
+
self,
|
| 75 |
+
sample: torch.Tensor,
|
| 76 |
+
denoised_sample: torch.Tensor,
|
| 77 |
+
sigmas: torch.Tensor,
|
| 78 |
+
step_index: int,
|
| 79 |
+
noise: torch.Tensor,
|
| 80 |
+
) -> torch.Tensor:
|
| 81 |
+
"""Advance one step with SDE noise injection via get_sde_coeff."""
|
| 82 |
+
sigma = sigmas[step_index]
|
| 83 |
+
sigma_next = sigmas[step_index + 1]
|
| 84 |
+
alpha_ratio, sigma_down, sigma_up = self.get_sde_coeff(sigma_next, sigma_up=sigma_next * 0.5)
|
| 85 |
+
output_dtype = denoised_sample.dtype
|
| 86 |
+
if torch.any(sigma_up == 0) or torch.any(sigma_next == 0):
|
| 87 |
+
return denoised_sample
|
| 88 |
+
|
| 89 |
+
# Extract epsilon prediction
|
| 90 |
+
eps_next = (sample - denoised_sample) / (sigma - sigma_next)
|
| 91 |
+
denoised_next = sample - sigma * eps_next
|
| 92 |
+
|
| 93 |
+
# Mix deterministic and stochastic components
|
| 94 |
+
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise
|
| 95 |
+
return x_noised.to(output_dtype)
|
packages/ltx-core/src/ltx_core/components/guiders.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from collections.abc import Mapping, Sequence
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ltx_core.components.protocols import GuiderProtocol
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass(frozen=True)
|
| 11 |
+
class CFGGuider(GuiderProtocol):
|
| 12 |
+
"""
|
| 13 |
+
Classifier-free guidance (CFG) guider.
|
| 14 |
+
Computes the guidance delta as (scale - 1) * (cond - uncond), steering the
|
| 15 |
+
denoising process toward the conditioned prediction.
|
| 16 |
+
Attributes:
|
| 17 |
+
scale: Guidance strength. 1.0 means no guidance, higher values increase
|
| 18 |
+
adherence to the conditioning.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
scale: float
|
| 22 |
+
|
| 23 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
return (self.scale - 1) * (cond - uncond)
|
| 25 |
+
|
| 26 |
+
def enabled(self) -> bool:
|
| 27 |
+
return self.scale != 1.0
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass(frozen=True)
|
| 31 |
+
class CFGStarRescalingGuider(GuiderProtocol):
|
| 32 |
+
"""
|
| 33 |
+
Calculates the CFG delta between conditioned and unconditioned samples.
|
| 34 |
+
To minimize offset in the denoising direction and move mostly along the
|
| 35 |
+
conditioning axis within the distribution, the unconditioned sample is
|
| 36 |
+
rescaled in accordance with the norm of the conditioned sample.
|
| 37 |
+
Attributes:
|
| 38 |
+
scale (float):
|
| 39 |
+
Global guidance strength. A value of 1.0 corresponds to no extra
|
| 40 |
+
guidance beyond the base model prediction. Values > 1.0 increase
|
| 41 |
+
the influence of the conditioned sample relative to the
|
| 42 |
+
unconditioned one.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
scale: float
|
| 46 |
+
|
| 47 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
rescaled_neg = projection_coef(cond, uncond) * uncond
|
| 49 |
+
return (self.scale - 1) * (cond - rescaled_neg)
|
| 50 |
+
|
| 51 |
+
def enabled(self) -> bool:
|
| 52 |
+
return self.scale != 1.0
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass(frozen=True)
|
| 56 |
+
class STGGuider(GuiderProtocol):
|
| 57 |
+
"""
|
| 58 |
+
Calculates the STG delta between conditioned and perturbed denoised samples.
|
| 59 |
+
Perturbed samples are the result of the denoising process with perturbations,
|
| 60 |
+
e.g. attentions acting as passthrough for certain layers and modalities.
|
| 61 |
+
Attributes:
|
| 62 |
+
scale (float):
|
| 63 |
+
Global strength of the STG guidance. A value of 0.0 disables the
|
| 64 |
+
guidance. Larger values increase the correction applied in the
|
| 65 |
+
direction of (pos_denoised - perturbed_denoised).
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
scale: float
|
| 69 |
+
|
| 70 |
+
def delta(self, pos_denoised: torch.Tensor, perturbed_denoised: torch.Tensor) -> torch.Tensor:
|
| 71 |
+
return self.scale * (pos_denoised - perturbed_denoised)
|
| 72 |
+
|
| 73 |
+
def enabled(self) -> bool:
|
| 74 |
+
return self.scale != 0.0
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass(frozen=True)
|
| 78 |
+
class LtxAPGGuider(GuiderProtocol):
|
| 79 |
+
"""
|
| 80 |
+
Calculates the APG (adaptive projected guidance) delta between conditioned
|
| 81 |
+
and unconditioned samples.
|
| 82 |
+
To minimize offset in the denoising direction and move mostly along the
|
| 83 |
+
conditioning axis within the distribution, the (cond - uncond) delta is
|
| 84 |
+
decomposed into components parallel and orthogonal to the conditioned
|
| 85 |
+
sample. The `eta` parameter weights the parallel component, while `scale`
|
| 86 |
+
is applied to the orthogonal component. Optionally, a norm threshold can
|
| 87 |
+
be used to suppress guidance when the magnitude of the correction is small.
|
| 88 |
+
Attributes:
|
| 89 |
+
scale (float):
|
| 90 |
+
Strength applied to the component of the guidance that is orthogonal
|
| 91 |
+
to the conditioned sample. Controls how aggressively we move in
|
| 92 |
+
directions that change semantics but stay consistent with the
|
| 93 |
+
conditioning manifold.
|
| 94 |
+
eta (float):
|
| 95 |
+
Weight of the component of the guidance that is parallel to the
|
| 96 |
+
conditioned sample. A value of 1.0 keeps the full parallel
|
| 97 |
+
component; values in [0, 1] attenuate it, and values > 1.0 amplify
|
| 98 |
+
motion along the conditioning direction.
|
| 99 |
+
norm_threshold (float):
|
| 100 |
+
Minimum L2 norm of the guidance delta below which the guidance
|
| 101 |
+
can be reduced or ignored (depending on implementation).
|
| 102 |
+
This is useful for avoiding noisy or unstable updates when the
|
| 103 |
+
guidance signal is very small.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
scale: float
|
| 107 |
+
eta: float = 1.0
|
| 108 |
+
norm_threshold: float = 0.0
|
| 109 |
+
|
| 110 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 111 |
+
guidance = cond - uncond
|
| 112 |
+
if self.norm_threshold > 0:
|
| 113 |
+
ones = torch.ones_like(guidance)
|
| 114 |
+
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
| 115 |
+
scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
|
| 116 |
+
guidance = guidance * scale_factor
|
| 117 |
+
proj_coeff = projection_coef(guidance, cond)
|
| 118 |
+
g_parallel = proj_coeff * cond
|
| 119 |
+
g_orth = guidance - g_parallel
|
| 120 |
+
g_apg = g_parallel * self.eta + g_orth
|
| 121 |
+
|
| 122 |
+
return g_apg * (self.scale - 1)
|
| 123 |
+
|
| 124 |
+
def enabled(self) -> bool:
|
| 125 |
+
return self.scale != 1.0
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@dataclass(frozen=False)
|
| 129 |
+
class LegacyStatefulAPGGuider(GuiderProtocol):
|
| 130 |
+
"""
|
| 131 |
+
Calculates the APG (adaptive projected guidance) delta between conditioned
|
| 132 |
+
and unconditioned samples.
|
| 133 |
+
To minimize offset in the denoising direction and move mostly along the
|
| 134 |
+
conditioning axis within the distribution, the (cond - uncond) delta is
|
| 135 |
+
decomposed into components parallel and orthogonal to the conditioned
|
| 136 |
+
sample. The `eta` parameter weights the parallel component, while `scale`
|
| 137 |
+
is applied to the orthogonal component. Optionally, a norm threshold can
|
| 138 |
+
be used to suppress guidance when the magnitude of the correction is small.
|
| 139 |
+
Attributes:
|
| 140 |
+
scale (float):
|
| 141 |
+
Strength applied to the component of the guidance that is orthogonal
|
| 142 |
+
to the conditioned sample. Controls how aggressively we move in
|
| 143 |
+
directions that change semantics but stay consistent with the
|
| 144 |
+
conditioning manifold.
|
| 145 |
+
eta (float):
|
| 146 |
+
Weight of the component of the guidance that is parallel to the
|
| 147 |
+
conditioned sample. A value of 1.0 keeps the full parallel
|
| 148 |
+
component; values in [0, 1] attenuate it, and values > 1.0 amplify
|
| 149 |
+
motion along the conditioning direction.
|
| 150 |
+
norm_threshold (float):
|
| 151 |
+
Minimum L2 norm of the guidance delta below which the guidance
|
| 152 |
+
can be reduced or ignored (depending on implementation).
|
| 153 |
+
This is useful for avoiding noisy or unstable updates when the
|
| 154 |
+
guidance signal is very small.
|
| 155 |
+
momentum (float):
|
| 156 |
+
Exponential moving-average coefficient for accumulating guidance
|
| 157 |
+
over time. running_avg = momentum * running_avg + guidance
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
scale: float
|
| 161 |
+
eta: float
|
| 162 |
+
norm_threshold: float = 5.0
|
| 163 |
+
momentum: float = 0.0
|
| 164 |
+
# it is user's responsibility not to use same APGGuider for several denoisings or different modalities
|
| 165 |
+
# in order not to share accumulated average across different denoisings or modalities
|
| 166 |
+
running_avg: torch.Tensor | None = None
|
| 167 |
+
|
| 168 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 169 |
+
guidance = cond - uncond
|
| 170 |
+
if self.momentum != 0:
|
| 171 |
+
if self.running_avg is None:
|
| 172 |
+
self.running_avg = guidance.clone()
|
| 173 |
+
else:
|
| 174 |
+
self.running_avg = self.momentum * self.running_avg + guidance
|
| 175 |
+
guidance = self.running_avg
|
| 176 |
+
|
| 177 |
+
if self.norm_threshold > 0:
|
| 178 |
+
ones = torch.ones_like(guidance)
|
| 179 |
+
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
| 180 |
+
scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
|
| 181 |
+
guidance = guidance * scale_factor
|
| 182 |
+
|
| 183 |
+
proj_coeff = projection_coef(guidance, cond)
|
| 184 |
+
g_parallel = proj_coeff * cond
|
| 185 |
+
g_orth = guidance - g_parallel
|
| 186 |
+
g_apg = g_parallel * self.eta + g_orth
|
| 187 |
+
|
| 188 |
+
return g_apg * self.scale
|
| 189 |
+
|
| 190 |
+
def enabled(self) -> bool:
|
| 191 |
+
return self.scale != 0.0
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@dataclass(frozen=True)
|
| 195 |
+
class MultiModalGuiderParams:
|
| 196 |
+
"""
|
| 197 |
+
Parameters for the multi-modal guider.
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
cfg_scale: float = 1.0
|
| 201 |
+
"CFG (Classifier-free guidance) scale controlling how strongly the model adheres to the prompt."
|
| 202 |
+
stg_scale: float = 0.0
|
| 203 |
+
"STG (Spatio-Temporal Guidance) scale controls how strongly the model reacts to the perturbation of the modality."
|
| 204 |
+
stg_blocks: list[int] | None = field(default_factory=list)
|
| 205 |
+
"Which transformer blocks to perturb for STG."
|
| 206 |
+
rescale_scale: float = 0.0
|
| 207 |
+
"Rescale scale controlling how strongly the model rescales the modality after applying other guidance."
|
| 208 |
+
modality_scale: float = 1.0
|
| 209 |
+
"Modality scale controlling how strongly the model reacts to the perturbation of the modality."
|
| 210 |
+
skip_step: int = 0
|
| 211 |
+
"Skip step controlling how often the model skips the step."
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _params_for_sigma_from_sorted_dict(
|
| 215 |
+
sigma: float, params_by_sigma: Sequence[tuple[float, MultiModalGuiderParams]]
|
| 216 |
+
) -> MultiModalGuiderParams:
|
| 217 |
+
"""
|
| 218 |
+
Return params for the given sigma from a sorted (sigma_upper_bound -> params) structure.
|
| 219 |
+
Keys are sorted descending (bin upper bounds). Bin i is (key_{i+1}, key_i].
|
| 220 |
+
Get all keys >= sigma; use last in list (smallest such key = upper bound of bin containing sigma),
|
| 221 |
+
or last entry in the sequence if list is empty (sigma above max key).
|
| 222 |
+
"""
|
| 223 |
+
if not params_by_sigma:
|
| 224 |
+
raise ValueError("params_by_sigma must be non-empty")
|
| 225 |
+
sigma = float(sigma)
|
| 226 |
+
keys_desc = [k for k, _ in params_by_sigma]
|
| 227 |
+
keys_ge_sigma = [k for k in keys_desc if k >= sigma]
|
| 228 |
+
# sigma above all keys: use first bin (max key)
|
| 229 |
+
key = keys_ge_sigma[-1] if keys_ge_sigma else keys_desc[0]
|
| 230 |
+
return next(p for k, p in params_by_sigma if k == key)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@dataclass(frozen=True)
|
| 234 |
+
class MultiModalGuider:
|
| 235 |
+
"""
|
| 236 |
+
Multi-modal guider with constant params per instance.
|
| 237 |
+
For sigma-dependent params, use MultiModalGuiderFactory.build_from_sigma(sigma) to
|
| 238 |
+
obtain a guider for each step.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
params: MultiModalGuiderParams
|
| 242 |
+
negative_context: torch.Tensor | None = None
|
| 243 |
+
|
| 244 |
+
def calculate(
|
| 245 |
+
self,
|
| 246 |
+
cond: torch.Tensor,
|
| 247 |
+
uncond_text: torch.Tensor | float,
|
| 248 |
+
uncond_perturbed: torch.Tensor | float,
|
| 249 |
+
uncond_modality: torch.Tensor | float,
|
| 250 |
+
) -> torch.Tensor:
|
| 251 |
+
"""
|
| 252 |
+
The guider calculates the guidance delta as (scale - 1) * (cond - uncond) for cfg and modality cfg,
|
| 253 |
+
and as scale * (cond - uncond) for stg, steering the denoising process away from the unconditioned
|
| 254 |
+
prediction.
|
| 255 |
+
"""
|
| 256 |
+
pred = (
|
| 257 |
+
cond
|
| 258 |
+
+ (self.params.cfg_scale - 1) * (cond - uncond_text)
|
| 259 |
+
+ self.params.stg_scale * (cond - uncond_perturbed)
|
| 260 |
+
+ (self.params.modality_scale - 1) * (cond - uncond_modality)
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
if self.params.rescale_scale != 0:
|
| 264 |
+
factor = cond.std() / pred.std()
|
| 265 |
+
factor = self.params.rescale_scale * factor + (1 - self.params.rescale_scale)
|
| 266 |
+
pred = pred * factor
|
| 267 |
+
|
| 268 |
+
return pred
|
| 269 |
+
|
| 270 |
+
def do_unconditional_generation(self) -> bool:
|
| 271 |
+
"""Returns True if the guider is doing unconditional generation."""
|
| 272 |
+
return not math.isclose(self.params.cfg_scale, 1.0)
|
| 273 |
+
|
| 274 |
+
def do_perturbed_generation(self) -> bool:
|
| 275 |
+
"""Returns True if the guider is doing perturbed generation."""
|
| 276 |
+
return not math.isclose(self.params.stg_scale, 0.0)
|
| 277 |
+
|
| 278 |
+
def do_isolated_modality_generation(self) -> bool:
|
| 279 |
+
"""Returns True if the guider is doing isolated modality generation."""
|
| 280 |
+
return not math.isclose(self.params.modality_scale, 1.0)
|
| 281 |
+
|
| 282 |
+
def should_skip_step(self, step: int) -> bool:
|
| 283 |
+
"""Returns True if the guider should skip the step."""
|
| 284 |
+
if self.params.skip_step == 0:
|
| 285 |
+
return False
|
| 286 |
+
return step % (self.params.skip_step + 1) != 0
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@dataclass(frozen=True)
|
| 290 |
+
class MultiModalGuiderFactory:
|
| 291 |
+
"""
|
| 292 |
+
Factory that creates a MultiModalGuider for a given sigma.
|
| 293 |
+
Single source of truth: _params_by_sigma (schedule). Use constant() for
|
| 294 |
+
one params for all sigma, from_dict() for sigma-binned params.
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
negative_context: torch.Tensor | None = None
|
| 298 |
+
_params_by_sigma: tuple[tuple[float, MultiModalGuiderParams], ...] = ()
|
| 299 |
+
|
| 300 |
+
@classmethod
|
| 301 |
+
def constant(
|
| 302 |
+
cls,
|
| 303 |
+
params: MultiModalGuiderParams,
|
| 304 |
+
negative_context: torch.Tensor | None = None,
|
| 305 |
+
) -> "MultiModalGuiderFactory":
|
| 306 |
+
"""Build a factory with constant params (same guider for all sigma)."""
|
| 307 |
+
return cls(
|
| 308 |
+
negative_context=negative_context,
|
| 309 |
+
_params_by_sigma=((float("inf"), params),),
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
@classmethod
|
| 313 |
+
def from_dict(
|
| 314 |
+
cls,
|
| 315 |
+
sigma_to_params: Mapping[float, MultiModalGuiderParams],
|
| 316 |
+
negative_context: torch.Tensor | None = None,
|
| 317 |
+
) -> "MultiModalGuiderFactory":
|
| 318 |
+
"""
|
| 319 |
+
Build a factory from a dict of sigma_value -> MultiModalGuiderParams.
|
| 320 |
+
Keys are sorted descending and used for bin lookup in params(sigma).
|
| 321 |
+
"""
|
| 322 |
+
if not sigma_to_params:
|
| 323 |
+
raise ValueError("sigma_to_params must be non-empty")
|
| 324 |
+
sorted_items = tuple(sorted(sigma_to_params.items(), key=lambda x: x[0], reverse=True))
|
| 325 |
+
return cls(negative_context=negative_context, _params_by_sigma=sorted_items)
|
| 326 |
+
|
| 327 |
+
def params(self, sigma: float | torch.Tensor) -> MultiModalGuiderParams:
|
| 328 |
+
"""Return params effective for the given sigma (getter; single source of truth)."""
|
| 329 |
+
sigma_val = float(sigma.item() if isinstance(sigma, torch.Tensor) else sigma)
|
| 330 |
+
return _params_for_sigma_from_sorted_dict(sigma_val, self._params_by_sigma)
|
| 331 |
+
|
| 332 |
+
def build_from_sigma(self, sigma: float | torch.Tensor) -> MultiModalGuider:
|
| 333 |
+
"""Return a MultiModalGuider with params effective for the given sigma."""
|
| 334 |
+
return MultiModalGuider(
|
| 335 |
+
params=self.params(sigma),
|
| 336 |
+
negative_context=self.negative_context,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def create_multimodal_guider_factory(
|
| 341 |
+
params: MultiModalGuiderParams | MultiModalGuiderFactory,
|
| 342 |
+
negative_context: torch.Tensor | None = None,
|
| 343 |
+
) -> MultiModalGuiderFactory:
|
| 344 |
+
"""
|
| 345 |
+
Create or return a MultiModalGuiderFactory. Pass constant params for a
|
| 346 |
+
single-params factory (uses MultiModalGuiderFactory.constant), or an existing
|
| 347 |
+
MultiModalGuiderFactory. When given a factory, returns it as-is unless
|
| 348 |
+
negative_context is provided. For sigma-dependent params use
|
| 349 |
+
MultiModalGuiderFactory.from_dict(...) and pass that as params.
|
| 350 |
+
"""
|
| 351 |
+
if isinstance(params, MultiModalGuiderFactory):
|
| 352 |
+
if negative_context is not None and params.negative_context is not negative_context:
|
| 353 |
+
return MultiModalGuiderFactory.from_dict(dict(params._params_by_sigma), negative_context=negative_context)
|
| 354 |
+
return params
|
| 355 |
+
return MultiModalGuiderFactory.constant(params, negative_context=negative_context)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def projection_coef(to_project: torch.Tensor, project_onto: torch.Tensor) -> torch.Tensor:
|
| 359 |
+
batch_size = to_project.shape[0]
|
| 360 |
+
positive_flat = to_project.reshape(batch_size, -1)
|
| 361 |
+
negative_flat = project_onto.reshape(batch_size, -1)
|
| 362 |
+
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
| 363 |
+
squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
|
| 364 |
+
return dot_product / squared_norm
|
packages/ltx-core/src/ltx_core/components/noisers.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import replace
|
| 2 |
+
from typing import Protocol
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ltx_core.types import LatentState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Noiser(Protocol):
|
| 10 |
+
"""Protocol for adding noise to a latent state during diffusion."""
|
| 11 |
+
|
| 12 |
+
def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ...
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GaussianNoiser(Noiser):
|
| 16 |
+
"""Adds Gaussian noise to a latent state, scaled by the denoise mask."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, generator: torch.Generator):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
self.generator = generator
|
| 22 |
+
|
| 23 |
+
def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState:
|
| 24 |
+
noise = torch.randn(
|
| 25 |
+
*latent_state.latent.shape,
|
| 26 |
+
device=latent_state.latent.device,
|
| 27 |
+
dtype=latent_state.latent.dtype,
|
| 28 |
+
generator=self.generator,
|
| 29 |
+
)
|
| 30 |
+
scaled_mask = latent_state.denoise_mask * noise_scale
|
| 31 |
+
latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask)
|
| 32 |
+
return replace(
|
| 33 |
+
latent_state,
|
| 34 |
+
latent=latent.to(latent_state.latent.dtype),
|
| 35 |
+
)
|
packages/ltx-core/src/ltx_core/components/patchifiers.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import einops
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ltx_core.components.protocols import Patchifier
|
| 8 |
+
from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class VideoLatentPatchifier(Patchifier):
|
| 12 |
+
def __init__(self, patch_size: int):
|
| 13 |
+
# Patch sizes for video latents.
|
| 14 |
+
self._patch_size = (
|
| 15 |
+
1, # temporal dimension
|
| 16 |
+
patch_size, # height dimension
|
| 17 |
+
patch_size, # width dimension
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def patch_size(self) -> Tuple[int, int, int]:
|
| 22 |
+
return self._patch_size
|
| 23 |
+
|
| 24 |
+
def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
|
| 25 |
+
return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
|
| 26 |
+
|
| 27 |
+
def patchify(
|
| 28 |
+
self,
|
| 29 |
+
latents: torch.Tensor,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
latents = einops.rearrange(
|
| 32 |
+
latents,
|
| 33 |
+
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
| 34 |
+
p1=self._patch_size[0],
|
| 35 |
+
p2=self._patch_size[1],
|
| 36 |
+
p3=self._patch_size[2],
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
return latents
|
| 40 |
+
|
| 41 |
+
def unpatchify(
|
| 42 |
+
self,
|
| 43 |
+
latents: torch.Tensor,
|
| 44 |
+
output_shape: VideoLatentShape,
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
|
| 47 |
+
|
| 48 |
+
patch_grid_frames = output_shape.frames // self._patch_size[0]
|
| 49 |
+
patch_grid_height = output_shape.height // self._patch_size[1]
|
| 50 |
+
patch_grid_width = output_shape.width // self._patch_size[2]
|
| 51 |
+
|
| 52 |
+
latents = einops.rearrange(
|
| 53 |
+
latents,
|
| 54 |
+
"b (f h w) (c p q) -> b c f (h p) (w q)",
|
| 55 |
+
f=patch_grid_frames,
|
| 56 |
+
h=patch_grid_height,
|
| 57 |
+
w=patch_grid_width,
|
| 58 |
+
p=self._patch_size[1],
|
| 59 |
+
q=self._patch_size[2],
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return latents
|
| 63 |
+
|
| 64 |
+
def get_patch_grid_bounds(
|
| 65 |
+
self,
|
| 66 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 67 |
+
device: Optional[torch.device] = None,
|
| 68 |
+
) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Return the per-dimension bounds [inclusive start, exclusive end) for every
|
| 71 |
+
patch produced by `patchify`. The bounds are expressed in the original
|
| 72 |
+
video grid coordinates: frame/time, height, and width.
|
| 73 |
+
The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
|
| 74 |
+
- axis 1 (size 3) enumerates (frame/time, height, width) dimensions
|
| 75 |
+
- axis 3 (size 2) stores `[start, end)` indices within each dimension
|
| 76 |
+
Args:
|
| 77 |
+
output_shape: Video grid description containing frames, height, and width.
|
| 78 |
+
device: Device of the latent tensor.
|
| 79 |
+
"""
|
| 80 |
+
if not isinstance(output_shape, VideoLatentShape):
|
| 81 |
+
raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
|
| 82 |
+
|
| 83 |
+
frames = output_shape.frames
|
| 84 |
+
height = output_shape.height
|
| 85 |
+
width = output_shape.width
|
| 86 |
+
batch_size = output_shape.batch
|
| 87 |
+
|
| 88 |
+
# Validate inputs to ensure positive dimensions
|
| 89 |
+
assert frames > 0, f"frames must be positive, got {frames}"
|
| 90 |
+
assert height > 0, f"height must be positive, got {height}"
|
| 91 |
+
assert width > 0, f"width must be positive, got {width}"
|
| 92 |
+
assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
|
| 93 |
+
|
| 94 |
+
# Generate grid coordinates for each dimension (frame, height, width)
|
| 95 |
+
# We use torch.arange to create the starting coordinates for each patch.
|
| 96 |
+
# indexing='ij' ensures the dimensions are in the order (frame, height, width).
|
| 97 |
+
grid_coords = torch.meshgrid(
|
| 98 |
+
torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
|
| 99 |
+
torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
|
| 100 |
+
torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
|
| 101 |
+
indexing="ij",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Stack the grid coordinates to create the start coordinates tensor.
|
| 105 |
+
# Shape becomes (3, grid_f, grid_h, grid_w)
|
| 106 |
+
patch_starts = torch.stack(grid_coords, dim=0)
|
| 107 |
+
|
| 108 |
+
# Create a tensor containing the size of a single patch:
|
| 109 |
+
# (frame_patch_size, height_patch_size, width_patch_size).
|
| 110 |
+
# Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
|
| 111 |
+
patch_size_delta = torch.tensor(
|
| 112 |
+
self._patch_size,
|
| 113 |
+
device=patch_starts.device,
|
| 114 |
+
dtype=patch_starts.dtype,
|
| 115 |
+
).view(3, 1, 1, 1)
|
| 116 |
+
|
| 117 |
+
# Calculate end coordinates: start + patch_size
|
| 118 |
+
# Shape becomes (3, grid_f, grid_h, grid_w)
|
| 119 |
+
patch_ends = patch_starts + patch_size_delta
|
| 120 |
+
|
| 121 |
+
# Stack start and end coordinates together along the last dimension
|
| 122 |
+
# Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
|
| 123 |
+
latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
|
| 124 |
+
|
| 125 |
+
# Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
|
| 126 |
+
# Final Shape: (batch_size, 3, num_patches, 2)
|
| 127 |
+
latent_coords = einops.repeat(
|
| 128 |
+
latent_coords,
|
| 129 |
+
"c f h w bounds -> b c (f h w) bounds",
|
| 130 |
+
b=batch_size,
|
| 131 |
+
bounds=2,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return latent_coords
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_pixel_coords(
|
| 138 |
+
latent_coords: torch.Tensor,
|
| 139 |
+
scale_factors: SpatioTemporalScaleFactors,
|
| 140 |
+
causal_fix: bool = False,
|
| 141 |
+
) -> torch.Tensor:
|
| 142 |
+
"""
|
| 143 |
+
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
|
| 144 |
+
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
|
| 145 |
+
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
|
| 146 |
+
Args:
|
| 147 |
+
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
|
| 148 |
+
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
|
| 149 |
+
per axis.
|
| 150 |
+
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
|
| 151 |
+
that treat frame zero differently still yield non-negative timestamps.
|
| 152 |
+
"""
|
| 153 |
+
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
|
| 154 |
+
broadcast_shape = [1] * latent_coords.ndim
|
| 155 |
+
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
|
| 156 |
+
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
|
| 157 |
+
|
| 158 |
+
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
|
| 159 |
+
pixel_coords = latent_coords * scale_tensor
|
| 160 |
+
|
| 161 |
+
if causal_fix:
|
| 162 |
+
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
|
| 163 |
+
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
|
| 164 |
+
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
| 165 |
+
|
| 166 |
+
return pixel_coords
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class AudioPatchifier(Patchifier):
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
patch_size: int,
|
| 173 |
+
sample_rate: int = 16000,
|
| 174 |
+
hop_length: int = 160,
|
| 175 |
+
audio_latent_downsample_factor: int = 4,
|
| 176 |
+
is_causal: bool = True,
|
| 177 |
+
shift: int = 0,
|
| 178 |
+
):
|
| 179 |
+
"""
|
| 180 |
+
Patchifier tailored for spectrogram/audio latents.
|
| 181 |
+
Args:
|
| 182 |
+
patch_size: Number of mel bins combined into a single patch. This
|
| 183 |
+
controls the resolution along the frequency axis.
|
| 184 |
+
sample_rate: Original waveform sampling rate. Used to map latent
|
| 185 |
+
indices back to seconds so downstream consumers can align audio
|
| 186 |
+
and video cues.
|
| 187 |
+
hop_length: Window hop length used for the spectrogram. Determines
|
| 188 |
+
how many real-time samples separate two consecutive latent frames.
|
| 189 |
+
audio_latent_downsample_factor: Ratio between spectrogram frames and
|
| 190 |
+
latent frames; compensates for additional downsampling inside the
|
| 191 |
+
VAE encoder.
|
| 192 |
+
is_causal: When True, timing is shifted to account for causal
|
| 193 |
+
receptive fields so timestamps do not peek into the future.
|
| 194 |
+
shift: Integer offset applied to the latent indices. Enables
|
| 195 |
+
constructing overlapping windows from the same latent sequence.
|
| 196 |
+
"""
|
| 197 |
+
self.hop_length = hop_length
|
| 198 |
+
self.sample_rate = sample_rate
|
| 199 |
+
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
| 200 |
+
self.is_causal = is_causal
|
| 201 |
+
self.shift = shift
|
| 202 |
+
self._patch_size = (1, patch_size, patch_size)
|
| 203 |
+
|
| 204 |
+
@property
|
| 205 |
+
def patch_size(self) -> Tuple[int, int, int]:
|
| 206 |
+
return self._patch_size
|
| 207 |
+
|
| 208 |
+
def get_token_count(self, tgt_shape: AudioLatentShape) -> int:
|
| 209 |
+
return tgt_shape.frames
|
| 210 |
+
|
| 211 |
+
def _get_audio_latent_time_in_sec(
|
| 212 |
+
self,
|
| 213 |
+
start_latent: int,
|
| 214 |
+
end_latent: int,
|
| 215 |
+
dtype: torch.dtype,
|
| 216 |
+
device: Optional[torch.device] = None,
|
| 217 |
+
) -> torch.Tensor:
|
| 218 |
+
"""
|
| 219 |
+
Converts latent indices into real-time seconds while honoring causal
|
| 220 |
+
offsets and the configured hop length.
|
| 221 |
+
Args:
|
| 222 |
+
start_latent: Inclusive start index inside the latent sequence. This
|
| 223 |
+
sets the first timestamp returned.
|
| 224 |
+
end_latent: Exclusive end index. Determines how many timestamps get
|
| 225 |
+
generated.
|
| 226 |
+
dtype: Floating-point dtype used for the returned tensor, allowing
|
| 227 |
+
callers to control precision.
|
| 228 |
+
device: Target device for the timestamp tensor. When omitted the
|
| 229 |
+
computation occurs on CPU to avoid surprising GPU allocations.
|
| 230 |
+
"""
|
| 231 |
+
if device is None:
|
| 232 |
+
device = torch.device("cpu")
|
| 233 |
+
|
| 234 |
+
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
|
| 235 |
+
|
| 236 |
+
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
|
| 237 |
+
|
| 238 |
+
if self.is_causal:
|
| 239 |
+
# Frame offset for causal alignment.
|
| 240 |
+
# The "+1" ensures the timestamp corresponds to the first sample that is fully available.
|
| 241 |
+
causal_offset = 1
|
| 242 |
+
audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)
|
| 243 |
+
|
| 244 |
+
return audio_mel_frame * self.hop_length / self.sample_rate
|
| 245 |
+
|
| 246 |
+
def _compute_audio_timings(
|
| 247 |
+
self,
|
| 248 |
+
batch_size: int,
|
| 249 |
+
num_steps: int,
|
| 250 |
+
device: Optional[torch.device] = None,
|
| 251 |
+
) -> torch.Tensor:
|
| 252 |
+
"""
|
| 253 |
+
Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.
|
| 254 |
+
This helper method underpins `get_patch_grid_bounds` for the audio patchifier.
|
| 255 |
+
Args:
|
| 256 |
+
batch_size: Number of sequences to broadcast the timings over.
|
| 257 |
+
num_steps: Number of latent frames (time steps) to convert into timestamps.
|
| 258 |
+
device: Device on which the resulting tensor should reside.
|
| 259 |
+
"""
|
| 260 |
+
resolved_device = device
|
| 261 |
+
if resolved_device is None:
|
| 262 |
+
resolved_device = torch.device("cpu")
|
| 263 |
+
|
| 264 |
+
start_timings = self._get_audio_latent_time_in_sec(
|
| 265 |
+
self.shift,
|
| 266 |
+
num_steps + self.shift,
|
| 267 |
+
torch.float32,
|
| 268 |
+
resolved_device,
|
| 269 |
+
)
|
| 270 |
+
start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
|
| 271 |
+
|
| 272 |
+
end_timings = self._get_audio_latent_time_in_sec(
|
| 273 |
+
self.shift + 1,
|
| 274 |
+
num_steps + self.shift + 1,
|
| 275 |
+
torch.float32,
|
| 276 |
+
resolved_device,
|
| 277 |
+
)
|
| 278 |
+
end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
|
| 279 |
+
|
| 280 |
+
return torch.stack([start_timings, end_timings], dim=-1)
|
| 281 |
+
|
| 282 |
+
def patchify(
|
| 283 |
+
self,
|
| 284 |
+
audio_latents: torch.Tensor,
|
| 285 |
+
) -> torch.Tensor:
|
| 286 |
+
"""
|
| 287 |
+
Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`
|
| 288 |
+
to derive timestamps for each latent frame based on the configured hop
|
| 289 |
+
length and downsampling.
|
| 290 |
+
Args:
|
| 291 |
+
audio_latents: Latent tensor to patchify.
|
| 292 |
+
Returns:
|
| 293 |
+
Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the
|
| 294 |
+
corresponding timing metadata when needed.
|
| 295 |
+
"""
|
| 296 |
+
audio_latents = einops.rearrange(
|
| 297 |
+
audio_latents,
|
| 298 |
+
"b c t f -> b t (c f)",
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
return audio_latents
|
| 302 |
+
|
| 303 |
+
def unpatchify(
|
| 304 |
+
self,
|
| 305 |
+
audio_latents: torch.Tensor,
|
| 306 |
+
output_shape: AudioLatentShape,
|
| 307 |
+
) -> torch.Tensor:
|
| 308 |
+
"""
|
| 309 |
+
Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.
|
| 310 |
+
Use `get_patch_grid_bounds` to recompute the timestamps that describe each
|
| 311 |
+
frame's position in real time.
|
| 312 |
+
Args:
|
| 313 |
+
audio_latents: Latent tensor to unpatchify.
|
| 314 |
+
output_shape: Shape of the unpatched output tensor.
|
| 315 |
+
Returns:
|
| 316 |
+
Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing
|
| 317 |
+
metadata associated with the restored latents.
|
| 318 |
+
"""
|
| 319 |
+
# audio_latents shape: (batch, time, freq * channels)
|
| 320 |
+
audio_latents = einops.rearrange(
|
| 321 |
+
audio_latents,
|
| 322 |
+
"b t (c f) -> b c t f",
|
| 323 |
+
c=output_shape.channels,
|
| 324 |
+
f=output_shape.mel_bins,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
return audio_latents
|
| 328 |
+
|
| 329 |
+
def get_patch_grid_bounds(
|
| 330 |
+
self,
|
| 331 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 332 |
+
device: Optional[torch.device] = None,
|
| 333 |
+
) -> torch.Tensor:
|
| 334 |
+
"""
|
| 335 |
+
Return the temporal bounds `[inclusive start, exclusive end)` for every
|
| 336 |
+
patch emitted by `patchify`. For audio this corresponds to timestamps in
|
| 337 |
+
seconds aligned with the original spectrogram grid.
|
| 338 |
+
The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:
|
| 339 |
+
- axis 1 (size 1) represents the temporal dimension
|
| 340 |
+
- axis 3 (size 2) stores the `[start, end)` timestamps per patch
|
| 341 |
+
Args:
|
| 342 |
+
output_shape: Audio grid specification describing the number of time steps.
|
| 343 |
+
device: Target device for the returned tensor.
|
| 344 |
+
"""
|
| 345 |
+
if not isinstance(output_shape, AudioLatentShape):
|
| 346 |
+
raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates")
|
| 347 |
+
|
| 348 |
+
return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)
|
packages/ltx-core/src/ltx_core/components/protocols.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Protocol, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.types import AudioLatentShape, VideoLatentShape
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Patchifier(Protocol):
|
| 9 |
+
"""
|
| 10 |
+
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def patchify(
|
| 14 |
+
self,
|
| 15 |
+
latents: torch.Tensor,
|
| 16 |
+
) -> torch.Tensor:
|
| 17 |
+
...
|
| 18 |
+
"""
|
| 19 |
+
Convert latent tensors into flattened patch tokens.
|
| 20 |
+
Args:
|
| 21 |
+
latents: Latent tensor to patchify.
|
| 22 |
+
Returns:
|
| 23 |
+
Flattened patch tokens tensor.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def unpatchify(
|
| 27 |
+
self,
|
| 28 |
+
latents: torch.Tensor,
|
| 29 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
"""
|
| 32 |
+
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
|
| 33 |
+
Args:
|
| 34 |
+
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
|
| 35 |
+
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
|
| 36 |
+
VideoLatentShape.
|
| 37 |
+
Returns:
|
| 38 |
+
Dense latent tensor restored from the flattened representation.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def patch_size(self) -> Tuple[int, int, int]:
|
| 43 |
+
...
|
| 44 |
+
"""
|
| 45 |
+
Returns the patch size as a tuple of (temporal, height, width) dimensions
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def get_patch_grid_bounds(
|
| 49 |
+
self,
|
| 50 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 51 |
+
device: torch.device | None = None,
|
| 52 |
+
) -> torch.Tensor:
|
| 53 |
+
...
|
| 54 |
+
"""
|
| 55 |
+
Compute metadata describing where each latent patch resides within the
|
| 56 |
+
grid specified by `output_shape`.
|
| 57 |
+
Args:
|
| 58 |
+
output_shape: Target grid layout for the patches.
|
| 59 |
+
device: Target device for the returned tensor.
|
| 60 |
+
Returns:
|
| 61 |
+
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SchedulerProtocol(Protocol):
|
| 66 |
+
"""
|
| 67 |
+
Protocol for schedulers that provide a sigmas schedule tensor for a
|
| 68 |
+
given number of steps. Device is cpu.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ...
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class GuiderProtocol(Protocol):
|
| 75 |
+
"""
|
| 76 |
+
Protocol for guiders that compute a delta tensor given conditioning inputs.
|
| 77 |
+
The returned delta should be added to the conditional output (cond), enabling
|
| 78 |
+
multiple guiders to be chained together by accumulating their deltas.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
scale: float
|
| 82 |
+
|
| 83 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ...
|
| 84 |
+
|
| 85 |
+
def enabled(self) -> bool:
|
| 86 |
+
"""
|
| 87 |
+
Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale
|
| 88 |
+
is 1.0.
|
| 89 |
+
"""
|
| 90 |
+
...
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class DiffusionStepProtocol(Protocol):
|
| 94 |
+
"""
|
| 95 |
+
Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor,
|
| 96 |
+
current denoised sample tensor, and sigmas tensor.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def step(
|
| 100 |
+
self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **kwargs
|
| 101 |
+
) -> torch.Tensor: ...
|
packages/ltx-core/src/ltx_core/components/schedulers.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
|
| 4 |
+
import numpy
|
| 5 |
+
import scipy
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from ltx_core.components.protocols import SchedulerProtocol
|
| 9 |
+
|
| 10 |
+
BASE_SHIFT_ANCHOR = 1024
|
| 11 |
+
MAX_SHIFT_ANCHOR = 4096
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LTX2Scheduler(SchedulerProtocol):
|
| 15 |
+
"""
|
| 16 |
+
Default scheduler for LTX-2 diffusion sampling.
|
| 17 |
+
Generates a sigma schedule with token-count-dependent shifting and optional
|
| 18 |
+
stretching to a terminal value.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def execute(
|
| 22 |
+
self,
|
| 23 |
+
steps: int,
|
| 24 |
+
latent: torch.Tensor | None = None,
|
| 25 |
+
max_shift: float = 2.05,
|
| 26 |
+
base_shift: float = 0.95,
|
| 27 |
+
stretch: bool = True,
|
| 28 |
+
terminal: float = 0.1,
|
| 29 |
+
default_number_of_tokens: int = MAX_SHIFT_ANCHOR,
|
| 30 |
+
**_kwargs,
|
| 31 |
+
) -> torch.FloatTensor:
|
| 32 |
+
tokens = math.prod(latent.shape[2:]) if latent is not None else default_number_of_tokens
|
| 33 |
+
sigmas = torch.linspace(1.0, 0.0, steps + 1)
|
| 34 |
+
|
| 35 |
+
x1 = BASE_SHIFT_ANCHOR
|
| 36 |
+
x2 = MAX_SHIFT_ANCHOR
|
| 37 |
+
mm = (max_shift - base_shift) / (x2 - x1)
|
| 38 |
+
b = base_shift - mm * x1
|
| 39 |
+
sigma_shift = (tokens) * mm + b
|
| 40 |
+
|
| 41 |
+
power = 1
|
| 42 |
+
sigmas = torch.where(
|
| 43 |
+
sigmas != 0,
|
| 44 |
+
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
|
| 45 |
+
0,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Stretch sigmas so that its final value matches the given terminal value.
|
| 49 |
+
if stretch:
|
| 50 |
+
non_zero_mask = sigmas != 0
|
| 51 |
+
non_zero_sigmas = sigmas[non_zero_mask]
|
| 52 |
+
one_minus_z = 1.0 - non_zero_sigmas
|
| 53 |
+
scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
| 54 |
+
stretched = 1.0 - (one_minus_z / scale_factor)
|
| 55 |
+
sigmas[non_zero_mask] = stretched
|
| 56 |
+
|
| 57 |
+
return sigmas.to(torch.float32)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class LinearQuadraticScheduler(SchedulerProtocol):
|
| 61 |
+
"""
|
| 62 |
+
Scheduler with linear steps followed by quadratic steps.
|
| 63 |
+
Produces a sigma schedule that transitions linearly up to a threshold,
|
| 64 |
+
then follows a quadratic curve for the remaining steps.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def execute(
|
| 68 |
+
self, steps: int, threshold_noise: float = 0.025, linear_steps: int | None = None, **_kwargs
|
| 69 |
+
) -> torch.FloatTensor:
|
| 70 |
+
if steps == 1:
|
| 71 |
+
return torch.FloatTensor([1.0, 0.0])
|
| 72 |
+
|
| 73 |
+
if linear_steps is None:
|
| 74 |
+
linear_steps = steps // 2
|
| 75 |
+
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
| 76 |
+
threshold_noise_step_diff = linear_steps - threshold_noise * steps
|
| 77 |
+
quadratic_steps = steps - linear_steps
|
| 78 |
+
quadratic_sigma_schedule = []
|
| 79 |
+
if quadratic_steps > 0:
|
| 80 |
+
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
|
| 81 |
+
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
|
| 82 |
+
const = quadratic_coef * (linear_steps**2)
|
| 83 |
+
quadratic_sigma_schedule = [
|
| 84 |
+
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, steps)
|
| 85 |
+
]
|
| 86 |
+
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
| 87 |
+
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
| 88 |
+
return torch.FloatTensor(sigma_schedule)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class BetaScheduler(SchedulerProtocol):
|
| 92 |
+
"""
|
| 93 |
+
Scheduler using a beta distribution to sample timesteps.
|
| 94 |
+
Based on: https://arxiv.org/abs/2407.12173
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
shift = 2.37
|
| 98 |
+
timesteps_length = 10000
|
| 99 |
+
|
| 100 |
+
def execute(self, steps: int, alpha: float = 0.6, beta: float = 0.6) -> torch.FloatTensor:
|
| 101 |
+
"""
|
| 102 |
+
Execute the beta scheduler.
|
| 103 |
+
Args:
|
| 104 |
+
steps: The number of steps to execute the scheduler for.
|
| 105 |
+
alpha: The alpha parameter for the beta distribution.
|
| 106 |
+
beta: The beta parameter for the beta distribution.
|
| 107 |
+
Warnings:
|
| 108 |
+
The number of steps within `sigmas` theoretically might be less than `steps+1`,
|
| 109 |
+
because of the deduplication of the identical timesteps
|
| 110 |
+
Returns:
|
| 111 |
+
A tensor of sigmas.
|
| 112 |
+
"""
|
| 113 |
+
model_sampling_sigmas = _precalculate_model_sampling_sigmas(self.shift, self.timesteps_length)
|
| 114 |
+
total_timesteps = len(model_sampling_sigmas) - 1
|
| 115 |
+
ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
|
| 116 |
+
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps).tolist()
|
| 117 |
+
ts = list(dict.fromkeys(ts))
|
| 118 |
+
|
| 119 |
+
sigmas = [float(model_sampling_sigmas[int(t)]) for t in ts] + [0.0]
|
| 120 |
+
return torch.FloatTensor(sigmas)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@lru_cache(maxsize=5)
|
| 124 |
+
def _precalculate_model_sampling_sigmas(shift: float, timesteps_length: int) -> torch.Tensor:
|
| 125 |
+
timesteps = torch.arange(1, timesteps_length + 1, 1) / timesteps_length
|
| 126 |
+
return torch.Tensor([flux_time_shift(shift, 1.0, t) for t in timesteps])
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def flux_time_shift(mu: float, sigma: float, t: float) -> float:
|
| 130 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
packages/ltx-core/src/ltx_core/conditioning/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Conditioning utilities: latent state, tools, and conditioning types."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.conditioning.exceptions import ConditioningError
|
| 4 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 5 |
+
from ltx_core.conditioning.types import (
|
| 6 |
+
ConditioningItemAttentionStrengthWrapper,
|
| 7 |
+
VideoConditionByKeyframeIndex,
|
| 8 |
+
VideoConditionByLatentIndex,
|
| 9 |
+
VideoConditionByReferenceLatent,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"ConditioningError",
|
| 14 |
+
"ConditioningItem",
|
| 15 |
+
"ConditioningItemAttentionStrengthWrapper",
|
| 16 |
+
"VideoConditionByKeyframeIndex",
|
| 17 |
+
"VideoConditionByLatentIndex",
|
| 18 |
+
"VideoConditionByReferenceLatent",
|
| 19 |
+
]
|
packages/ltx-core/src/ltx_core/conditioning/exceptions.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ConditioningError(Exception):
|
| 2 |
+
"""
|
| 3 |
+
Class for conditioning-related errors.
|
| 4 |
+
"""
|
packages/ltx-core/src/ltx_core/conditioning/item.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Protocol
|
| 2 |
+
|
| 3 |
+
from ltx_core.tools import LatentTools
|
| 4 |
+
from ltx_core.types import LatentState
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ConditioningItem(Protocol):
|
| 8 |
+
"""Protocol for conditioning items that modify latent state during diffusion."""
|
| 9 |
+
|
| 10 |
+
def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
|
| 11 |
+
"""
|
| 12 |
+
Apply the conditioning to the latent state.
|
| 13 |
+
Args:
|
| 14 |
+
latent_state: The latent state to apply the conditioning to. This is state always patchified.
|
| 15 |
+
Returns:
|
| 16 |
+
The latent state after the conditioning has been applied.
|
| 17 |
+
IMPORTANT: If the conditioning needs to add extra tokens to the latent, it should add them to the end of the
|
| 18 |
+
latent.
|
| 19 |
+
"""
|
| 20 |
+
...
|
packages/ltx-core/src/ltx_core/conditioning/mask_utils.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities for building 2D self-attention masks for conditioning items."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from ltx_core.types import LatentState
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def resolve_cross_mask(
|
| 14 |
+
attention_mask: float | int | torch.Tensor,
|
| 15 |
+
num_new_tokens: int,
|
| 16 |
+
batch_size: int,
|
| 17 |
+
device: torch.device,
|
| 18 |
+
dtype: torch.dtype,
|
| 19 |
+
) -> torch.Tensor:
|
| 20 |
+
"""Convert an attention_mask (scalar or tensor) to a (B, M) cross_mask tensor.
|
| 21 |
+
Args:
|
| 22 |
+
attention_mask: Scalar value applied uniformly, 1D tensor of shape (M,)
|
| 23 |
+
broadcast across batch, or 2D tensor of shape (B, M).
|
| 24 |
+
num_new_tokens: Number of new conditioning tokens M.
|
| 25 |
+
batch_size: Batch size B.
|
| 26 |
+
device: Device for the output tensor.
|
| 27 |
+
dtype: Data type for the output tensor.
|
| 28 |
+
Returns:
|
| 29 |
+
Cross-mask tensor of shape (B, M).
|
| 30 |
+
"""
|
| 31 |
+
if isinstance(attention_mask, (int, float)):
|
| 32 |
+
return torch.full(
|
| 33 |
+
(batch_size, num_new_tokens),
|
| 34 |
+
fill_value=float(attention_mask),
|
| 35 |
+
device=device,
|
| 36 |
+
dtype=dtype,
|
| 37 |
+
)
|
| 38 |
+
mask = attention_mask.to(device=device, dtype=dtype)
|
| 39 |
+
|
| 40 |
+
# Handle scalar (0-D) tensor like a Python scalar.
|
| 41 |
+
if mask.dim() == 0:
|
| 42 |
+
return torch.full(
|
| 43 |
+
(batch_size, num_new_tokens),
|
| 44 |
+
fill_value=float(mask.item()),
|
| 45 |
+
device=device,
|
| 46 |
+
dtype=dtype,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
if mask.dim() == 1:
|
| 50 |
+
if mask.shape[0] != num_new_tokens:
|
| 51 |
+
raise ValueError(
|
| 52 |
+
f"1-D attention_mask length must equal num_new_tokens ({num_new_tokens}), got shape {tuple(mask.shape)}"
|
| 53 |
+
)
|
| 54 |
+
mask = mask.unsqueeze(0).expand(batch_size, -1)
|
| 55 |
+
elif mask.dim() == 2:
|
| 56 |
+
b, m = mask.shape
|
| 57 |
+
if m != num_new_tokens:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
f"2-D attention_mask second dimension must equal num_new_tokens ({num_new_tokens}), "
|
| 60 |
+
f"got shape {tuple(mask.shape)}"
|
| 61 |
+
)
|
| 62 |
+
if b not in (batch_size, 1):
|
| 63 |
+
raise ValueError(
|
| 64 |
+
f"2-D attention_mask batch dimension must equal batch_size ({batch_size}) or 1, "
|
| 65 |
+
f"got shape {tuple(mask.shape)}"
|
| 66 |
+
)
|
| 67 |
+
if b == 1 and batch_size > 1:
|
| 68 |
+
mask = mask.expand(batch_size, -1)
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError(
|
| 71 |
+
f"attention_mask tensor must be 0-D, 1-D, or 2-D, got {mask.dim()}-D with shape {tuple(mask.shape)}"
|
| 72 |
+
)
|
| 73 |
+
return mask
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def update_attention_mask(
|
| 77 |
+
latent_state: LatentState,
|
| 78 |
+
attention_mask: float | torch.Tensor | None,
|
| 79 |
+
num_noisy_tokens: int,
|
| 80 |
+
num_new_tokens: int,
|
| 81 |
+
batch_size: int,
|
| 82 |
+
device: torch.device,
|
| 83 |
+
dtype: torch.dtype,
|
| 84 |
+
) -> torch.Tensor | None:
|
| 85 |
+
"""Build or update the self-attention mask for newly appended conditioning tokens.
|
| 86 |
+
If *attention_mask* is ``None`` and no existing mask is present, returns
|
| 87 |
+
``None``. If *attention_mask* is ``None`` but an existing mask is present,
|
| 88 |
+
the mask is expanded with full attention (1s) for the new tokens so that
|
| 89 |
+
its dimensions stay consistent with the growing latent sequence. Otherwise,
|
| 90 |
+
resolves *attention_mask* to a per-token cross-mask and expands the 2-D
|
| 91 |
+
attention mask via :func:`build_attention_mask`.
|
| 92 |
+
Args:
|
| 93 |
+
latent_state: Current latent state (provides the existing mask and total
|
| 94 |
+
existing-token count).
|
| 95 |
+
attention_mask: Per-token attention weight. Scalar, 1-D ``(M,)``, 2-D
|
| 96 |
+
``(B, M)`` tensor, or ``None`` (no-op).
|
| 97 |
+
num_noisy_tokens: Number of original noisy tokens (from
|
| 98 |
+
``latent_tools.target_shape.token_count()``).
|
| 99 |
+
num_new_tokens: Number of new conditioning tokens being appended.
|
| 100 |
+
batch_size: Batch size.
|
| 101 |
+
device: Device for the output tensor.
|
| 102 |
+
dtype: Data type for the output tensor.
|
| 103 |
+
Returns:
|
| 104 |
+
Updated attention mask of shape ``(B, N+M, N+M)``, or ``None`` if no
|
| 105 |
+
masking is needed.
|
| 106 |
+
"""
|
| 107 |
+
if attention_mask is None:
|
| 108 |
+
if latent_state.attention_mask is None:
|
| 109 |
+
return None
|
| 110 |
+
# Existing mask present but no new mask requested: pad with 1s (full
|
| 111 |
+
# attention) so the mask dimensions stay consistent with the growing
|
| 112 |
+
# latent sequence.
|
| 113 |
+
cross_mask = torch.ones(batch_size, num_new_tokens, device=device, dtype=dtype)
|
| 114 |
+
return build_attention_mask(
|
| 115 |
+
existing_mask=latent_state.attention_mask,
|
| 116 |
+
num_noisy_tokens=num_noisy_tokens,
|
| 117 |
+
num_new_tokens=num_new_tokens,
|
| 118 |
+
num_existing_tokens=latent_state.latent.shape[1],
|
| 119 |
+
cross_mask=cross_mask,
|
| 120 |
+
device=device,
|
| 121 |
+
dtype=dtype,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
cross_mask = resolve_cross_mask(attention_mask, num_new_tokens, batch_size, device, dtype)
|
| 125 |
+
return build_attention_mask(
|
| 126 |
+
existing_mask=latent_state.attention_mask,
|
| 127 |
+
num_noisy_tokens=num_noisy_tokens,
|
| 128 |
+
num_new_tokens=num_new_tokens,
|
| 129 |
+
num_existing_tokens=latent_state.latent.shape[1],
|
| 130 |
+
cross_mask=cross_mask,
|
| 131 |
+
device=device,
|
| 132 |
+
dtype=dtype,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def build_attention_mask(
|
| 137 |
+
existing_mask: torch.Tensor | None,
|
| 138 |
+
num_noisy_tokens: int,
|
| 139 |
+
num_new_tokens: int,
|
| 140 |
+
num_existing_tokens: int,
|
| 141 |
+
cross_mask: torch.Tensor,
|
| 142 |
+
device: torch.device,
|
| 143 |
+
dtype: torch.dtype,
|
| 144 |
+
) -> torch.Tensor:
|
| 145 |
+
"""
|
| 146 |
+
Expand the attention mask to include newly appended conditioning tokens.
|
| 147 |
+
Each conditioning item appends M new reference tokens to the sequence. This function
|
| 148 |
+
builds a (B, N+M, N+M) attention mask with the following block structure:
|
| 149 |
+
noisy prev_ref new_ref
|
| 150 |
+
(N_noisy) (N-N_noisy) (M)
|
| 151 |
+
┌───────────┬───────────┬───────────┐
|
| 152 |
+
noisy │ │ │ │
|
| 153 |
+
(N_noisy) │ existing │ existing │ cross │
|
| 154 |
+
│ │ │ │
|
| 155 |
+
├───────────┼───────────┼───────────┤
|
| 156 |
+
prev_ref │ │ │ │
|
| 157 |
+
(N-N_noisy)│ existing │ existing │ 0 │
|
| 158 |
+
│ │ │ │
|
| 159 |
+
├───────────┼───────────┼───────────┤
|
| 160 |
+
new_ref │ │ │ │
|
| 161 |
+
(M) │ cross │ 0 │ 1 │
|
| 162 |
+
│ │ │ │
|
| 163 |
+
└───────────┴───────────┴───────────┘
|
| 164 |
+
Where:
|
| 165 |
+
- **existing**: preserved from the previous mask (or 1.0 if first conditioning)
|
| 166 |
+
- **cross**: values from *cross_mask* (shape B, M), in [0, 1]
|
| 167 |
+
- **0**: no attention between different reference groups
|
| 168 |
+
Args:
|
| 169 |
+
existing_mask: Current attention mask of shape (B, N, N), or None if no mask exists yet.
|
| 170 |
+
When None, the top-left NxN block is filled with 1s (full attention between all
|
| 171 |
+
existing tokens including any prior reference tokens that had no mask).
|
| 172 |
+
num_noisy_tokens: Number of original noisy tokens (always at positions [0:num_noisy_tokens]).
|
| 173 |
+
num_new_tokens: Number of new conditioning tokens M being appended.
|
| 174 |
+
num_existing_tokens: Total number of current tokens N (noisy + any prior conditioning tokens).
|
| 175 |
+
cross_mask: Per-token attention weight of shape (B, M) controlling attention between
|
| 176 |
+
new reference tokens and noisy tokens. Values in [0, 1].
|
| 177 |
+
device: Device for the output tensor.
|
| 178 |
+
dtype: Data type for the output tensor.
|
| 179 |
+
Returns:
|
| 180 |
+
Attention mask of shape (B, N+M, N+M) with values in [0, 1].
|
| 181 |
+
"""
|
| 182 |
+
batch_size = cross_mask.shape[0]
|
| 183 |
+
total = num_existing_tokens + num_new_tokens
|
| 184 |
+
|
| 185 |
+
# Start with zeros
|
| 186 |
+
mask = torch.zeros((batch_size, total, total), device=device, dtype=dtype)
|
| 187 |
+
|
| 188 |
+
# Top-left: preserve existing mask or fill with 1s for noisy tokens
|
| 189 |
+
if existing_mask is not None:
|
| 190 |
+
mask[:, :num_existing_tokens, :num_existing_tokens] = existing_mask
|
| 191 |
+
else:
|
| 192 |
+
mask[:, :num_existing_tokens, :num_existing_tokens] = 1.0
|
| 193 |
+
|
| 194 |
+
# Bottom-right: new reference tokens fully attend to themselves
|
| 195 |
+
mask[:, num_existing_tokens:, num_existing_tokens:] = 1.0
|
| 196 |
+
|
| 197 |
+
# Cross-attention between noisy tokens and new reference tokens
|
| 198 |
+
# cross_mask shape: (B, M) -> broadcast to (B, N_noisy, M) and (B, M, N_noisy)
|
| 199 |
+
|
| 200 |
+
# Noisy tokens attending to new reference tokens: [0:N_noisy, N:N+M]
|
| 201 |
+
# Each column j in this block gets cross_mask[:, j]
|
| 202 |
+
mask[:, :num_noisy_tokens, num_existing_tokens:] = cross_mask.unsqueeze(1)
|
| 203 |
+
|
| 204 |
+
# New reference tokens attending to noisy tokens: [N:N+M, 0:N_noisy]
|
| 205 |
+
# Each row i in this block gets cross_mask[:, i]
|
| 206 |
+
mask[:, num_existing_tokens:, :num_noisy_tokens] = cross_mask.unsqueeze(2)
|
| 207 |
+
|
| 208 |
+
# [N_noisy:N, N:N+M] and [N:N+M, N_noisy:N] remain 0 (no cross-ref attention)
|
| 209 |
+
|
| 210 |
+
return mask
|
packages/ltx-core/src/ltx_core/guidance/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Guidance and perturbation utilities for attention manipulation."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.guidance.perturbations import (
|
| 4 |
+
BatchedPerturbationConfig,
|
| 5 |
+
Perturbation,
|
| 6 |
+
PerturbationConfig,
|
| 7 |
+
PerturbationType,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"BatchedPerturbationConfig",
|
| 12 |
+
"Perturbation",
|
| 13 |
+
"PerturbationConfig",
|
| 14 |
+
"PerturbationType",
|
| 15 |
+
]
|
packages/ltx-core/src/ltx_core/guidance/perturbations.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from enum import Enum
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._prims_common import DeviceLikeType
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PerturbationType(Enum):
|
| 9 |
+
"""Types of attention perturbations for STG (Spatio-Temporal Guidance)."""
|
| 10 |
+
|
| 11 |
+
SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn"
|
| 12 |
+
SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn"
|
| 13 |
+
SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn"
|
| 14 |
+
SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass(frozen=True)
|
| 18 |
+
class Perturbation:
|
| 19 |
+
"""A single perturbation specifying which attention type to skip and in which blocks."""
|
| 20 |
+
|
| 21 |
+
type: PerturbationType
|
| 22 |
+
blocks: list[int] | None # None means all blocks
|
| 23 |
+
|
| 24 |
+
def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 25 |
+
if self.type != perturbation_type:
|
| 26 |
+
return False
|
| 27 |
+
|
| 28 |
+
if self.blocks is None:
|
| 29 |
+
return True
|
| 30 |
+
|
| 31 |
+
return block in self.blocks
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass(frozen=True)
|
| 35 |
+
class PerturbationConfig:
|
| 36 |
+
"""Configuration holding a list of perturbations for a single sample."""
|
| 37 |
+
|
| 38 |
+
perturbations: list[Perturbation] | None
|
| 39 |
+
|
| 40 |
+
def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 41 |
+
if self.perturbations is None:
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def empty() -> "PerturbationConfig":
|
| 48 |
+
return PerturbationConfig([])
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass(frozen=True)
|
| 52 |
+
class BatchedPerturbationConfig:
|
| 53 |
+
"""Perturbation configurations for a batch, with utilities for generating attention masks."""
|
| 54 |
+
|
| 55 |
+
perturbations: list[PerturbationConfig]
|
| 56 |
+
|
| 57 |
+
def mask(
|
| 58 |
+
self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype
|
| 59 |
+
) -> torch.Tensor:
|
| 60 |
+
mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
|
| 61 |
+
for batch_idx, perturbation in enumerate(self.perturbations):
|
| 62 |
+
if perturbation.is_perturbed(perturbation_type, block):
|
| 63 |
+
mask[batch_idx] = 0
|
| 64 |
+
|
| 65 |
+
return mask
|
| 66 |
+
|
| 67 |
+
def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor:
|
| 68 |
+
mask = self.mask(perturbation_type, block, values.device, values.dtype)
|
| 69 |
+
return mask.view(mask.numel(), *([1] * len(values.shape[1:])))
|
| 70 |
+
|
| 71 |
+
def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 72 |
+
return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
| 73 |
+
|
| 74 |
+
def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 75 |
+
return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def empty(batch_size: int) -> "BatchedPerturbationConfig":
|
| 79 |
+
return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
|
packages/ltx-core/src/ltx_core/loader/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Loader utilities for model weights, LoRAs, and safetensor operations."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.loader.fuse_loras import apply_loras
|
| 4 |
+
from ltx_core.loader.module_ops import ModuleOps
|
| 5 |
+
from ltx_core.loader.primitives import (
|
| 6 |
+
LoRAAdaptableProtocol,
|
| 7 |
+
LoraPathStrengthAndSDOps,
|
| 8 |
+
LoraStateDictWithStrength,
|
| 9 |
+
ModelBuilderProtocol,
|
| 10 |
+
StateDict,
|
| 11 |
+
StateDictLoader,
|
| 12 |
+
)
|
| 13 |
+
from ltx_core.loader.registry import DummyRegistry, Registry, StateDictRegistry
|
| 14 |
+
from ltx_core.loader.sd_ops import (
|
| 15 |
+
LTXV_LORA_COMFY_RENAMING_MAP,
|
| 16 |
+
ContentMatching,
|
| 17 |
+
ContentReplacement,
|
| 18 |
+
KeyValueOperation,
|
| 19 |
+
KeyValueOperationResult,
|
| 20 |
+
SDKeyValueOperation,
|
| 21 |
+
SDOps,
|
| 22 |
+
)
|
| 23 |
+
from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader, SafetensorsStateDictLoader
|
| 24 |
+
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"LTXV_LORA_COMFY_RENAMING_MAP",
|
| 28 |
+
"ContentMatching",
|
| 29 |
+
"ContentReplacement",
|
| 30 |
+
"DummyRegistry",
|
| 31 |
+
"KeyValueOperation",
|
| 32 |
+
"KeyValueOperationResult",
|
| 33 |
+
"LoRAAdaptableProtocol",
|
| 34 |
+
"LoraPathStrengthAndSDOps",
|
| 35 |
+
"LoraStateDictWithStrength",
|
| 36 |
+
"ModelBuilderProtocol",
|
| 37 |
+
"ModuleOps",
|
| 38 |
+
"Registry",
|
| 39 |
+
"SDKeyValueOperation",
|
| 40 |
+
"SDOps",
|
| 41 |
+
"SafetensorsModelStateDictLoader",
|
| 42 |
+
"SafetensorsStateDictLoader",
|
| 43 |
+
"SingleGPUModelBuilder",
|
| 44 |
+
"StateDict",
|
| 45 |
+
"StateDictLoader",
|
| 46 |
+
"StateDictRegistry",
|
| 47 |
+
"apply_loras",
|
| 48 |
+
]
|
packages/ltx-core/src/ltx_core/loader/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.33 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/fuse_loras.cpython-312.pyc
ADDED
|
Binary file (7.41 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/module_ops.cpython-312.pyc
ADDED
|
Binary file (955 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/primitives.cpython-312.pyc
ADDED
|
Binary file (5.37 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/registry.cpython-312.pyc
ADDED
|
Binary file (5.68 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/sd_ops.cpython-312.pyc
ADDED
|
Binary file (6.81 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/sft_loader.cpython-312.pyc
ADDED
|
Binary file (4.36 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/__pycache__/single_gpu_model_builder.cpython-312.pyc
ADDED
|
Binary file (8.84 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/loader/fuse_loras.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict
|
| 4 |
+
from ltx_core.quantization.fp8_cast import calculate_weight_float8
|
| 5 |
+
from ltx_core.quantization.fp8_scaled_mm import quantize_weight_to_fp8_per_tensor
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def apply_loras(
|
| 9 |
+
model_sd: StateDict,
|
| 10 |
+
lora_sd_and_strengths: list[LoraStateDictWithStrength],
|
| 11 |
+
dtype: torch.dtype | None = None,
|
| 12 |
+
destination_sd: StateDict | None = None,
|
| 13 |
+
) -> StateDict:
|
| 14 |
+
sd = {}
|
| 15 |
+
if destination_sd is not None:
|
| 16 |
+
sd = destination_sd.sd
|
| 17 |
+
size = 0
|
| 18 |
+
device = torch.device("meta")
|
| 19 |
+
inner_dtypes = set()
|
| 20 |
+
for key, weight in model_sd.sd.items():
|
| 21 |
+
if weight is None:
|
| 22 |
+
continue
|
| 23 |
+
# Skip scale keys - they are handled together with their weight keys
|
| 24 |
+
if key.endswith(".weight_scale"):
|
| 25 |
+
continue
|
| 26 |
+
device = weight.device
|
| 27 |
+
target_dtype = dtype if dtype is not None else weight.dtype
|
| 28 |
+
deltas_dtype = target_dtype if target_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16
|
| 29 |
+
|
| 30 |
+
scale_key = key.replace(".weight", ".weight_scale") if key.endswith(".weight") else None
|
| 31 |
+
is_scaled_fp8 = scale_key is not None and scale_key in model_sd.sd
|
| 32 |
+
|
| 33 |
+
deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, device)
|
| 34 |
+
fused = _fuse_deltas(deltas, weight, key, sd, target_dtype, device, is_scaled_fp8, scale_key, model_sd)
|
| 35 |
+
|
| 36 |
+
sd.update(fused)
|
| 37 |
+
for tensor in fused.values():
|
| 38 |
+
inner_dtypes.add(tensor.dtype)
|
| 39 |
+
size += tensor.nbytes
|
| 40 |
+
|
| 41 |
+
if destination_sd is not None:
|
| 42 |
+
return destination_sd
|
| 43 |
+
return StateDict(sd, device, size, inner_dtypes)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _prepare_deltas(
|
| 47 |
+
lora_sd_and_strengths: list[LoraStateDictWithStrength], key: str, dtype: torch.dtype, device: torch.device
|
| 48 |
+
) -> torch.Tensor | None:
|
| 49 |
+
deltas = []
|
| 50 |
+
prefix = key[: -len(".weight")]
|
| 51 |
+
key_a = f"{prefix}.lora_A.weight"
|
| 52 |
+
key_b = f"{prefix}.lora_B.weight"
|
| 53 |
+
for lsd, coef in lora_sd_and_strengths:
|
| 54 |
+
if key_a not in lsd.sd or key_b not in lsd.sd:
|
| 55 |
+
continue
|
| 56 |
+
a = lsd.sd[key_a].to(device=device)
|
| 57 |
+
b = lsd.sd[key_b].to(device=device)
|
| 58 |
+
product = torch.matmul(b * coef, a)
|
| 59 |
+
del a, b
|
| 60 |
+
deltas.append(product.to(dtype=dtype))
|
| 61 |
+
if len(deltas) == 0:
|
| 62 |
+
return None
|
| 63 |
+
elif len(deltas) == 1:
|
| 64 |
+
return deltas[0]
|
| 65 |
+
return torch.sum(torch.stack(deltas, dim=0), dim=0)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _fuse_deltas(
|
| 69 |
+
deltas: torch.Tensor | None,
|
| 70 |
+
weight: torch.Tensor,
|
| 71 |
+
key: str,
|
| 72 |
+
sd: dict[str, torch.Tensor],
|
| 73 |
+
target_dtype: torch.dtype,
|
| 74 |
+
device: torch.device,
|
| 75 |
+
is_scaled_fp8: bool,
|
| 76 |
+
scale_key: str | None,
|
| 77 |
+
model_sd: StateDict,
|
| 78 |
+
) -> dict[str, torch.Tensor]:
|
| 79 |
+
if deltas is None:
|
| 80 |
+
if key in sd:
|
| 81 |
+
return {}
|
| 82 |
+
fused = _copy_weight_without_lora(weight, key, target_dtype, device, is_scaled_fp8, scale_key, model_sd)
|
| 83 |
+
elif weight.dtype == torch.float8_e4m3fn:
|
| 84 |
+
if is_scaled_fp8:
|
| 85 |
+
fused = _fuse_delta_with_scaled_fp8(deltas, weight, key, scale_key, model_sd)
|
| 86 |
+
else:
|
| 87 |
+
fused = _fuse_delta_with_cast_fp8(deltas, weight, key, target_dtype, device)
|
| 88 |
+
elif weight.dtype == torch.bfloat16:
|
| 89 |
+
fused = _fuse_delta_with_bfloat16(deltas, weight, key, target_dtype)
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError(f"Unsupported dtype: {weight.dtype}")
|
| 92 |
+
|
| 93 |
+
return fused
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _copy_weight_without_lora(
|
| 97 |
+
weight: torch.Tensor,
|
| 98 |
+
key: str,
|
| 99 |
+
target_dtype: torch.dtype,
|
| 100 |
+
device: torch.device,
|
| 101 |
+
is_scaled_fp8: bool,
|
| 102 |
+
scale_key: str | None,
|
| 103 |
+
model_sd: StateDict,
|
| 104 |
+
) -> dict[str, torch.Tensor]:
|
| 105 |
+
"""Copy original weight (and scale if applicable) when no LoRA affects this key."""
|
| 106 |
+
result = {key: weight.clone().to(dtype=target_dtype, device=device)}
|
| 107 |
+
if is_scaled_fp8:
|
| 108 |
+
result[scale_key] = model_sd.sd[scale_key].clone()
|
| 109 |
+
return result
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _fuse_delta_with_scaled_fp8(
|
| 113 |
+
deltas: torch.Tensor,
|
| 114 |
+
weight: torch.Tensor,
|
| 115 |
+
key: str,
|
| 116 |
+
scale_key: str,
|
| 117 |
+
model_sd: StateDict,
|
| 118 |
+
) -> dict[str, torch.Tensor]:
|
| 119 |
+
"""Dequantize scaled FP8 weight, add LoRA delta, and re-quantize."""
|
| 120 |
+
weight_scale = model_sd.sd[scale_key]
|
| 121 |
+
|
| 122 |
+
original_weight = weight.t().to(torch.float32) * weight_scale
|
| 123 |
+
|
| 124 |
+
new_weight = original_weight + deltas.to(torch.float32)
|
| 125 |
+
|
| 126 |
+
new_fp8_weight, new_weight_scale = quantize_weight_to_fp8_per_tensor(new_weight)
|
| 127 |
+
return {key: new_fp8_weight, scale_key: new_weight_scale}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _fuse_delta_with_cast_fp8(
|
| 131 |
+
deltas: torch.Tensor,
|
| 132 |
+
weight: torch.Tensor,
|
| 133 |
+
key: str,
|
| 134 |
+
target_dtype: torch.dtype,
|
| 135 |
+
device: torch.device,
|
| 136 |
+
) -> dict[str, torch.Tensor]:
|
| 137 |
+
"""Fuse LoRA delta with cast-only FP8 weight (no scale factor)."""
|
| 138 |
+
if str(device).startswith("cuda"):
|
| 139 |
+
deltas = calculate_weight_float8(deltas, weight)
|
| 140 |
+
else:
|
| 141 |
+
deltas.add_(weight.to(dtype=deltas.dtype, device=device))
|
| 142 |
+
return {key: deltas.to(dtype=target_dtype)}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _fuse_delta_with_bfloat16(
|
| 146 |
+
deltas: torch.Tensor,
|
| 147 |
+
weight: torch.Tensor,
|
| 148 |
+
key: str,
|
| 149 |
+
target_dtype: torch.dtype,
|
| 150 |
+
) -> dict[str, torch.Tensor]:
|
| 151 |
+
"""Fuse LoRA delta with bfloat16 weight."""
|
| 152 |
+
deltas.add_(weight)
|
| 153 |
+
return {key: deltas.to(dtype=target_dtype)}
|
packages/ltx-core/src/ltx_core/loader/kernels.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ruff: noqa: ANN001, ANN201, ERA001, N803, N806
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@triton.jit
|
| 7 |
+
def fused_add_round_kernel(
|
| 8 |
+
x_ptr,
|
| 9 |
+
output_ptr, # contents will be added to the output
|
| 10 |
+
seed,
|
| 11 |
+
n_elements,
|
| 12 |
+
EXPONENT_BIAS,
|
| 13 |
+
MANTISSA_BITS,
|
| 14 |
+
BLOCK_SIZE: tl.constexpr,
|
| 15 |
+
):
|
| 16 |
+
"""
|
| 17 |
+
A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding
|
| 18 |
+
and add them to bfloat16 output weights. Might be used to upcast original model weights
|
| 19 |
+
and to further add them to precalculated deltas coming from LoRAs.
|
| 20 |
+
"""
|
| 21 |
+
# Get program ID and compute offsets
|
| 22 |
+
pid = tl.program_id(axis=0)
|
| 23 |
+
block_start = pid * BLOCK_SIZE
|
| 24 |
+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
| 25 |
+
mask = offsets < n_elements
|
| 26 |
+
|
| 27 |
+
# Load data
|
| 28 |
+
x = tl.load(x_ptr + offsets, mask=mask)
|
| 29 |
+
rand_vals = tl.rand(seed, offsets) - 0.5
|
| 30 |
+
|
| 31 |
+
x = tl.cast(x, tl.float16)
|
| 32 |
+
delta = tl.load(output_ptr + offsets, mask=mask)
|
| 33 |
+
delta = tl.cast(delta, tl.float16)
|
| 34 |
+
x = x + delta
|
| 35 |
+
|
| 36 |
+
x_bits = tl.cast(x, tl.int16, bitcast=True)
|
| 37 |
+
|
| 38 |
+
# Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for
|
| 39 |
+
# normal numbers and -14 for subnormals.
|
| 40 |
+
fp16_exponent_bits = (x_bits & 0x7C00) >> 10
|
| 41 |
+
fp16_normals = fp16_exponent_bits > 0
|
| 42 |
+
fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14)
|
| 43 |
+
|
| 44 |
+
# Add the target dtype's exponent bias and clamp to the target dtype's exponent range.
|
| 45 |
+
exponent = fp16_exponent + EXPONENT_BIAS
|
| 46 |
+
MAX_EXPONENT = 2 * EXPONENT_BIAS + 1
|
| 47 |
+
exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent)
|
| 48 |
+
exponent = tl.where(exponent < 0, 0, exponent)
|
| 49 |
+
|
| 50 |
+
# Normal ULP exponent, expressed as an fp16 exponent field:
|
| 51 |
+
# (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15
|
| 52 |
+
# Simplifies to: fp16_exponent - MANTISSA_BITS + 15
|
| 53 |
+
# See https://en.wikipedia.org/wiki/Unit_in_the_last_place
|
| 54 |
+
eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15))
|
| 55 |
+
|
| 56 |
+
# Calculate epsilon in the target dtype
|
| 57 |
+
eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True)
|
| 58 |
+
|
| 59 |
+
# Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) ->
|
| 60 |
+
# fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 =
|
| 61 |
+
# 16 - EXPONENT_BIAS - MANTISSA_BITS
|
| 62 |
+
eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True)
|
| 63 |
+
eps = tl.where(exponent > 0, eps_normal, eps_subnormal)
|
| 64 |
+
|
| 65 |
+
# Apply zero mask to epsilon
|
| 66 |
+
eps = tl.where(x == 0, 0.0, eps)
|
| 67 |
+
|
| 68 |
+
# Apply stochastic rounding
|
| 69 |
+
output = tl.cast(x + rand_vals * eps, tl.bfloat16)
|
| 70 |
+
|
| 71 |
+
# Store the result
|
| 72 |
+
tl.store(output_ptr + offsets, output, mask=mask)
|
packages/ltx-core/src/ltx_core/loader/module_ops.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, NamedTuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ModuleOps(NamedTuple):
|
| 7 |
+
"""
|
| 8 |
+
Defines a named operation for matching and mutating PyTorch modules.
|
| 9 |
+
Used to selectively transform modules in a model (e.g., replacing layers with quantized versions).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
name: str
|
| 13 |
+
matcher: Callable[[torch.nn.Module], bool]
|
| 14 |
+
mutator: Callable[[torch.nn.Module], torch.nn.Module]
|
packages/ltx-core/src/ltx_core/loader/primitives.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import NamedTuple, Protocol
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ltx_core.loader.module_ops import ModuleOps
|
| 7 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 8 |
+
from ltx_core.model.model_protocol import ModelType
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass(frozen=True)
|
| 12 |
+
class StateDict:
|
| 13 |
+
"""
|
| 14 |
+
Immutable container for a PyTorch state dictionary.
|
| 15 |
+
Contains:
|
| 16 |
+
- sd: Dictionary of tensors (weights, buffers, etc.)
|
| 17 |
+
- device: Device where tensors are stored
|
| 18 |
+
- size: Total memory footprint in bytes
|
| 19 |
+
- dtype: Set of tensor dtypes present
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
sd: dict
|
| 23 |
+
device: torch.device
|
| 24 |
+
size: int
|
| 25 |
+
dtype: set[torch.dtype]
|
| 26 |
+
|
| 27 |
+
def footprint(self) -> tuple[int, torch.device]:
|
| 28 |
+
return self.size, self.device
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class StateDictLoader(Protocol):
|
| 32 |
+
"""
|
| 33 |
+
Protocol for loading state dictionaries from various sources.
|
| 34 |
+
Implementations must provide:
|
| 35 |
+
- metadata: Extract model metadata from a single path
|
| 36 |
+
- load: Load state dict from path(s) and apply SDOps transformations
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def metadata(self, path: str) -> dict:
|
| 40 |
+
"""
|
| 41 |
+
Load metadata from path
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
|
| 45 |
+
"""
|
| 46 |
+
Load state dict from path or paths (for sharded model storage) and apply sd_ops
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ModelBuilderProtocol(Protocol[ModelType]):
|
| 51 |
+
"""
|
| 52 |
+
Protocol for building PyTorch models from configuration dictionaries.
|
| 53 |
+
Implementations must provide:
|
| 54 |
+
- meta_model: Create a model from configuration dictionary and apply module operations
|
| 55 |
+
- build: Create and initialize a model from state dictionary and apply dtype transformations
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType:
|
| 59 |
+
"""
|
| 60 |
+
Create a model on the meta device from a configuration dictionary.
|
| 61 |
+
This decouples model creation from weight loading, allowing the model
|
| 62 |
+
architecture to be instantiated without allocating memory for parameters.
|
| 63 |
+
Args:
|
| 64 |
+
config: Model configuration dictionary.
|
| 65 |
+
module_ops: Optional list of module operations to apply (e.g., quantization).
|
| 66 |
+
Returns:
|
| 67 |
+
Model instance on meta device (no actual memory allocated for parameters).
|
| 68 |
+
"""
|
| 69 |
+
...
|
| 70 |
+
|
| 71 |
+
def build(self, dtype: torch.dtype | None = None) -> ModelType:
|
| 72 |
+
"""
|
| 73 |
+
Build the model
|
| 74 |
+
Args:
|
| 75 |
+
dtype: Target dtype for the model, if None, uses the dtype of the model_path model
|
| 76 |
+
Returns:
|
| 77 |
+
Model instance
|
| 78 |
+
"""
|
| 79 |
+
...
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class LoRAAdaptableProtocol(Protocol):
|
| 83 |
+
"""
|
| 84 |
+
Protocol for models that can be adapted with LoRAs.
|
| 85 |
+
Implementations must provide:
|
| 86 |
+
- lora: Add a LoRA to the model
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol":
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class LoraPathStrengthAndSDOps(NamedTuple):
|
| 94 |
+
"""
|
| 95 |
+
Tuple containing a LoRA path, strength, and SDOps for applying to the LoRA state dict.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
path: str
|
| 99 |
+
strength: float
|
| 100 |
+
sd_ops: SDOps
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class LoraStateDictWithStrength(NamedTuple):
|
| 104 |
+
"""
|
| 105 |
+
Tuple containing a LoRA state dict and strength for applying to the model.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
state_dict: StateDict
|
| 109 |
+
strength: float
|
packages/ltx-core/src/ltx_core/loader/registry.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import threading
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Protocol
|
| 6 |
+
|
| 7 |
+
from ltx_core.loader.primitives import StateDict
|
| 8 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Registry(Protocol):
|
| 12 |
+
"""
|
| 13 |
+
Protocol for managing state dictionaries in a registry.
|
| 14 |
+
It is used to store state dictionaries and reuse them later without loading them again.
|
| 15 |
+
Implementations must provide:
|
| 16 |
+
- add: Add a state dictionary to the registry
|
| 17 |
+
- pop: Remove a state dictionary from the registry
|
| 18 |
+
- get: Retrieve a state dictionary from the registry
|
| 19 |
+
- clear: Clear all state dictionaries from the registry
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: ...
|
| 23 |
+
|
| 24 |
+
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
|
| 25 |
+
|
| 26 |
+
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
|
| 27 |
+
|
| 28 |
+
def clear(self) -> None: ...
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DummyRegistry(Registry):
|
| 32 |
+
"""
|
| 33 |
+
Dummy registry that does not store state dictionaries.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None:
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
def clear(self) -> None:
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class StateDictRegistry(Registry):
|
| 51 |
+
"""
|
| 52 |
+
Registry that stores state dictionaries in a dictionary.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
_state_dicts: dict[str, StateDict] = field(default_factory=dict)
|
| 56 |
+
_lock: threading.Lock = field(default_factory=threading.Lock)
|
| 57 |
+
|
| 58 |
+
def _generate_id(self, paths: list[str], sd_ops: SDOps) -> str:
|
| 59 |
+
m = hashlib.sha256()
|
| 60 |
+
parts = [str(Path(p).resolve()) for p in paths]
|
| 61 |
+
if sd_ops is not None:
|
| 62 |
+
parts.append(sd_ops.name)
|
| 63 |
+
m.update("\0".join(parts).encode("utf-8"))
|
| 64 |
+
return m.hexdigest()
|
| 65 |
+
|
| 66 |
+
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> str:
|
| 67 |
+
sd_id = self._generate_id(paths, sd_ops)
|
| 68 |
+
with self._lock:
|
| 69 |
+
if sd_id in self._state_dicts:
|
| 70 |
+
raise ValueError(f"State dict retrieved from {paths} with {sd_ops} already added, check with get first")
|
| 71 |
+
self._state_dicts[sd_id] = state_dict
|
| 72 |
+
return sd_id
|
| 73 |
+
|
| 74 |
+
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 75 |
+
with self._lock:
|
| 76 |
+
return self._state_dicts.pop(self._generate_id(paths, sd_ops), None)
|
| 77 |
+
|
| 78 |
+
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 79 |
+
with self._lock:
|
| 80 |
+
return self._state_dicts.get(self._generate_id(paths, sd_ops), None)
|
| 81 |
+
|
| 82 |
+
def clear(self) -> None:
|
| 83 |
+
with self._lock:
|
| 84 |
+
self._state_dicts.clear()
|
packages/ltx-core/src/ltx_core/loader/sd_ops.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, replace
|
| 2 |
+
from typing import NamedTuple, Protocol
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass(frozen=True, slots=True)
|
| 8 |
+
class ContentReplacement:
|
| 9 |
+
"""
|
| 10 |
+
Represents a content replacement operation.
|
| 11 |
+
Used to replace a specific content with a replacement in a state dict key.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
content: str
|
| 15 |
+
replacement: str
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(frozen=True, slots=True)
|
| 19 |
+
class ContentMatching:
|
| 20 |
+
"""
|
| 21 |
+
Represents a content matching operation.
|
| 22 |
+
Used to match a specific prefix and suffix in a state dict key.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
prefix: str = ""
|
| 26 |
+
suffix: str = ""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class KeyValueOperationResult(NamedTuple):
|
| 30 |
+
"""
|
| 31 |
+
Represents the result of a key-value operation.
|
| 32 |
+
Contains the new key and value after the operation has been applied.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
new_key: str
|
| 36 |
+
new_value: torch.Tensor
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class KeyValueOperation(Protocol):
|
| 40 |
+
"""
|
| 41 |
+
Protocol for key-value operations.
|
| 42 |
+
Used to apply operations to a specific key and value in a state dict.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ...
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass(frozen=True, slots=True)
|
| 49 |
+
class SDKeyValueOperation:
|
| 50 |
+
"""
|
| 51 |
+
Represents a key-value operation.
|
| 52 |
+
Used to apply operations to a specific key and value in a state dict.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
key_matcher: ContentMatching
|
| 56 |
+
kv_operation: KeyValueOperation
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass(frozen=True, slots=True)
|
| 60 |
+
class SDOps:
|
| 61 |
+
"""Immutable class representing state dict key operations."""
|
| 62 |
+
|
| 63 |
+
name: str
|
| 64 |
+
mapping: tuple[
|
| 65 |
+
ContentReplacement | ContentMatching | SDKeyValueOperation, ...
|
| 66 |
+
] = () # Immutable tuple of (key, value) pairs
|
| 67 |
+
|
| 68 |
+
def with_replacement(self, content: str, replacement: str) -> "SDOps":
|
| 69 |
+
"""Create a new SDOps instance with the specified replacement added to the mapping."""
|
| 70 |
+
|
| 71 |
+
new_mapping = (*self.mapping, ContentReplacement(content, replacement))
|
| 72 |
+
return replace(self, mapping=new_mapping)
|
| 73 |
+
|
| 74 |
+
def with_matching(self, prefix: str = "", suffix: str = "") -> "SDOps":
|
| 75 |
+
"""Create a new SDOps instance with the specified prefix and suffix matching added to the mapping."""
|
| 76 |
+
|
| 77 |
+
new_mapping = (*self.mapping, ContentMatching(prefix, suffix))
|
| 78 |
+
return replace(self, mapping=new_mapping)
|
| 79 |
+
|
| 80 |
+
def with_kv_operation(
|
| 81 |
+
self,
|
| 82 |
+
operation: KeyValueOperation,
|
| 83 |
+
key_prefix: str = "",
|
| 84 |
+
key_suffix: str = "",
|
| 85 |
+
) -> "SDOps":
|
| 86 |
+
"""Create a new SDOps instance with the specified value operation added to the mapping."""
|
| 87 |
+
key_matcher = ContentMatching(key_prefix, key_suffix)
|
| 88 |
+
sd_kv_operation = SDKeyValueOperation(key_matcher, operation)
|
| 89 |
+
new_mapping = (*self.mapping, sd_kv_operation)
|
| 90 |
+
return replace(self, mapping=new_mapping)
|
| 91 |
+
|
| 92 |
+
def apply_to_key(self, key: str) -> str | None:
|
| 93 |
+
"""Apply the mapping to the given name."""
|
| 94 |
+
matchers = [content for content in self.mapping if isinstance(content, ContentMatching)]
|
| 95 |
+
valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers)
|
| 96 |
+
if not valid:
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
for replacement in self.mapping:
|
| 100 |
+
if not isinstance(replacement, ContentReplacement):
|
| 101 |
+
continue
|
| 102 |
+
if replacement.content in key:
|
| 103 |
+
key = key.replace(replacement.content, replacement.replacement)
|
| 104 |
+
return key
|
| 105 |
+
|
| 106 |
+
def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
|
| 107 |
+
"""Apply the value operation to the given name and associated value."""
|
| 108 |
+
for operation in self.mapping:
|
| 109 |
+
if not isinstance(operation, SDKeyValueOperation):
|
| 110 |
+
continue
|
| 111 |
+
if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix):
|
| 112 |
+
return operation.kv_operation(key, value)
|
| 113 |
+
return [KeyValueOperationResult(key, value)]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Predefined SDOps instances
|
| 117 |
+
LTXV_LORA_COMFY_RENAMING_MAP = (
|
| 118 |
+
SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "")
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
LTXV_LORA_COMFY_TARGET_MAP = (
|
| 122 |
+
SDOps("LTXV_LORA_COMFY_TARGET_MAP")
|
| 123 |
+
.with_matching()
|
| 124 |
+
.with_replacement("diffusion_model.", "")
|
| 125 |
+
.with_replacement(".lora_A.weight", ".weight")
|
| 126 |
+
.with_replacement(".lora_B.weight", ".weight")
|
| 127 |
+
)
|
packages/ltx-core/src/ltx_core/loader/sft_loader.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
import safetensors
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ltx_core.loader.primitives import StateDict, StateDictLoader
|
| 7 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SafetensorsStateDictLoader(StateDictLoader):
|
| 11 |
+
"""
|
| 12 |
+
Loads weights from safetensors files without metadata support.
|
| 13 |
+
Use this for loading raw weight files. For model files that include
|
| 14 |
+
configuration metadata, use SafetensorsModelStateDictLoader instead.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def metadata(self, path: str) -> dict:
|
| 18 |
+
raise NotImplementedError("Not implemented")
|
| 19 |
+
|
| 20 |
+
def load(self, path: str | list[str], sd_ops: SDOps, device: torch.device | None = None) -> StateDict:
|
| 21 |
+
"""
|
| 22 |
+
Load state dict from path or paths (for sharded model storage) and apply sd_ops
|
| 23 |
+
"""
|
| 24 |
+
sd = {}
|
| 25 |
+
size = 0
|
| 26 |
+
dtype = set()
|
| 27 |
+
device = device or torch.device("cpu")
|
| 28 |
+
model_paths = path if isinstance(path, list) else [path]
|
| 29 |
+
for shard_path in model_paths:
|
| 30 |
+
with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f:
|
| 31 |
+
safetensor_keys = f.keys()
|
| 32 |
+
for name in safetensor_keys:
|
| 33 |
+
expected_name = name if sd_ops is None else sd_ops.apply_to_key(name)
|
| 34 |
+
if expected_name is None:
|
| 35 |
+
continue
|
| 36 |
+
value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False)
|
| 37 |
+
key_value_pairs = ((expected_name, value),)
|
| 38 |
+
if sd_ops is not None:
|
| 39 |
+
key_value_pairs = sd_ops.apply_to_key_value(expected_name, value)
|
| 40 |
+
for key, value in key_value_pairs:
|
| 41 |
+
size += value.nbytes
|
| 42 |
+
dtype.add(value.dtype)
|
| 43 |
+
sd[key] = value
|
| 44 |
+
|
| 45 |
+
return StateDict(sd=sd, device=device, size=size, dtype=dtype)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class SafetensorsModelStateDictLoader(StateDictLoader):
|
| 49 |
+
"""
|
| 50 |
+
Loads weights and configuration metadata from safetensors model files.
|
| 51 |
+
Unlike SafetensorsStateDictLoader, this loader can read model configuration
|
| 52 |
+
from the safetensors file metadata via the metadata() method.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None):
|
| 56 |
+
self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader()
|
| 57 |
+
|
| 58 |
+
def metadata(self, path: str) -> dict:
|
| 59 |
+
with safetensors.safe_open(path, framework="pt") as f:
|
| 60 |
+
meta = f.metadata()
|
| 61 |
+
if meta is None or "config" not in meta:
|
| 62 |
+
return {}
|
| 63 |
+
return json.loads(meta["config"])
|
| 64 |
+
|
| 65 |
+
def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
|
| 66 |
+
return self.weight_loader.load(path, sd_ops, device)
|
packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from dataclasses import dataclass, field, replace
|
| 3 |
+
from typing import Generic
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ltx_core.loader.fuse_loras import apply_loras
|
| 8 |
+
from ltx_core.loader.module_ops import ModuleOps
|
| 9 |
+
from ltx_core.loader.primitives import (
|
| 10 |
+
LoRAAdaptableProtocol,
|
| 11 |
+
LoraPathStrengthAndSDOps,
|
| 12 |
+
LoraStateDictWithStrength,
|
| 13 |
+
ModelBuilderProtocol,
|
| 14 |
+
StateDict,
|
| 15 |
+
StateDictLoader,
|
| 16 |
+
)
|
| 17 |
+
from ltx_core.loader.registry import DummyRegistry, Registry
|
| 18 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 19 |
+
from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader
|
| 20 |
+
from ltx_core.model.model_protocol import ModelConfigurator, ModelType
|
| 21 |
+
|
| 22 |
+
logger: logging.Logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass(frozen=True)
|
| 26 |
+
class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol):
|
| 27 |
+
"""
|
| 28 |
+
Builder for PyTorch models residing on a single GPU.
|
| 29 |
+
Attributes:
|
| 30 |
+
model_class_configurator: Class responsible for constructing the model from a config dict.
|
| 31 |
+
model_path: Path (or tuple of shard paths) to the model's `.safetensors` checkpoint(s).
|
| 32 |
+
model_sd_ops: Optional state-dict operations applied when loading the model weights.
|
| 33 |
+
module_ops: Sequence of module-level mutations applied to the meta model before weight loading.
|
| 34 |
+
loras: Sequence of LoRA adapters (path, strength, optional sd_ops) to fuse into the model.
|
| 35 |
+
model_loader: Strategy for loading state dicts from disk. Defaults to
|
| 36 |
+
:class:`SafetensorsModelStateDictLoader`.
|
| 37 |
+
registry: Cache for already-loaded state dicts. Defaults to :class:`DummyRegistry` (no caching).
|
| 38 |
+
lora_load_device: Device used when loading LoRA weight tensors from disk. Defaults to
|
| 39 |
+
``torch.device("cpu")``, which keeps LoRA weights in CPU memory and transfers them to
|
| 40 |
+
the target GPU sequentially during fusion, reducing peak GPU memory usage compared to
|
| 41 |
+
loading all LoRA weights directly onto the GPU at once.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
model_class_configurator: type[ModelConfigurator[ModelType]]
|
| 45 |
+
model_path: str | tuple[str, ...]
|
| 46 |
+
model_sd_ops: SDOps | None = None
|
| 47 |
+
module_ops: tuple[ModuleOps, ...] = field(default_factory=tuple)
|
| 48 |
+
loras: tuple[LoraPathStrengthAndSDOps, ...] = field(default_factory=tuple)
|
| 49 |
+
model_loader: StateDictLoader = field(default_factory=SafetensorsModelStateDictLoader)
|
| 50 |
+
registry: Registry = field(default_factory=DummyRegistry)
|
| 51 |
+
lora_load_device: torch.device = field(default_factory=lambda: torch.device("cpu"))
|
| 52 |
+
|
| 53 |
+
def lora(self, lora_path: str, strength: float = 1.0, sd_ops: SDOps | None = None) -> "SingleGPUModelBuilder":
|
| 54 |
+
return replace(self, loras=(*self.loras, LoraPathStrengthAndSDOps(lora_path, strength, sd_ops)))
|
| 55 |
+
|
| 56 |
+
def model_config(self) -> dict:
|
| 57 |
+
first_shard_path = self.model_path[0] if isinstance(self.model_path, tuple) else self.model_path
|
| 58 |
+
return self.model_loader.metadata(first_shard_path)
|
| 59 |
+
|
| 60 |
+
def meta_model(self, config: dict, module_ops: tuple[ModuleOps, ...]) -> ModelType:
|
| 61 |
+
with torch.device("meta"):
|
| 62 |
+
model = self.model_class_configurator.from_config(config)
|
| 63 |
+
for module_op in module_ops:
|
| 64 |
+
if module_op.matcher(model):
|
| 65 |
+
model = module_op.mutator(model)
|
| 66 |
+
return model
|
| 67 |
+
|
| 68 |
+
def load_sd(
|
| 69 |
+
self, paths: list[str], registry: Registry, device: torch.device | None, sd_ops: SDOps | None = None
|
| 70 |
+
) -> StateDict:
|
| 71 |
+
state_dict = registry.get(paths, sd_ops)
|
| 72 |
+
if state_dict is None:
|
| 73 |
+
state_dict = self.model_loader.load(paths, sd_ops=sd_ops, device=device)
|
| 74 |
+
registry.add(paths, sd_ops=sd_ops, state_dict=state_dict)
|
| 75 |
+
return state_dict
|
| 76 |
+
|
| 77 |
+
def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelType:
|
| 78 |
+
uninitialized_params = [name for name, param in meta_model.named_parameters() if str(param.device) == "meta"]
|
| 79 |
+
uninitialized_buffers = [name for name, buffer in meta_model.named_buffers() if str(buffer.device) == "meta"]
|
| 80 |
+
if uninitialized_params or uninitialized_buffers:
|
| 81 |
+
logger.warning(f"Uninitialized parameters or buffers: {uninitialized_params + uninitialized_buffers}")
|
| 82 |
+
return meta_model
|
| 83 |
+
retval = meta_model.to(device)
|
| 84 |
+
return retval
|
| 85 |
+
|
| 86 |
+
def build(self, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ModelType:
|
| 87 |
+
device = torch.device("cuda") if device is None else device
|
| 88 |
+
config = self.model_config()
|
| 89 |
+
meta_model = self.meta_model(config, self.module_ops)
|
| 90 |
+
model_paths = list(self.model_path) if isinstance(self.model_path, tuple) else [self.model_path]
|
| 91 |
+
model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device)
|
| 92 |
+
|
| 93 |
+
lora_strengths = [lora.strength for lora in self.loras]
|
| 94 |
+
if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0):
|
| 95 |
+
sd = model_state_dict.sd
|
| 96 |
+
if dtype is not None:
|
| 97 |
+
sd = {key: value.to(dtype=dtype) for key, value in model_state_dict.sd.items()}
|
| 98 |
+
meta_model.load_state_dict(sd, strict=False, assign=True)
|
| 99 |
+
return self._return_model(meta_model, device)
|
| 100 |
+
|
| 101 |
+
lora_state_dicts = [
|
| 102 |
+
self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=self.lora_load_device)
|
| 103 |
+
for lora in self.loras
|
| 104 |
+
]
|
| 105 |
+
lora_sd_and_strengths = [
|
| 106 |
+
LoraStateDictWithStrength(sd, strength)
|
| 107 |
+
for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True)
|
| 108 |
+
]
|
| 109 |
+
final_sd = apply_loras(
|
| 110 |
+
model_sd=model_state_dict,
|
| 111 |
+
lora_sd_and_strengths=lora_sd_and_strengths,
|
| 112 |
+
dtype=dtype,
|
| 113 |
+
destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None,
|
| 114 |
+
)
|
| 115 |
+
meta_model.load_state_dict(final_sd.sd, strict=False, assign=True)
|
| 116 |
+
return self._return_model(meta_model, device)
|
packages/ltx-core/src/ltx_core/model/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model definitions for LTX-2."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.model_protocol import ModelConfigurator, ModelType
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"ModelConfigurator",
|
| 7 |
+
"ModelType",
|
| 8 |
+
]
|
packages/ltx-core/src/ltx_core/model/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (358 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/model/__pycache__/model_protocol.cpython-312.pyc
ADDED
|
Binary file (807 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio VAE model components."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder, decode_audio, encode_audio
|
| 4 |
+
from ltx_core.model.audio_vae.model_configurator import (
|
| 5 |
+
AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
|
| 6 |
+
AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
|
| 7 |
+
VOCODER_COMFY_KEYS_FILTER,
|
| 8 |
+
AudioDecoderConfigurator,
|
| 9 |
+
AudioEncoderConfigurator,
|
| 10 |
+
VocoderConfigurator,
|
| 11 |
+
)
|
| 12 |
+
from ltx_core.model.audio_vae.ops import AudioProcessor
|
| 13 |
+
from ltx_core.model.audio_vae.vocoder import Vocoder, VocoderWithBWE
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"AUDIO_VAE_DECODER_COMFY_KEYS_FILTER",
|
| 17 |
+
"AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER",
|
| 18 |
+
"VOCODER_COMFY_KEYS_FILTER",
|
| 19 |
+
"AudioDecoder",
|
| 20 |
+
"AudioDecoderConfigurator",
|
| 21 |
+
"AudioEncoder",
|
| 22 |
+
"AudioEncoderConfigurator",
|
| 23 |
+
"AudioProcessor",
|
| 24 |
+
"Vocoder",
|
| 25 |
+
"VocoderConfigurator",
|
| 26 |
+
"VocoderWithBWE",
|
| 27 |
+
"decode_audio",
|
| 28 |
+
"encode_audio",
|
| 29 |
+
]
|
packages/ltx-core/src/ltx_core/model/audio_vae/attention.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.common.normalization import NormType, build_normalization_layer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AttentionType(Enum):
|
| 9 |
+
"""Enum for specifying the attention mechanism type."""
|
| 10 |
+
|
| 11 |
+
VANILLA = "vanilla"
|
| 12 |
+
LINEAR = "linear"
|
| 13 |
+
NONE = "none"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AttnBlock(torch.nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_channels: int,
|
| 20 |
+
norm_type: NormType = NormType.GROUP,
|
| 21 |
+
) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.in_channels = in_channels
|
| 24 |
+
|
| 25 |
+
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
|
| 26 |
+
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 27 |
+
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 28 |
+
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 29 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 30 |
+
|
| 31 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
h_ = x
|
| 33 |
+
h_ = self.norm(h_)
|
| 34 |
+
q = self.q(h_)
|
| 35 |
+
k = self.k(h_)
|
| 36 |
+
v = self.v(h_)
|
| 37 |
+
|
| 38 |
+
# compute attention
|
| 39 |
+
b, c, h, w = q.shape
|
| 40 |
+
q = q.reshape(b, c, h * w).contiguous()
|
| 41 |
+
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
| 42 |
+
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
| 43 |
+
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 44 |
+
w_ = w_ * (int(c) ** (-0.5))
|
| 45 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 46 |
+
|
| 47 |
+
# attend to values
|
| 48 |
+
v = v.reshape(b, c, h * w).contiguous()
|
| 49 |
+
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
| 50 |
+
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 51 |
+
h_ = h_.reshape(b, c, h, w).contiguous()
|
| 52 |
+
|
| 53 |
+
h_ = self.proj_out(h_)
|
| 54 |
+
|
| 55 |
+
return x + h_
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def make_attn(
|
| 59 |
+
in_channels: int,
|
| 60 |
+
attn_type: AttentionType = AttentionType.VANILLA,
|
| 61 |
+
norm_type: NormType = NormType.GROUP,
|
| 62 |
+
) -> torch.nn.Module:
|
| 63 |
+
match attn_type:
|
| 64 |
+
case AttentionType.VANILLA:
|
| 65 |
+
return AttnBlock(in_channels, norm_type=norm_type)
|
| 66 |
+
case AttentionType.NONE:
|
| 67 |
+
return torch.nn.Identity()
|
| 68 |
+
case AttentionType.LINEAR:
|
| 69 |
+
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
| 70 |
+
case _:
|
| 71 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Set, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from ltx_core.components.patchifiers import AudioPatchifier
|
| 7 |
+
from ltx_core.model.audio_vae.attention import AttentionType, make_attn
|
| 8 |
+
from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
|
| 9 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 10 |
+
from ltx_core.model.audio_vae.downsample import build_downsampling_path
|
| 11 |
+
from ltx_core.model.audio_vae.ops import AudioProcessor, PerChannelStatistics
|
| 12 |
+
from ltx_core.model.audio_vae.resnet import ResnetBlock
|
| 13 |
+
from ltx_core.model.audio_vae.upsample import build_upsampling_path
|
| 14 |
+
from ltx_core.model.audio_vae.vocoder import Vocoder
|
| 15 |
+
from ltx_core.model.common.normalization import NormType, build_normalization_layer
|
| 16 |
+
from ltx_core.types import Audio, AudioLatentShape
|
| 17 |
+
|
| 18 |
+
LATENT_DOWNSAMPLE_FACTOR = 4
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_mid_block(
|
| 22 |
+
channels: int,
|
| 23 |
+
temb_channels: int,
|
| 24 |
+
dropout: float,
|
| 25 |
+
norm_type: NormType,
|
| 26 |
+
causality_axis: CausalityAxis,
|
| 27 |
+
attn_type: AttentionType,
|
| 28 |
+
add_attention: bool,
|
| 29 |
+
) -> torch.nn.Module:
|
| 30 |
+
"""Build the middle block with two ResNet blocks and optional attention."""
|
| 31 |
+
mid = torch.nn.Module()
|
| 32 |
+
mid.block_1 = ResnetBlock(
|
| 33 |
+
in_channels=channels,
|
| 34 |
+
out_channels=channels,
|
| 35 |
+
temb_channels=temb_channels,
|
| 36 |
+
dropout=dropout,
|
| 37 |
+
norm_type=norm_type,
|
| 38 |
+
causality_axis=causality_axis,
|
| 39 |
+
)
|
| 40 |
+
mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity()
|
| 41 |
+
mid.block_2 = ResnetBlock(
|
| 42 |
+
in_channels=channels,
|
| 43 |
+
out_channels=channels,
|
| 44 |
+
temb_channels=temb_channels,
|
| 45 |
+
dropout=dropout,
|
| 46 |
+
norm_type=norm_type,
|
| 47 |
+
causality_axis=causality_axis,
|
| 48 |
+
)
|
| 49 |
+
return mid
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
"""Run features through the middle block."""
|
| 54 |
+
features = mid.block_1(features, temb=None)
|
| 55 |
+
features = mid.attn_1(features)
|
| 56 |
+
return mid.block_2(features, temb=None)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class AudioEncoder(torch.nn.Module):
|
| 60 |
+
"""
|
| 61 |
+
Encoder that compresses audio spectrograms into latent representations.
|
| 62 |
+
The encoder uses a series of downsampling blocks with residual connections,
|
| 63 |
+
attention mechanisms, and configurable causal convolutions.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__( # noqa: PLR0913
|
| 67 |
+
self,
|
| 68 |
+
*,
|
| 69 |
+
ch: int,
|
| 70 |
+
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
| 71 |
+
num_res_blocks: int,
|
| 72 |
+
attn_resolutions: Set[int],
|
| 73 |
+
dropout: float = 0.0,
|
| 74 |
+
resamp_with_conv: bool = True,
|
| 75 |
+
in_channels: int,
|
| 76 |
+
resolution: int,
|
| 77 |
+
z_channels: int,
|
| 78 |
+
double_z: bool = True,
|
| 79 |
+
attn_type: AttentionType = AttentionType.VANILLA,
|
| 80 |
+
mid_block_add_attention: bool = True,
|
| 81 |
+
norm_type: NormType = NormType.GROUP,
|
| 82 |
+
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
| 83 |
+
sample_rate: int = 16000,
|
| 84 |
+
mel_hop_length: int = 160,
|
| 85 |
+
n_fft: int = 1024,
|
| 86 |
+
is_causal: bool = True,
|
| 87 |
+
mel_bins: int = 64,
|
| 88 |
+
**_ignore_kwargs,
|
| 89 |
+
) -> None:
|
| 90 |
+
"""
|
| 91 |
+
Initialize the Encoder.
|
| 92 |
+
Args:
|
| 93 |
+
Arguments are configuration parameters, loaded from the audio VAE checkpoint config
|
| 94 |
+
(audio_vae.model.params.ddconfig):
|
| 95 |
+
ch: Base number of feature channels used in the first convolution layer.
|
| 96 |
+
ch_mult: Multiplicative factors for the number of channels at each resolution level.
|
| 97 |
+
num_res_blocks: Number of residual blocks to use at each resolution level.
|
| 98 |
+
attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention.
|
| 99 |
+
resolution: Input spatial resolution of the spectrogram (height, width).
|
| 100 |
+
z_channels: Number of channels in the latent representation.
|
| 101 |
+
norm_type: Normalization layer type to use within the network (e.g., group, batch).
|
| 102 |
+
causality_axis: Axis along which convolutions should be causal (e.g., time axis).
|
| 103 |
+
sample_rate: Audio sample rate in Hz for the input signals.
|
| 104 |
+
mel_hop_length: Hop length used when computing the mel spectrogram.
|
| 105 |
+
n_fft: FFT size used to compute the spectrogram.
|
| 106 |
+
mel_bins: Number of mel-frequency bins in the input spectrogram.
|
| 107 |
+
in_channels: Number of channels in the input spectrogram tensor.
|
| 108 |
+
double_z: If True, predict both mean and log-variance (doubling latent channels).
|
| 109 |
+
is_causal: If True, use causal convolutions suitable for streaming setups.
|
| 110 |
+
dropout: Dropout probability used in residual and mid blocks.
|
| 111 |
+
attn_type: Type of attention mechanism to use in attention blocks.
|
| 112 |
+
resamp_with_conv: If True, perform resolution changes using strided convolutions.
|
| 113 |
+
mid_block_add_attention: If True, add an attention block in the mid-level of the encoder.
|
| 114 |
+
"""
|
| 115 |
+
super().__init__()
|
| 116 |
+
|
| 117 |
+
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
|
| 118 |
+
self.sample_rate = sample_rate
|
| 119 |
+
self.mel_hop_length = mel_hop_length
|
| 120 |
+
self.n_fft = n_fft
|
| 121 |
+
self.is_causal = is_causal
|
| 122 |
+
self.mel_bins = mel_bins
|
| 123 |
+
|
| 124 |
+
self.patchifier = AudioPatchifier(
|
| 125 |
+
patch_size=1,
|
| 126 |
+
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
| 127 |
+
sample_rate=sample_rate,
|
| 128 |
+
hop_length=mel_hop_length,
|
| 129 |
+
is_causal=is_causal,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.ch = ch
|
| 133 |
+
self.temb_ch = 0
|
| 134 |
+
self.num_resolutions = len(ch_mult)
|
| 135 |
+
self.num_res_blocks = num_res_blocks
|
| 136 |
+
self.resolution = resolution
|
| 137 |
+
self.in_channels = in_channels
|
| 138 |
+
self.z_channels = z_channels
|
| 139 |
+
self.double_z = double_z
|
| 140 |
+
self.norm_type = norm_type
|
| 141 |
+
self.causality_axis = causality_axis
|
| 142 |
+
self.attn_type = attn_type
|
| 143 |
+
|
| 144 |
+
# downsampling
|
| 145 |
+
self.conv_in = make_conv2d(
|
| 146 |
+
in_channels,
|
| 147 |
+
self.ch,
|
| 148 |
+
kernel_size=3,
|
| 149 |
+
stride=1,
|
| 150 |
+
causality_axis=self.causality_axis,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
self.non_linearity = torch.nn.SiLU()
|
| 154 |
+
|
| 155 |
+
self.down, block_in = build_downsampling_path(
|
| 156 |
+
ch=ch,
|
| 157 |
+
ch_mult=ch_mult,
|
| 158 |
+
num_resolutions=self.num_resolutions,
|
| 159 |
+
num_res_blocks=num_res_blocks,
|
| 160 |
+
resolution=resolution,
|
| 161 |
+
temb_channels=self.temb_ch,
|
| 162 |
+
dropout=dropout,
|
| 163 |
+
norm_type=self.norm_type,
|
| 164 |
+
causality_axis=self.causality_axis,
|
| 165 |
+
attn_type=self.attn_type,
|
| 166 |
+
attn_resolutions=attn_resolutions,
|
| 167 |
+
resamp_with_conv=resamp_with_conv,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.mid = build_mid_block(
|
| 171 |
+
channels=block_in,
|
| 172 |
+
temb_channels=self.temb_ch,
|
| 173 |
+
dropout=dropout,
|
| 174 |
+
norm_type=self.norm_type,
|
| 175 |
+
causality_axis=self.causality_axis,
|
| 176 |
+
attn_type=self.attn_type,
|
| 177 |
+
add_attention=mid_block_add_attention,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
| 181 |
+
self.conv_out = make_conv2d(
|
| 182 |
+
block_in,
|
| 183 |
+
2 * z_channels if double_z else z_channels,
|
| 184 |
+
kernel_size=3,
|
| 185 |
+
stride=1,
|
| 186 |
+
causality_axis=self.causality_axis,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
|
| 190 |
+
"""
|
| 191 |
+
Encode audio spectrogram into latent representations.
|
| 192 |
+
Args:
|
| 193 |
+
spectrogram: Input spectrogram of shape (batch, channels, time, frequency)
|
| 194 |
+
Returns:
|
| 195 |
+
Encoded latent representation of shape (batch, channels, frames, mel_bins)
|
| 196 |
+
"""
|
| 197 |
+
h = self.conv_in(spectrogram)
|
| 198 |
+
h = self._run_downsampling_path(h)
|
| 199 |
+
h = run_mid_block(self.mid, h)
|
| 200 |
+
h = self._finalize_output(h)
|
| 201 |
+
|
| 202 |
+
return self._normalize_latents(h)
|
| 203 |
+
|
| 204 |
+
def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor:
|
| 205 |
+
for level in range(self.num_resolutions):
|
| 206 |
+
stage = self.down[level]
|
| 207 |
+
for block_idx in range(self.num_res_blocks):
|
| 208 |
+
h = stage.block[block_idx](h, temb=None)
|
| 209 |
+
if stage.attn:
|
| 210 |
+
h = stage.attn[block_idx](h)
|
| 211 |
+
|
| 212 |
+
if level != self.num_resolutions - 1:
|
| 213 |
+
h = stage.downsample(h)
|
| 214 |
+
|
| 215 |
+
return h
|
| 216 |
+
|
| 217 |
+
def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
|
| 218 |
+
h = self.norm_out(h)
|
| 219 |
+
h = self.non_linearity(h)
|
| 220 |
+
return self.conv_out(h)
|
| 221 |
+
|
| 222 |
+
def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor:
|
| 223 |
+
"""
|
| 224 |
+
Normalize encoder latents using per-channel statistics.
|
| 225 |
+
When the encoder is configured with ``double_z=True``, the final
|
| 226 |
+
convolution produces twice the number of latent channels, typically
|
| 227 |
+
interpreted as two concatenated tensors along the channel dimension
|
| 228 |
+
(e.g., mean and variance or other auxiliary parameters).
|
| 229 |
+
This method intentionally uses only the first half of the channels
|
| 230 |
+
(the "mean" component) as input to the patchifier and normalization
|
| 231 |
+
logic. The remaining channels are left unchanged by this method and
|
| 232 |
+
are expected to be consumed elsewhere in the VAE pipeline.
|
| 233 |
+
If ``double_z=False``, the encoder output already contains only the
|
| 234 |
+
mean latents and the chunking operation simply returns that tensor.
|
| 235 |
+
"""
|
| 236 |
+
means = torch.chunk(latent_output, 2, dim=1)[0]
|
| 237 |
+
latent_shape = AudioLatentShape(
|
| 238 |
+
batch=means.shape[0],
|
| 239 |
+
channels=means.shape[1],
|
| 240 |
+
frames=means.shape[2],
|
| 241 |
+
mel_bins=means.shape[3],
|
| 242 |
+
)
|
| 243 |
+
latent_patched = self.patchifier.patchify(means)
|
| 244 |
+
latent_normalized = self.per_channel_statistics.normalize(latent_patched)
|
| 245 |
+
return self.patchifier.unpatchify(latent_normalized, latent_shape)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def encode_audio(
|
| 249 |
+
audio: Audio,
|
| 250 |
+
audio_encoder: AudioEncoder,
|
| 251 |
+
audio_processor: AudioProcessor | None = None,
|
| 252 |
+
) -> torch.Tensor:
|
| 253 |
+
"""Encode audio waveform into latent representation.
|
| 254 |
+
Args:
|
| 255 |
+
audio: Audio container with waveform tensor of shape (batch, channels, samples) and sampling rate.
|
| 256 |
+
audio_encoder: Audio encoder model
|
| 257 |
+
audio_processor: Audio processor model (optional, if not provided, it will be created from the audio encoder)
|
| 258 |
+
"""
|
| 259 |
+
dtype = next(audio_encoder.parameters()).dtype
|
| 260 |
+
device = next(audio_encoder.parameters()).device
|
| 261 |
+
|
| 262 |
+
if audio_processor is None:
|
| 263 |
+
audio_processor = AudioProcessor(
|
| 264 |
+
target_sample_rate=audio_encoder.sample_rate,
|
| 265 |
+
mel_bins=audio_encoder.mel_bins,
|
| 266 |
+
mel_hop_length=audio_encoder.mel_hop_length,
|
| 267 |
+
n_fft=audio_encoder.n_fft,
|
| 268 |
+
).to(device=device)
|
| 269 |
+
|
| 270 |
+
mel_spectrogram = audio_processor.waveform_to_mel(audio.to(device=device))
|
| 271 |
+
|
| 272 |
+
latent = audio_encoder(mel_spectrogram.to(dtype=dtype))
|
| 273 |
+
return latent
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class AudioDecoder(torch.nn.Module):
|
| 277 |
+
"""
|
| 278 |
+
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
| 279 |
+
The decoder mirrors the encoder structure with configurable channel multipliers,
|
| 280 |
+
attention resolutions, and causal convolutions.
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
def __init__( # noqa: PLR0913
|
| 284 |
+
self,
|
| 285 |
+
*,
|
| 286 |
+
ch: int,
|
| 287 |
+
out_ch: int,
|
| 288 |
+
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
| 289 |
+
num_res_blocks: int,
|
| 290 |
+
attn_resolutions: Set[int],
|
| 291 |
+
resolution: int,
|
| 292 |
+
z_channels: int,
|
| 293 |
+
norm_type: NormType = NormType.GROUP,
|
| 294 |
+
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
| 295 |
+
dropout: float = 0.0,
|
| 296 |
+
mid_block_add_attention: bool = True,
|
| 297 |
+
sample_rate: int = 16000,
|
| 298 |
+
mel_hop_length: int = 160,
|
| 299 |
+
is_causal: bool = True,
|
| 300 |
+
mel_bins: int | None = None,
|
| 301 |
+
) -> None:
|
| 302 |
+
"""
|
| 303 |
+
Initialize the Decoder.
|
| 304 |
+
Args:
|
| 305 |
+
Arguments are configuration parameters, loaded from the audio VAE checkpoint config
|
| 306 |
+
(audio_vae.model.params.ddconfig):
|
| 307 |
+
- ch, out_ch, ch_mult, num_res_blocks, attn_resolutions
|
| 308 |
+
- resolution, z_channels
|
| 309 |
+
- norm_type, causality_axis
|
| 310 |
+
"""
|
| 311 |
+
super().__init__()
|
| 312 |
+
|
| 313 |
+
# Internal behavioural defaults that are not driven by the checkpoint.
|
| 314 |
+
resamp_with_conv = True
|
| 315 |
+
attn_type = AttentionType.VANILLA
|
| 316 |
+
|
| 317 |
+
# Per-channel statistics for denormalizing latents
|
| 318 |
+
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
|
| 319 |
+
self.sample_rate = sample_rate
|
| 320 |
+
self.mel_hop_length = mel_hop_length
|
| 321 |
+
self.is_causal = is_causal
|
| 322 |
+
self.mel_bins = mel_bins
|
| 323 |
+
self.patchifier = AudioPatchifier(
|
| 324 |
+
patch_size=1,
|
| 325 |
+
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
| 326 |
+
sample_rate=sample_rate,
|
| 327 |
+
hop_length=mel_hop_length,
|
| 328 |
+
is_causal=is_causal,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
self.ch = ch
|
| 332 |
+
self.temb_ch = 0
|
| 333 |
+
self.num_resolutions = len(ch_mult)
|
| 334 |
+
self.num_res_blocks = num_res_blocks
|
| 335 |
+
self.resolution = resolution
|
| 336 |
+
self.out_ch = out_ch
|
| 337 |
+
self.give_pre_end = False
|
| 338 |
+
self.tanh_out = False
|
| 339 |
+
self.norm_type = norm_type
|
| 340 |
+
self.z_channels = z_channels
|
| 341 |
+
self.channel_multipliers = ch_mult
|
| 342 |
+
self.attn_resolutions = attn_resolutions
|
| 343 |
+
self.causality_axis = causality_axis
|
| 344 |
+
self.attn_type = attn_type
|
| 345 |
+
|
| 346 |
+
base_block_channels = ch * self.channel_multipliers[-1]
|
| 347 |
+
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
| 348 |
+
self.z_shape = (1, z_channels, base_resolution, base_resolution)
|
| 349 |
+
|
| 350 |
+
self.conv_in = make_conv2d(
|
| 351 |
+
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
| 352 |
+
)
|
| 353 |
+
self.non_linearity = torch.nn.SiLU()
|
| 354 |
+
self.mid = build_mid_block(
|
| 355 |
+
channels=base_block_channels,
|
| 356 |
+
temb_channels=self.temb_ch,
|
| 357 |
+
dropout=dropout,
|
| 358 |
+
norm_type=self.norm_type,
|
| 359 |
+
causality_axis=self.causality_axis,
|
| 360 |
+
attn_type=self.attn_type,
|
| 361 |
+
add_attention=mid_block_add_attention,
|
| 362 |
+
)
|
| 363 |
+
self.up, final_block_channels = build_upsampling_path(
|
| 364 |
+
ch=ch,
|
| 365 |
+
ch_mult=ch_mult,
|
| 366 |
+
num_resolutions=self.num_resolutions,
|
| 367 |
+
num_res_blocks=num_res_blocks,
|
| 368 |
+
resolution=resolution,
|
| 369 |
+
temb_channels=self.temb_ch,
|
| 370 |
+
dropout=dropout,
|
| 371 |
+
norm_type=self.norm_type,
|
| 372 |
+
causality_axis=self.causality_axis,
|
| 373 |
+
attn_type=self.attn_type,
|
| 374 |
+
attn_resolutions=attn_resolutions,
|
| 375 |
+
resamp_with_conv=resamp_with_conv,
|
| 376 |
+
initial_block_channels=base_block_channels,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
| 380 |
+
self.conv_out = make_conv2d(
|
| 381 |
+
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
| 385 |
+
"""
|
| 386 |
+
Decode latent features back to audio spectrograms.
|
| 387 |
+
Args:
|
| 388 |
+
sample: Encoded latent representation of shape (batch, channels, frames, mel_bins)
|
| 389 |
+
Returns:
|
| 390 |
+
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
|
| 391 |
+
"""
|
| 392 |
+
sample, target_shape = self._denormalize_latents(sample)
|
| 393 |
+
|
| 394 |
+
h = self.conv_in(sample)
|
| 395 |
+
h = run_mid_block(self.mid, h)
|
| 396 |
+
h = self._run_upsampling_path(h)
|
| 397 |
+
h = self._finalize_output(h)
|
| 398 |
+
|
| 399 |
+
return self._adjust_output_shape(h, target_shape)
|
| 400 |
+
|
| 401 |
+
def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]:
|
| 402 |
+
latent_shape = AudioLatentShape(
|
| 403 |
+
batch=sample.shape[0],
|
| 404 |
+
channels=sample.shape[1],
|
| 405 |
+
frames=sample.shape[2],
|
| 406 |
+
mel_bins=sample.shape[3],
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
sample_patched = self.patchifier.patchify(sample)
|
| 410 |
+
sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
|
| 411 |
+
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
|
| 412 |
+
|
| 413 |
+
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
|
| 414 |
+
if self.causality_axis != CausalityAxis.NONE:
|
| 415 |
+
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
| 416 |
+
|
| 417 |
+
target_shape = AudioLatentShape(
|
| 418 |
+
batch=latent_shape.batch,
|
| 419 |
+
channels=self.out_ch,
|
| 420 |
+
frames=target_frames,
|
| 421 |
+
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
return sample, target_shape
|
| 425 |
+
|
| 426 |
+
def _adjust_output_shape(
|
| 427 |
+
self,
|
| 428 |
+
decoded_output: torch.Tensor,
|
| 429 |
+
target_shape: AudioLatentShape,
|
| 430 |
+
) -> torch.Tensor:
|
| 431 |
+
"""
|
| 432 |
+
Adjust output shape to match target dimensions for variable-length audio.
|
| 433 |
+
This function handles the common case where decoded audio spectrograms need to be
|
| 434 |
+
resized to match a specific target shape.
|
| 435 |
+
Args:
|
| 436 |
+
decoded_output: Tensor of shape (batch, channels, time, frequency)
|
| 437 |
+
target_shape: AudioLatentShape describing (batch, channels, time, mel bins)
|
| 438 |
+
Returns:
|
| 439 |
+
Tensor adjusted to match target_shape exactly
|
| 440 |
+
"""
|
| 441 |
+
# Current output shape: (batch, channels, time, frequency)
|
| 442 |
+
_, _, current_time, current_freq = decoded_output.shape
|
| 443 |
+
target_channels = target_shape.channels
|
| 444 |
+
target_time = target_shape.frames
|
| 445 |
+
target_freq = target_shape.mel_bins
|
| 446 |
+
|
| 447 |
+
# Step 1: Crop first to avoid exceeding target dimensions
|
| 448 |
+
decoded_output = decoded_output[
|
| 449 |
+
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
| 450 |
+
]
|
| 451 |
+
|
| 452 |
+
# Step 2: Calculate padding needed for time and frequency dimensions
|
| 453 |
+
time_padding_needed = target_time - decoded_output.shape[2]
|
| 454 |
+
freq_padding_needed = target_freq - decoded_output.shape[3]
|
| 455 |
+
|
| 456 |
+
# Step 3: Apply padding if needed
|
| 457 |
+
if time_padding_needed > 0 or freq_padding_needed > 0:
|
| 458 |
+
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
|
| 459 |
+
# For audio: pad_left/right = frequency, pad_top/bottom = time
|
| 460 |
+
padding = (
|
| 461 |
+
0,
|
| 462 |
+
max(freq_padding_needed, 0), # frequency padding (left, right)
|
| 463 |
+
0,
|
| 464 |
+
max(time_padding_needed, 0), # time padding (top, bottom)
|
| 465 |
+
)
|
| 466 |
+
decoded_output = F.pad(decoded_output, padding)
|
| 467 |
+
|
| 468 |
+
# Step 4: Final safety crop to ensure exact target shape
|
| 469 |
+
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
| 470 |
+
|
| 471 |
+
return decoded_output
|
| 472 |
+
|
| 473 |
+
def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor:
|
| 474 |
+
for level in reversed(range(self.num_resolutions)):
|
| 475 |
+
stage = self.up[level]
|
| 476 |
+
for block_idx, block in enumerate(stage.block):
|
| 477 |
+
h = block(h, temb=None)
|
| 478 |
+
if stage.attn:
|
| 479 |
+
h = stage.attn[block_idx](h)
|
| 480 |
+
|
| 481 |
+
if level != 0 and hasattr(stage, "upsample"):
|
| 482 |
+
h = stage.upsample(h)
|
| 483 |
+
|
| 484 |
+
return h
|
| 485 |
+
|
| 486 |
+
def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
|
| 487 |
+
if self.give_pre_end:
|
| 488 |
+
return h
|
| 489 |
+
|
| 490 |
+
h = self.norm_out(h)
|
| 491 |
+
h = self.non_linearity(h)
|
| 492 |
+
h = self.conv_out(h)
|
| 493 |
+
return torch.tanh(h) if self.tanh_out else h
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def decode_audio(latent: torch.Tensor, audio_decoder: "AudioDecoder", vocoder: "Vocoder") -> Audio:
|
| 497 |
+
"""
|
| 498 |
+
Decode an audio latent representation using the provided audio decoder and vocoder.
|
| 499 |
+
Args:
|
| 500 |
+
latent: Input audio latent tensor.
|
| 501 |
+
audio_decoder: Model to decode the latent to waveform features.
|
| 502 |
+
vocoder: Model to convert decoded features to audio waveform.
|
| 503 |
+
Returns:
|
| 504 |
+
Decoded audio with waveform and sampling rate.
|
| 505 |
+
"""
|
| 506 |
+
decoded_audio = audio_decoder(latent)
|
| 507 |
+
waveform = vocoder(decoded_audio).squeeze(0).float()
|
| 508 |
+
return Audio(waveform=waveform, sampling_rate=vocoder.output_sampling_rate)
|
packages/ltx-core/src/ltx_core/model/audio_vae/causal_conv_2d.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CausalConv2d(torch.nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
A causal 2D convolution.
|
| 10 |
+
This layer ensures that the output at time `t` only depends on inputs
|
| 11 |
+
at time `t` and earlier. It achieves this by applying asymmetric padding
|
| 12 |
+
to the time dimension (width) before the convolution.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_channels: int,
|
| 18 |
+
out_channels: int,
|
| 19 |
+
kernel_size: int | tuple[int, int],
|
| 20 |
+
stride: int = 1,
|
| 21 |
+
dilation: int | tuple[int, int] = 1,
|
| 22 |
+
groups: int = 1,
|
| 23 |
+
bias: bool = True,
|
| 24 |
+
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.causality_axis = causality_axis
|
| 29 |
+
|
| 30 |
+
# Ensure kernel_size and dilation are tuples
|
| 31 |
+
kernel_size = torch.nn.modules.utils._pair(kernel_size)
|
| 32 |
+
dilation = torch.nn.modules.utils._pair(dilation)
|
| 33 |
+
|
| 34 |
+
# Calculate padding dimensions
|
| 35 |
+
pad_h = (kernel_size[0] - 1) * dilation[0]
|
| 36 |
+
pad_w = (kernel_size[1] - 1) * dilation[1]
|
| 37 |
+
|
| 38 |
+
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
|
| 39 |
+
match self.causality_axis:
|
| 40 |
+
case CausalityAxis.NONE:
|
| 41 |
+
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
| 42 |
+
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
|
| 43 |
+
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
| 44 |
+
case CausalityAxis.HEIGHT:
|
| 45 |
+
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
| 46 |
+
case _:
|
| 47 |
+
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
| 48 |
+
|
| 49 |
+
# The internal convolution layer uses no padding, as we handle it manually
|
| 50 |
+
self.conv = torch.nn.Conv2d(
|
| 51 |
+
in_channels,
|
| 52 |
+
out_channels,
|
| 53 |
+
kernel_size,
|
| 54 |
+
stride=stride,
|
| 55 |
+
padding=0,
|
| 56 |
+
dilation=dilation,
|
| 57 |
+
groups=groups,
|
| 58 |
+
bias=bias,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 62 |
+
# Apply causal padding before convolution
|
| 63 |
+
x = F.pad(x, self.padding)
|
| 64 |
+
return self.conv(x)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def make_conv2d(
|
| 68 |
+
in_channels: int,
|
| 69 |
+
out_channels: int,
|
| 70 |
+
kernel_size: int | tuple[int, int],
|
| 71 |
+
stride: int = 1,
|
| 72 |
+
padding: tuple[int, int, int, int] | None = None,
|
| 73 |
+
dilation: int = 1,
|
| 74 |
+
groups: int = 1,
|
| 75 |
+
bias: bool = True,
|
| 76 |
+
causality_axis: CausalityAxis | None = None,
|
| 77 |
+
) -> torch.nn.Module:
|
| 78 |
+
"""
|
| 79 |
+
Create a 2D convolution layer that can be either causal or non-causal.
|
| 80 |
+
Args:
|
| 81 |
+
in_channels: Number of input channels
|
| 82 |
+
out_channels: Number of output channels
|
| 83 |
+
kernel_size: Size of the convolution kernel
|
| 84 |
+
stride: Convolution stride
|
| 85 |
+
padding: Padding (if None, will be calculated based on causal flag)
|
| 86 |
+
dilation: Dilation rate
|
| 87 |
+
groups: Number of groups for grouped convolution
|
| 88 |
+
bias: Whether to use bias
|
| 89 |
+
causality_axis: Dimension along which to apply causality.
|
| 90 |
+
Returns:
|
| 91 |
+
Either a regular Conv2d or CausalConv2d layer
|
| 92 |
+
"""
|
| 93 |
+
if causality_axis is not None:
|
| 94 |
+
# For causal convolution, padding is handled internally by CausalConv2d
|
| 95 |
+
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
|
| 96 |
+
else:
|
| 97 |
+
# For non-causal convolution, use symmetric padding if not specified
|
| 98 |
+
if padding is None:
|
| 99 |
+
padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size)
|
| 100 |
+
|
| 101 |
+
return torch.nn.Conv2d(
|
| 102 |
+
in_channels,
|
| 103 |
+
out_channels,
|
| 104 |
+
kernel_size,
|
| 105 |
+
stride,
|
| 106 |
+
padding,
|
| 107 |
+
dilation,
|
| 108 |
+
groups,
|
| 109 |
+
bias,
|
| 110 |
+
)
|
packages/ltx-core/src/ltx_core/model/audio_vae/downsample.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Set, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.audio_vae.attention import AttentionType, make_attn
|
| 6 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 7 |
+
from ltx_core.model.audio_vae.resnet import ResnetBlock
|
| 8 |
+
from ltx_core.model.common.normalization import NormType
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Downsample(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
A downsampling layer that can use either a strided convolution
|
| 14 |
+
or average pooling. Supports standard and causal padding for the
|
| 15 |
+
convolutional mode.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
in_channels: int,
|
| 21 |
+
with_conv: bool,
|
| 22 |
+
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.with_conv = with_conv
|
| 26 |
+
self.causality_axis = causality_axis
|
| 27 |
+
|
| 28 |
+
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
|
| 29 |
+
raise ValueError("causality is only supported when `with_conv=True`.")
|
| 30 |
+
|
| 31 |
+
if self.with_conv:
|
| 32 |
+
# Do time downsampling here
|
| 33 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 34 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 35 |
+
|
| 36 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
if self.with_conv:
|
| 38 |
+
# Padding tuple is in the order: (left, right, top, bottom).
|
| 39 |
+
match self.causality_axis:
|
| 40 |
+
case CausalityAxis.NONE:
|
| 41 |
+
pad = (0, 1, 0, 1)
|
| 42 |
+
case CausalityAxis.WIDTH:
|
| 43 |
+
pad = (2, 0, 0, 1)
|
| 44 |
+
case CausalityAxis.HEIGHT:
|
| 45 |
+
pad = (0, 1, 2, 0)
|
| 46 |
+
case CausalityAxis.WIDTH_COMPATIBILITY:
|
| 47 |
+
pad = (1, 0, 0, 1)
|
| 48 |
+
case _:
|
| 49 |
+
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
| 50 |
+
|
| 51 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 52 |
+
x = self.conv(x)
|
| 53 |
+
else:
|
| 54 |
+
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
|
| 55 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 56 |
+
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def build_downsampling_path( # noqa: PLR0913
|
| 61 |
+
*,
|
| 62 |
+
ch: int,
|
| 63 |
+
ch_mult: Tuple[int, ...],
|
| 64 |
+
num_resolutions: int,
|
| 65 |
+
num_res_blocks: int,
|
| 66 |
+
resolution: int,
|
| 67 |
+
temb_channels: int,
|
| 68 |
+
dropout: float,
|
| 69 |
+
norm_type: NormType,
|
| 70 |
+
causality_axis: CausalityAxis,
|
| 71 |
+
attn_type: AttentionType,
|
| 72 |
+
attn_resolutions: Set[int],
|
| 73 |
+
resamp_with_conv: bool,
|
| 74 |
+
) -> tuple[torch.nn.ModuleList, int]:
|
| 75 |
+
"""Build the downsampling path with residual blocks, attention, and downsampling layers."""
|
| 76 |
+
down_modules = torch.nn.ModuleList()
|
| 77 |
+
curr_res = resolution
|
| 78 |
+
in_ch_mult = (1, *tuple(ch_mult))
|
| 79 |
+
block_in = ch
|
| 80 |
+
|
| 81 |
+
for i_level in range(num_resolutions):
|
| 82 |
+
block = torch.nn.ModuleList()
|
| 83 |
+
attn = torch.nn.ModuleList()
|
| 84 |
+
block_in = ch * in_ch_mult[i_level]
|
| 85 |
+
block_out = ch * ch_mult[i_level]
|
| 86 |
+
|
| 87 |
+
for _ in range(num_res_blocks):
|
| 88 |
+
block.append(
|
| 89 |
+
ResnetBlock(
|
| 90 |
+
in_channels=block_in,
|
| 91 |
+
out_channels=block_out,
|
| 92 |
+
temb_channels=temb_channels,
|
| 93 |
+
dropout=dropout,
|
| 94 |
+
norm_type=norm_type,
|
| 95 |
+
causality_axis=causality_axis,
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
block_in = block_out
|
| 99 |
+
if curr_res in attn_resolutions:
|
| 100 |
+
attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
|
| 101 |
+
|
| 102 |
+
down = torch.nn.Module()
|
| 103 |
+
down.block = block
|
| 104 |
+
down.attn = attn
|
| 105 |
+
if i_level != num_resolutions - 1:
|
| 106 |
+
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
| 107 |
+
curr_res = curr_res // 2
|
| 108 |
+
down_modules.append(down)
|
| 109 |
+
|
| 110 |
+
return down_modules, block_in
|
packages/ltx-core/src/ltx_core/model/audio_vae/model_configurator.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.loader.sd_ops import KeyValueOperationResult, SDOps
|
| 4 |
+
from ltx_core.model.audio_vae.attention import AttentionType
|
| 5 |
+
from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder
|
| 6 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 7 |
+
from ltx_core.model.audio_vae.vocoder import MelSTFT, Vocoder, VocoderWithBWE
|
| 8 |
+
from ltx_core.model.common.normalization import NormType
|
| 9 |
+
from ltx_core.model.model_protocol import ModelConfigurator
|
| 10 |
+
from ltx_core.utils import check_config_value
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _vocoder_from_config(
|
| 14 |
+
cfg: dict,
|
| 15 |
+
apply_final_activation: bool = True,
|
| 16 |
+
output_sampling_rate: int | None = None,
|
| 17 |
+
) -> Vocoder:
|
| 18 |
+
"""Instantiate a Vocoder from a flat config dict.
|
| 19 |
+
Args:
|
| 20 |
+
cfg: Vocoder config dict (keys match Vocoder constructor args).
|
| 21 |
+
apply_final_activation: Whether to apply tanh/clamp at the output.
|
| 22 |
+
output_sampling_rate: Explicit override for the output sample rate.
|
| 23 |
+
When None, reads from cfg["output_sampling_rate"] (default 24000).
|
| 24 |
+
"""
|
| 25 |
+
return Vocoder(
|
| 26 |
+
resblock_kernel_sizes=cfg.get("resblock_kernel_sizes", [3, 7, 11]),
|
| 27 |
+
upsample_rates=cfg.get("upsample_rates", [6, 5, 2, 2, 2]),
|
| 28 |
+
upsample_kernel_sizes=cfg.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]),
|
| 29 |
+
resblock_dilation_sizes=cfg.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]),
|
| 30 |
+
upsample_initial_channel=cfg.get("upsample_initial_channel", 1024),
|
| 31 |
+
resblock=cfg.get("resblock", "1"),
|
| 32 |
+
output_sampling_rate=(
|
| 33 |
+
output_sampling_rate if output_sampling_rate is not None else cfg.get("output_sampling_rate", 24000)
|
| 34 |
+
),
|
| 35 |
+
activation=cfg.get("activation", "snake"),
|
| 36 |
+
use_tanh_at_final=cfg.get("use_tanh_at_final", True),
|
| 37 |
+
apply_final_activation=apply_final_activation,
|
| 38 |
+
use_bias_at_final=cfg.get("use_bias_at_final", True),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class VocoderConfigurator(ModelConfigurator[Vocoder]):
|
| 43 |
+
"""Configurator that auto-detects the checkpoint format.
|
| 44 |
+
Returns a plain Vocoder for pre-ltx-2.3 checkpoints (flat config) or a
|
| 45 |
+
VocoderWithBWE for ltx-2.3+ checkpoints (nested "vocoder" + "bwe" config).
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def from_config(cls: type[Vocoder], config: dict) -> Vocoder | VocoderWithBWE:
|
| 50 |
+
cfg = config.get("vocoder", {})
|
| 51 |
+
|
| 52 |
+
if "bwe" not in cfg:
|
| 53 |
+
check_config_value(cfg, "resblock", "1")
|
| 54 |
+
check_config_value(cfg, "stereo", True)
|
| 55 |
+
return _vocoder_from_config(cfg)
|
| 56 |
+
|
| 57 |
+
vocoder_cfg = cfg.get("vocoder", {})
|
| 58 |
+
bwe_cfg = cfg["bwe"]
|
| 59 |
+
|
| 60 |
+
check_config_value(vocoder_cfg, "resblock", "AMP1")
|
| 61 |
+
check_config_value(vocoder_cfg, "stereo", True)
|
| 62 |
+
check_config_value(vocoder_cfg, "activation", "snakebeta")
|
| 63 |
+
check_config_value(bwe_cfg, "resblock", "AMP1")
|
| 64 |
+
check_config_value(bwe_cfg, "stereo", True)
|
| 65 |
+
check_config_value(bwe_cfg, "activation", "snakebeta")
|
| 66 |
+
|
| 67 |
+
vocoder = _vocoder_from_config(
|
| 68 |
+
vocoder_cfg,
|
| 69 |
+
output_sampling_rate=bwe_cfg["input_sampling_rate"],
|
| 70 |
+
)
|
| 71 |
+
bwe_generator = _vocoder_from_config(
|
| 72 |
+
bwe_cfg,
|
| 73 |
+
apply_final_activation=False,
|
| 74 |
+
output_sampling_rate=bwe_cfg["output_sampling_rate"],
|
| 75 |
+
)
|
| 76 |
+
mel_stft = MelSTFT(
|
| 77 |
+
filter_length=bwe_cfg["n_fft"],
|
| 78 |
+
hop_length=bwe_cfg["hop_length"],
|
| 79 |
+
win_length=bwe_cfg["n_fft"],
|
| 80 |
+
n_mel_channels=bwe_cfg["num_mels"],
|
| 81 |
+
)
|
| 82 |
+
return VocoderWithBWE(
|
| 83 |
+
vocoder=vocoder,
|
| 84 |
+
bwe_generator=bwe_generator,
|
| 85 |
+
mel_stft=mel_stft,
|
| 86 |
+
input_sampling_rate=bwe_cfg["input_sampling_rate"],
|
| 87 |
+
output_sampling_rate=bwe_cfg["output_sampling_rate"],
|
| 88 |
+
hop_length=bwe_cfg["hop_length"],
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _strip_vocoder_prefix(key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
|
| 93 |
+
"""Strip the leading 'vocoder.' prefix exactly once.
|
| 94 |
+
Uses removeprefix instead of str.replace so that BWE keys like
|
| 95 |
+
'vocoder.vocoder.conv_pre' become 'vocoder.conv_pre' (not 'conv_pre').
|
| 96 |
+
Works identically for legacy keys like 'vocoder.conv_pre' → 'conv_pre'.
|
| 97 |
+
"""
|
| 98 |
+
return [KeyValueOperationResult(key.removeprefix("vocoder."), value)]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
VOCODER_COMFY_KEYS_FILTER = (
|
| 102 |
+
SDOps("VOCODER_COMFY_KEYS_FILTER")
|
| 103 |
+
.with_matching(prefix="vocoder.")
|
| 104 |
+
.with_kv_operation(operation=_strip_vocoder_prefix, key_prefix="vocoder.")
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class AudioDecoderConfigurator(ModelConfigurator[AudioDecoder]):
|
| 109 |
+
@classmethod
|
| 110 |
+
def from_config(cls: type[AudioDecoder], config: dict) -> AudioDecoder:
|
| 111 |
+
audio_vae_cfg = config.get("audio_vae", {})
|
| 112 |
+
model_cfg = audio_vae_cfg.get("model", {})
|
| 113 |
+
model_params = model_cfg.get("params", {})
|
| 114 |
+
ddconfig = model_params.get("ddconfig", {})
|
| 115 |
+
preprocessing_cfg = audio_vae_cfg.get("preprocessing", {})
|
| 116 |
+
stft_cfg = preprocessing_cfg.get("stft", {})
|
| 117 |
+
mel_cfg = preprocessing_cfg.get("mel", {})
|
| 118 |
+
variables_cfg = audio_vae_cfg.get("variables", {})
|
| 119 |
+
|
| 120 |
+
sample_rate = model_params.get("sampling_rate", 16000)
|
| 121 |
+
mel_hop_length = stft_cfg.get("hop_length", 160)
|
| 122 |
+
is_causal = stft_cfg.get("causal", True)
|
| 123 |
+
mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins")
|
| 124 |
+
|
| 125 |
+
return AudioDecoder(
|
| 126 |
+
ch=ddconfig.get("ch", 128),
|
| 127 |
+
out_ch=ddconfig.get("out_ch", 2),
|
| 128 |
+
ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))),
|
| 129 |
+
num_res_blocks=ddconfig.get("num_res_blocks", 2),
|
| 130 |
+
attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}),
|
| 131 |
+
resolution=ddconfig.get("resolution", 256),
|
| 132 |
+
z_channels=ddconfig.get("z_channels", 8),
|
| 133 |
+
norm_type=NormType(ddconfig.get("norm_type", "pixel")),
|
| 134 |
+
causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")),
|
| 135 |
+
dropout=ddconfig.get("dropout", 0.0),
|
| 136 |
+
mid_block_add_attention=ddconfig.get("mid_block_add_attention", True),
|
| 137 |
+
sample_rate=sample_rate,
|
| 138 |
+
mel_hop_length=mel_hop_length,
|
| 139 |
+
is_causal=is_causal,
|
| 140 |
+
mel_bins=mel_bins,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class AudioEncoderConfigurator(ModelConfigurator[AudioEncoder]):
|
| 145 |
+
@classmethod
|
| 146 |
+
def from_config(cls: type[AudioEncoder], config: dict) -> AudioEncoder:
|
| 147 |
+
audio_vae_cfg = config.get("audio_vae", {})
|
| 148 |
+
model_cfg = audio_vae_cfg.get("model", {})
|
| 149 |
+
model_params = model_cfg.get("params", {})
|
| 150 |
+
ddconfig = model_params.get("ddconfig", {})
|
| 151 |
+
preprocessing_cfg = audio_vae_cfg.get("preprocessing", {})
|
| 152 |
+
stft_cfg = preprocessing_cfg.get("stft", {})
|
| 153 |
+
mel_cfg = preprocessing_cfg.get("mel", {})
|
| 154 |
+
variables_cfg = audio_vae_cfg.get("variables", {})
|
| 155 |
+
|
| 156 |
+
sample_rate = model_params.get("sampling_rate", 16000)
|
| 157 |
+
mel_hop_length = stft_cfg.get("hop_length", 160)
|
| 158 |
+
n_fft = stft_cfg.get("filter_length", 1024)
|
| 159 |
+
is_causal = stft_cfg.get("causal", True)
|
| 160 |
+
mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins")
|
| 161 |
+
|
| 162 |
+
return AudioEncoder(
|
| 163 |
+
ch=ddconfig.get("ch", 128),
|
| 164 |
+
ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))),
|
| 165 |
+
num_res_blocks=ddconfig.get("num_res_blocks", 2),
|
| 166 |
+
attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}),
|
| 167 |
+
resolution=ddconfig.get("resolution", 256),
|
| 168 |
+
z_channels=ddconfig.get("z_channels", 8),
|
| 169 |
+
double_z=ddconfig.get("double_z", True),
|
| 170 |
+
dropout=ddconfig.get("dropout", 0.0),
|
| 171 |
+
resamp_with_conv=ddconfig.get("resamp_with_conv", True),
|
| 172 |
+
in_channels=ddconfig.get("in_channels", 2),
|
| 173 |
+
attn_type=AttentionType(ddconfig.get("attn_type", "vanilla")),
|
| 174 |
+
mid_block_add_attention=ddconfig.get("mid_block_add_attention", True),
|
| 175 |
+
norm_type=NormType(ddconfig.get("norm_type", "pixel")),
|
| 176 |
+
causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")),
|
| 177 |
+
sample_rate=sample_rate,
|
| 178 |
+
mel_hop_length=mel_hop_length,
|
| 179 |
+
n_fft=n_fft,
|
| 180 |
+
is_causal=is_causal,
|
| 181 |
+
mel_bins=mel_bins,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
AUDIO_VAE_DECODER_COMFY_KEYS_FILTER = (
|
| 186 |
+
SDOps("AUDIO_VAE_DECODER_COMFY_KEYS_FILTER")
|
| 187 |
+
.with_matching(prefix="audio_vae.decoder.")
|
| 188 |
+
.with_matching(prefix="audio_vae.per_channel_statistics.")
|
| 189 |
+
.with_replacement("audio_vae.decoder.", "")
|
| 190 |
+
.with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.")
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER = (
|
| 195 |
+
SDOps("AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER")
|
| 196 |
+
.with_matching(prefix="audio_vae.encoder.")
|
| 197 |
+
.with_matching(prefix="audio_vae.per_channel_statistics.")
|
| 198 |
+
.with_replacement("audio_vae.encoder.", "")
|
| 199 |
+
.with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.")
|
| 200 |
+
)
|
packages/ltx-core/src/ltx_core/model/audio_vae/ops.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchaudio
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
from ltx_core.types import Audio
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AudioProcessor(nn.Module):
|
| 9 |
+
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
target_sample_rate: int,
|
| 14 |
+
mel_bins: int,
|
| 15 |
+
mel_hop_length: int,
|
| 16 |
+
n_fft: int,
|
| 17 |
+
) -> None:
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.target_sample_rate = target_sample_rate
|
| 20 |
+
self.mel_transform = torchaudio.transforms.MelSpectrogram(
|
| 21 |
+
sample_rate=target_sample_rate,
|
| 22 |
+
n_fft=n_fft,
|
| 23 |
+
win_length=n_fft,
|
| 24 |
+
hop_length=mel_hop_length,
|
| 25 |
+
f_min=0.0,
|
| 26 |
+
f_max=target_sample_rate / 2.0,
|
| 27 |
+
n_mels=mel_bins,
|
| 28 |
+
window_fn=torch.hann_window,
|
| 29 |
+
center=True,
|
| 30 |
+
pad_mode="reflect",
|
| 31 |
+
power=1.0,
|
| 32 |
+
mel_scale="slaney",
|
| 33 |
+
norm="slaney",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def resample_audio(self, audio: Audio) -> Audio:
|
| 37 |
+
"""Resample audio to the processor's target sample rate if needed."""
|
| 38 |
+
if audio.sampling_rate == self.target_sample_rate:
|
| 39 |
+
return audio
|
| 40 |
+
resampled = torchaudio.functional.resample(audio.waveform, audio.sampling_rate, self.target_sample_rate)
|
| 41 |
+
resampled = resampled.to(device=audio.waveform.device, dtype=audio.waveform.dtype)
|
| 42 |
+
return Audio(waveform=resampled, sampling_rate=self.target_sample_rate)
|
| 43 |
+
|
| 44 |
+
def waveform_to_mel(
|
| 45 |
+
self,
|
| 46 |
+
audio: Audio,
|
| 47 |
+
) -> torch.Tensor:
|
| 48 |
+
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
|
| 49 |
+
waveform = self.resample_audio(audio).waveform
|
| 50 |
+
|
| 51 |
+
mel = self.mel_transform(waveform)
|
| 52 |
+
mel = torch.log(torch.clamp(mel, min=1e-5))
|
| 53 |
+
|
| 54 |
+
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
|
| 55 |
+
return mel.permute(0, 1, 3, 2).contiguous()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class PerChannelStatistics(nn.Module):
|
| 59 |
+
"""
|
| 60 |
+
Per-channel statistics for normalizing and denormalizing the latent representation.
|
| 61 |
+
This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, latent_channels: int = 128) -> None:
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.register_buffer("std-of-means", torch.empty(latent_channels))
|
| 67 |
+
self.register_buffer("mean-of-means", torch.empty(latent_channels))
|
| 68 |
+
|
| 69 |
+
def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
|
| 71 |
+
|
| 72 |
+
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
|
packages/ltx-core/src/ltx_core/model/audio_vae/resnet.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
|
| 6 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 7 |
+
from ltx_core.model.common.normalization import NormType, build_normalization_layer
|
| 8 |
+
|
| 9 |
+
LRELU_SLOPE = 0.1
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ResBlock1(torch.nn.Module):
|
| 13 |
+
def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)):
|
| 14 |
+
super(ResBlock1, self).__init__()
|
| 15 |
+
self.convs1 = torch.nn.ModuleList(
|
| 16 |
+
[
|
| 17 |
+
torch.nn.Conv1d(
|
| 18 |
+
channels,
|
| 19 |
+
channels,
|
| 20 |
+
kernel_size,
|
| 21 |
+
1,
|
| 22 |
+
dilation=dilation[0],
|
| 23 |
+
padding="same",
|
| 24 |
+
),
|
| 25 |
+
torch.nn.Conv1d(
|
| 26 |
+
channels,
|
| 27 |
+
channels,
|
| 28 |
+
kernel_size,
|
| 29 |
+
1,
|
| 30 |
+
dilation=dilation[1],
|
| 31 |
+
padding="same",
|
| 32 |
+
),
|
| 33 |
+
torch.nn.Conv1d(
|
| 34 |
+
channels,
|
| 35 |
+
channels,
|
| 36 |
+
kernel_size,
|
| 37 |
+
1,
|
| 38 |
+
dilation=dilation[2],
|
| 39 |
+
padding="same",
|
| 40 |
+
),
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.convs2 = torch.nn.ModuleList(
|
| 45 |
+
[
|
| 46 |
+
torch.nn.Conv1d(
|
| 47 |
+
channels,
|
| 48 |
+
channels,
|
| 49 |
+
kernel_size,
|
| 50 |
+
1,
|
| 51 |
+
dilation=1,
|
| 52 |
+
padding="same",
|
| 53 |
+
),
|
| 54 |
+
torch.nn.Conv1d(
|
| 55 |
+
channels,
|
| 56 |
+
channels,
|
| 57 |
+
kernel_size,
|
| 58 |
+
1,
|
| 59 |
+
dilation=1,
|
| 60 |
+
padding="same",
|
| 61 |
+
),
|
| 62 |
+
torch.nn.Conv1d(
|
| 63 |
+
channels,
|
| 64 |
+
channels,
|
| 65 |
+
kernel_size,
|
| 66 |
+
1,
|
| 67 |
+
dilation=1,
|
| 68 |
+
padding="same",
|
| 69 |
+
),
|
| 70 |
+
]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 74 |
+
for conv1, conv2 in zip(self.convs1, self.convs2, strict=True):
|
| 75 |
+
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
| 76 |
+
xt = conv1(xt)
|
| 77 |
+
xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
|
| 78 |
+
xt = conv2(xt)
|
| 79 |
+
x = xt + x
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ResBlock2(torch.nn.Module):
|
| 84 |
+
def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)):
|
| 85 |
+
super(ResBlock2, self).__init__()
|
| 86 |
+
self.convs = torch.nn.ModuleList(
|
| 87 |
+
[
|
| 88 |
+
torch.nn.Conv1d(
|
| 89 |
+
channels,
|
| 90 |
+
channels,
|
| 91 |
+
kernel_size,
|
| 92 |
+
1,
|
| 93 |
+
dilation=dilation[0],
|
| 94 |
+
padding="same",
|
| 95 |
+
),
|
| 96 |
+
torch.nn.Conv1d(
|
| 97 |
+
channels,
|
| 98 |
+
channels,
|
| 99 |
+
kernel_size,
|
| 100 |
+
1,
|
| 101 |
+
dilation=dilation[1],
|
| 102 |
+
padding="same",
|
| 103 |
+
),
|
| 104 |
+
]
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 108 |
+
for conv in self.convs:
|
| 109 |
+
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
| 110 |
+
xt = conv(xt)
|
| 111 |
+
x = xt + x
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ResnetBlock(torch.nn.Module):
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
*,
|
| 119 |
+
in_channels: int,
|
| 120 |
+
out_channels: int | None = None,
|
| 121 |
+
conv_shortcut: bool = False,
|
| 122 |
+
dropout: float = 0.0,
|
| 123 |
+
temb_channels: int = 512,
|
| 124 |
+
norm_type: NormType = NormType.GROUP,
|
| 125 |
+
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
| 126 |
+
) -> None:
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.causality_axis = causality_axis
|
| 129 |
+
|
| 130 |
+
if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
|
| 131 |
+
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
| 132 |
+
self.in_channels = in_channels
|
| 133 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 134 |
+
self.out_channels = out_channels
|
| 135 |
+
self.use_conv_shortcut = conv_shortcut
|
| 136 |
+
|
| 137 |
+
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
| 138 |
+
self.non_linearity = torch.nn.SiLU()
|
| 139 |
+
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
| 140 |
+
if temb_channels > 0:
|
| 141 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
| 142 |
+
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
| 143 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 144 |
+
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
| 145 |
+
if self.in_channels != self.out_channels:
|
| 146 |
+
if self.use_conv_shortcut:
|
| 147 |
+
self.conv_shortcut = make_conv2d(
|
| 148 |
+
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
self.nin_shortcut = make_conv2d(
|
| 152 |
+
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def forward(
|
| 156 |
+
self,
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
temb: torch.Tensor | None = None,
|
| 159 |
+
) -> torch.Tensor:
|
| 160 |
+
h = x
|
| 161 |
+
h = self.norm1(h)
|
| 162 |
+
h = self.non_linearity(h)
|
| 163 |
+
h = self.conv1(h)
|
| 164 |
+
|
| 165 |
+
if temb is not None:
|
| 166 |
+
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
| 167 |
+
|
| 168 |
+
h = self.norm2(h)
|
| 169 |
+
h = self.non_linearity(h)
|
| 170 |
+
h = self.dropout(h)
|
| 171 |
+
h = self.conv2(h)
|
| 172 |
+
|
| 173 |
+
if self.in_channels != self.out_channels:
|
| 174 |
+
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
|
| 175 |
+
|
| 176 |
+
return x + h
|
packages/ltx-core/src/ltx_core/model/audio_vae/upsample.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Set, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.audio_vae.attention import AttentionType, make_attn
|
| 6 |
+
from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
|
| 7 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 8 |
+
from ltx_core.model.audio_vae.resnet import ResnetBlock
|
| 9 |
+
from ltx_core.model.common.normalization import NormType
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Upsample(torch.nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_channels: int,
|
| 16 |
+
with_conv: bool,
|
| 17 |
+
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
| 18 |
+
) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.with_conv = with_conv
|
| 21 |
+
self.causality_axis = causality_axis
|
| 22 |
+
if self.with_conv:
|
| 23 |
+
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
| 24 |
+
|
| 25 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 26 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 27 |
+
if self.with_conv:
|
| 28 |
+
x = self.conv(x)
|
| 29 |
+
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
|
| 30 |
+
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
|
| 31 |
+
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
|
| 32 |
+
# So the output elements rely on the following windows:
|
| 33 |
+
# 0: [-,-,0]
|
| 34 |
+
# 1: [-,0,0]
|
| 35 |
+
# 2: [0,0,1]
|
| 36 |
+
# 3: [0,1,1]
|
| 37 |
+
# 4: [1,1,2]
|
| 38 |
+
# 5: [1,2,2]
|
| 39 |
+
# Notice that the first and second elements in the output rely only on the first element in the input,
|
| 40 |
+
# while all other elements rely on two elements in the input.
|
| 41 |
+
# So we can drop the first element to undo the padding (rather than the last element).
|
| 42 |
+
# This is a no-op for non-causal convolutions.
|
| 43 |
+
match self.causality_axis:
|
| 44 |
+
case CausalityAxis.NONE:
|
| 45 |
+
pass # x remains unchanged
|
| 46 |
+
case CausalityAxis.HEIGHT:
|
| 47 |
+
x = x[:, :, 1:, :]
|
| 48 |
+
case CausalityAxis.WIDTH:
|
| 49 |
+
x = x[:, :, :, 1:]
|
| 50 |
+
case CausalityAxis.WIDTH_COMPATIBILITY:
|
| 51 |
+
pass # x remains unchanged
|
| 52 |
+
case _:
|
| 53 |
+
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
| 54 |
+
|
| 55 |
+
return x
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def build_upsampling_path( # noqa: PLR0913
|
| 59 |
+
*,
|
| 60 |
+
ch: int,
|
| 61 |
+
ch_mult: Tuple[int, ...],
|
| 62 |
+
num_resolutions: int,
|
| 63 |
+
num_res_blocks: int,
|
| 64 |
+
resolution: int,
|
| 65 |
+
temb_channels: int,
|
| 66 |
+
dropout: float,
|
| 67 |
+
norm_type: NormType,
|
| 68 |
+
causality_axis: CausalityAxis,
|
| 69 |
+
attn_type: AttentionType,
|
| 70 |
+
attn_resolutions: Set[int],
|
| 71 |
+
resamp_with_conv: bool,
|
| 72 |
+
initial_block_channels: int,
|
| 73 |
+
) -> tuple[torch.nn.ModuleList, int]:
|
| 74 |
+
"""Build the upsampling path with residual blocks, attention, and upsampling layers."""
|
| 75 |
+
up_modules = torch.nn.ModuleList()
|
| 76 |
+
block_in = initial_block_channels
|
| 77 |
+
curr_res = resolution // (2 ** (num_resolutions - 1))
|
| 78 |
+
|
| 79 |
+
for level in reversed(range(num_resolutions)):
|
| 80 |
+
stage = torch.nn.Module()
|
| 81 |
+
stage.block = torch.nn.ModuleList()
|
| 82 |
+
stage.attn = torch.nn.ModuleList()
|
| 83 |
+
block_out = ch * ch_mult[level]
|
| 84 |
+
|
| 85 |
+
for _ in range(num_res_blocks + 1):
|
| 86 |
+
stage.block.append(
|
| 87 |
+
ResnetBlock(
|
| 88 |
+
in_channels=block_in,
|
| 89 |
+
out_channels=block_out,
|
| 90 |
+
temb_channels=temb_channels,
|
| 91 |
+
dropout=dropout,
|
| 92 |
+
norm_type=norm_type,
|
| 93 |
+
causality_axis=causality_axis,
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
block_in = block_out
|
| 97 |
+
if curr_res in attn_resolutions:
|
| 98 |
+
stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
|
| 99 |
+
|
| 100 |
+
if level != 0:
|
| 101 |
+
stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
| 102 |
+
curr_res *= 2
|
| 103 |
+
|
| 104 |
+
up_modules.insert(0, stage)
|
| 105 |
+
|
| 106 |
+
return up_modules, block_in
|
packages/ltx-core/src/ltx_core/model/audio_vae/vocoder.py
ADDED
|
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import einops
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from ltx_core.model.audio_vae.resnet import LRELU_SLOPE, ResBlock1
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
| 13 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
| 18 |
+
# Adopted from https://github.com/NVIDIA/BigVGAN
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _sinc(x: torch.Tensor) -> torch.Tensor:
|
| 23 |
+
return torch.where(
|
| 24 |
+
x == 0,
|
| 25 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
| 26 |
+
torch.sin(math.pi * x) / math.pi / x,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
|
| 31 |
+
even = kernel_size % 2 == 0
|
| 32 |
+
half_size = kernel_size // 2
|
| 33 |
+
delta_f = 4 * half_width
|
| 34 |
+
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
| 35 |
+
if amplitude > 50.0:
|
| 36 |
+
beta = 0.1102 * (amplitude - 8.7)
|
| 37 |
+
elif amplitude >= 21.0:
|
| 38 |
+
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
|
| 39 |
+
else:
|
| 40 |
+
beta = 0.0
|
| 41 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
| 42 |
+
time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
|
| 43 |
+
if cutoff == 0:
|
| 44 |
+
filter_ = torch.zeros_like(time)
|
| 45 |
+
else:
|
| 46 |
+
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
| 47 |
+
filter_ /= filter_.sum()
|
| 48 |
+
return filter_.view(1, 1, kernel_size)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LowPassFilter1d(nn.Module):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
cutoff: float = 0.5,
|
| 55 |
+
half_width: float = 0.6,
|
| 56 |
+
stride: int = 1,
|
| 57 |
+
padding: bool = True,
|
| 58 |
+
padding_mode: str = "replicate",
|
| 59 |
+
kernel_size: int = 12,
|
| 60 |
+
) -> None:
|
| 61 |
+
super().__init__()
|
| 62 |
+
if cutoff < -0.0:
|
| 63 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
| 64 |
+
if cutoff > 0.5:
|
| 65 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
| 66 |
+
self.kernel_size = kernel_size
|
| 67 |
+
self.even = kernel_size % 2 == 0
|
| 68 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
| 69 |
+
self.pad_right = kernel_size // 2
|
| 70 |
+
self.stride = stride
|
| 71 |
+
self.padding = padding
|
| 72 |
+
self.padding_mode = padding_mode
|
| 73 |
+
self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size))
|
| 74 |
+
|
| 75 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
_, n_channels, _ = x.shape
|
| 77 |
+
if self.padding:
|
| 78 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
| 79 |
+
return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class UpSample1d(nn.Module):
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
ratio: int = 2,
|
| 86 |
+
kernel_size: int | None = None,
|
| 87 |
+
persistent: bool = True,
|
| 88 |
+
window_type: str = "kaiser",
|
| 89 |
+
) -> None:
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.ratio = ratio
|
| 92 |
+
self.stride = ratio
|
| 93 |
+
|
| 94 |
+
if window_type == "hann":
|
| 95 |
+
# Hann-windowed sinc filter equivalent to torchaudio.functional.resample
|
| 96 |
+
rolloff = 0.99
|
| 97 |
+
lowpass_filter_width = 6
|
| 98 |
+
width = math.ceil(lowpass_filter_width / rolloff)
|
| 99 |
+
self.kernel_size = 2 * width * ratio + 1
|
| 100 |
+
self.pad = width
|
| 101 |
+
self.pad_left = 2 * width * ratio
|
| 102 |
+
self.pad_right = self.kernel_size - ratio
|
| 103 |
+
time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
| 104 |
+
time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
|
| 105 |
+
window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
| 106 |
+
sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
|
| 107 |
+
else:
|
| 108 |
+
# Kaiser-windowed sinc filter (BigVGAN default).
|
| 109 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 110 |
+
self.pad = self.kernel_size // ratio - 1
|
| 111 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
| 112 |
+
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
| 113 |
+
sinc_filter = kaiser_sinc_filter1d(
|
| 114 |
+
cutoff=0.5 / ratio,
|
| 115 |
+
half_width=0.6 / ratio,
|
| 116 |
+
kernel_size=self.kernel_size,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self.register_buffer("filter", sinc_filter, persistent=persistent)
|
| 120 |
+
|
| 121 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
_, n_channels, _ = x.shape
|
| 123 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
| 124 |
+
filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1)
|
| 125 |
+
x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels)
|
| 126 |
+
return x[..., self.pad_left : -self.pad_right]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class DownSample1d(nn.Module):
|
| 130 |
+
def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None:
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.ratio = ratio
|
| 133 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 134 |
+
self.lowpass = LowPassFilter1d(
|
| 135 |
+
cutoff=0.5 / ratio,
|
| 136 |
+
half_width=0.6 / ratio,
|
| 137 |
+
stride=ratio,
|
| 138 |
+
kernel_size=self.kernel_size,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 142 |
+
return self.lowpass(x)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Activation1d(nn.Module):
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
activation: nn.Module,
|
| 149 |
+
up_ratio: int = 2,
|
| 150 |
+
down_ratio: int = 2,
|
| 151 |
+
up_kernel_size: int = 12,
|
| 152 |
+
down_kernel_size: int = 12,
|
| 153 |
+
) -> None:
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.act = activation
|
| 156 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
| 157 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
| 158 |
+
|
| 159 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 160 |
+
x = self.upsample(x)
|
| 161 |
+
x = self.act(x)
|
| 162 |
+
return self.downsample(x)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class Snake(nn.Module):
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
in_features: int,
|
| 169 |
+
alpha: float = 1.0,
|
| 170 |
+
alpha_trainable: bool = True,
|
| 171 |
+
alpha_logscale: bool = True,
|
| 172 |
+
) -> None:
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.alpha_logscale = alpha_logscale
|
| 175 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
| 176 |
+
self.alpha.requires_grad = alpha_trainable
|
| 177 |
+
self.eps = 1e-9
|
| 178 |
+
|
| 179 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 180 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
| 181 |
+
if self.alpha_logscale:
|
| 182 |
+
alpha = torch.exp(alpha)
|
| 183 |
+
return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class SnakeBeta(nn.Module):
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
in_features: int,
|
| 190 |
+
alpha: float = 1.0,
|
| 191 |
+
alpha_trainable: bool = True,
|
| 192 |
+
alpha_logscale: bool = True,
|
| 193 |
+
) -> None:
|
| 194 |
+
super().__init__()
|
| 195 |
+
self.alpha_logscale = alpha_logscale
|
| 196 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
| 197 |
+
self.alpha.requires_grad = alpha_trainable
|
| 198 |
+
self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha)
|
| 199 |
+
self.beta.requires_grad = alpha_trainable
|
| 200 |
+
self.eps = 1e-9
|
| 201 |
+
|
| 202 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 203 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
| 204 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 205 |
+
if self.alpha_logscale:
|
| 206 |
+
alpha = torch.exp(alpha)
|
| 207 |
+
beta = torch.exp(beta)
|
| 208 |
+
return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class AMPBlock1(nn.Module):
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
channels: int,
|
| 215 |
+
kernel_size: int = 3,
|
| 216 |
+
dilation: tuple[int, int, int] = (1, 3, 5),
|
| 217 |
+
activation: str = "snake",
|
| 218 |
+
) -> None:
|
| 219 |
+
super().__init__()
|
| 220 |
+
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
| 221 |
+
self.convs1 = nn.ModuleList(
|
| 222 |
+
[
|
| 223 |
+
nn.Conv1d(
|
| 224 |
+
channels,
|
| 225 |
+
channels,
|
| 226 |
+
kernel_size,
|
| 227 |
+
1,
|
| 228 |
+
dilation=dilation[0],
|
| 229 |
+
padding=get_padding(kernel_size, dilation[0]),
|
| 230 |
+
),
|
| 231 |
+
nn.Conv1d(
|
| 232 |
+
channels,
|
| 233 |
+
channels,
|
| 234 |
+
kernel_size,
|
| 235 |
+
1,
|
| 236 |
+
dilation=dilation[1],
|
| 237 |
+
padding=get_padding(kernel_size, dilation[1]),
|
| 238 |
+
),
|
| 239 |
+
nn.Conv1d(
|
| 240 |
+
channels,
|
| 241 |
+
channels,
|
| 242 |
+
kernel_size,
|
| 243 |
+
1,
|
| 244 |
+
dilation=dilation[2],
|
| 245 |
+
padding=get_padding(kernel_size, dilation[2]),
|
| 246 |
+
),
|
| 247 |
+
]
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
self.convs2 = nn.ModuleList(
|
| 251 |
+
[
|
| 252 |
+
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
| 253 |
+
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
| 254 |
+
nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)),
|
| 255 |
+
]
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))])
|
| 259 |
+
self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))])
|
| 260 |
+
|
| 261 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 262 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True):
|
| 263 |
+
xt = a1(x)
|
| 264 |
+
xt = c1(xt)
|
| 265 |
+
xt = a2(xt)
|
| 266 |
+
xt = c2(xt)
|
| 267 |
+
x = x + xt
|
| 268 |
+
return x
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class Vocoder(torch.nn.Module):
|
| 272 |
+
"""
|
| 273 |
+
Vocoder model for synthesizing audio from Mel spectrograms.
|
| 274 |
+
Args:
|
| 275 |
+
resblock_kernel_sizes: List of kernel sizes for the residual blocks.
|
| 276 |
+
This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
|
| 277 |
+
upsample_rates: List of upsampling rates.
|
| 278 |
+
This value is read from the checkpoint at `config.vocoder.upsample_rates`.
|
| 279 |
+
upsample_kernel_sizes: List of kernel sizes for the upsampling layers.
|
| 280 |
+
This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`.
|
| 281 |
+
resblock_dilation_sizes: List of dilation sizes for the residual blocks.
|
| 282 |
+
This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
|
| 283 |
+
upsample_initial_channel: Initial number of channels for the upsampling layers.
|
| 284 |
+
This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
|
| 285 |
+
resblock: Type of residual block to use ("1", "2", or "AMP1").
|
| 286 |
+
This value is read from the checkpoint at `config.vocoder.resblock`.
|
| 287 |
+
output_sampling_rate: Waveform sample rate.
|
| 288 |
+
This value is read from the checkpoint at `config.vocoder.output_sampling_rate`.
|
| 289 |
+
activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1".
|
| 290 |
+
use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True).
|
| 291 |
+
apply_final_activation: Whether to apply the final tanh/clamp activation.
|
| 292 |
+
use_bias_at_final: Whether to use bias in the final conv layer.
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
def __init__( # noqa: PLR0913
|
| 296 |
+
self,
|
| 297 |
+
resblock_kernel_sizes: List[int] | None = None,
|
| 298 |
+
upsample_rates: List[int] | None = None,
|
| 299 |
+
upsample_kernel_sizes: List[int] | None = None,
|
| 300 |
+
resblock_dilation_sizes: List[List[int]] | None = None,
|
| 301 |
+
upsample_initial_channel: int = 1024,
|
| 302 |
+
resblock: str = "1",
|
| 303 |
+
output_sampling_rate: int = 24000,
|
| 304 |
+
activation: str = "snake",
|
| 305 |
+
use_tanh_at_final: bool = True,
|
| 306 |
+
apply_final_activation: bool = True,
|
| 307 |
+
use_bias_at_final: bool = True,
|
| 308 |
+
) -> None:
|
| 309 |
+
super().__init__()
|
| 310 |
+
|
| 311 |
+
# Mutable default values are not supported as default arguments.
|
| 312 |
+
if resblock_kernel_sizes is None:
|
| 313 |
+
resblock_kernel_sizes = [3, 7, 11]
|
| 314 |
+
if upsample_rates is None:
|
| 315 |
+
upsample_rates = [6, 5, 2, 2, 2]
|
| 316 |
+
if upsample_kernel_sizes is None:
|
| 317 |
+
upsample_kernel_sizes = [16, 15, 8, 4, 4]
|
| 318 |
+
if resblock_dilation_sizes is None:
|
| 319 |
+
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
| 320 |
+
|
| 321 |
+
self.output_sampling_rate = output_sampling_rate
|
| 322 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 323 |
+
self.num_upsamples = len(upsample_rates)
|
| 324 |
+
self.use_tanh_at_final = use_tanh_at_final
|
| 325 |
+
self.apply_final_activation = apply_final_activation
|
| 326 |
+
self.is_amp = resblock == "AMP1"
|
| 327 |
+
|
| 328 |
+
# All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel
|
| 329 |
+
# bins each), 2 output channels.
|
| 330 |
+
self.conv_pre = nn.Conv1d(
|
| 331 |
+
in_channels=128,
|
| 332 |
+
out_channels=upsample_initial_channel,
|
| 333 |
+
kernel_size=7,
|
| 334 |
+
stride=1,
|
| 335 |
+
padding=3,
|
| 336 |
+
)
|
| 337 |
+
resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1
|
| 338 |
+
|
| 339 |
+
self.ups = nn.ModuleList(
|
| 340 |
+
nn.ConvTranspose1d(
|
| 341 |
+
upsample_initial_channel // (2**i),
|
| 342 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
| 343 |
+
kernel_size,
|
| 344 |
+
stride,
|
| 345 |
+
padding=(kernel_size - stride) // 2,
|
| 346 |
+
)
|
| 347 |
+
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True))
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
final_channels = upsample_initial_channel // (2 ** len(upsample_rates))
|
| 351 |
+
self.resblocks = nn.ModuleList()
|
| 352 |
+
|
| 353 |
+
for i in range(len(upsample_rates)):
|
| 354 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 355 |
+
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
|
| 356 |
+
if self.is_amp:
|
| 357 |
+
self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation))
|
| 358 |
+
else:
|
| 359 |
+
self.resblocks.append(resblock_cls(ch, kernel_size, dilations))
|
| 360 |
+
|
| 361 |
+
if self.is_amp:
|
| 362 |
+
self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels))
|
| 363 |
+
else:
|
| 364 |
+
self.act_post = nn.LeakyReLU()
|
| 365 |
+
|
| 366 |
+
# All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo).
|
| 367 |
+
self.conv_post = nn.Conv1d(
|
| 368 |
+
in_channels=final_channels,
|
| 369 |
+
out_channels=2,
|
| 370 |
+
kernel_size=7,
|
| 371 |
+
stride=1,
|
| 372 |
+
padding=3,
|
| 373 |
+
bias=use_bias_at_final,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 377 |
+
"""
|
| 378 |
+
Forward pass of the vocoder.
|
| 379 |
+
Args:
|
| 380 |
+
x: Input Mel spectrogram tensor. Can be either:
|
| 381 |
+
- 3D: (batch_size, time, mel_bins) for mono
|
| 382 |
+
- 4D: (batch_size, 2, time, mel_bins) for stereo
|
| 383 |
+
Returns:
|
| 384 |
+
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
|
| 385 |
+
"""
|
| 386 |
+
x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time)
|
| 387 |
+
|
| 388 |
+
if x.dim() == 4: # stereo
|
| 389 |
+
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
| 390 |
+
x = einops.rearrange(x, "b s c t -> b (s c) t")
|
| 391 |
+
|
| 392 |
+
x = self.conv_pre(x)
|
| 393 |
+
|
| 394 |
+
for i in range(self.num_upsamples):
|
| 395 |
+
if not self.is_amp:
|
| 396 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 397 |
+
x = self.ups[i](x)
|
| 398 |
+
start = i * self.num_kernels
|
| 399 |
+
end = start + self.num_kernels
|
| 400 |
+
|
| 401 |
+
# Evaluate all resblocks with the same input tensor so they can run
|
| 402 |
+
# independently (and thus in parallel on accelerator hardware) before
|
| 403 |
+
# aggregating their outputs via mean.
|
| 404 |
+
block_outputs = torch.stack(
|
| 405 |
+
[self.resblocks[idx](x) for idx in range(start, end)],
|
| 406 |
+
dim=0,
|
| 407 |
+
)
|
| 408 |
+
x = block_outputs.mean(dim=0)
|
| 409 |
+
|
| 410 |
+
x = self.act_post(x)
|
| 411 |
+
x = self.conv_post(x)
|
| 412 |
+
|
| 413 |
+
if self.apply_final_activation:
|
| 414 |
+
x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1)
|
| 415 |
+
|
| 416 |
+
return x
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class _STFTFn(nn.Module):
|
| 420 |
+
"""Implements STFT as a convolution with precomputed DFT x Hann-window bases.
|
| 421 |
+
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
| 422 |
+
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
| 423 |
+
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
| 424 |
+
bit-identical to what it was trained on.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
|
| 428 |
+
super().__init__()
|
| 429 |
+
self.hop_length = hop_length
|
| 430 |
+
self.win_length = win_length
|
| 431 |
+
n_freqs = filter_length // 2 + 1
|
| 432 |
+
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
| 433 |
+
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
| 434 |
+
|
| 435 |
+
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 436 |
+
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
| 437 |
+
Applies causal (left-only) padding of win_length - hop_length samples so that
|
| 438 |
+
each output frame depends only on past and present input — no lookahead.
|
| 439 |
+
Args:
|
| 440 |
+
y: Waveform tensor of shape (B, T).
|
| 441 |
+
Returns:
|
| 442 |
+
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
| 443 |
+
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
| 444 |
+
"""
|
| 445 |
+
if y.dim() == 2:
|
| 446 |
+
y = y.unsqueeze(1) # (B, 1, T)
|
| 447 |
+
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
| 448 |
+
y = F.pad(y, (left_pad, 0))
|
| 449 |
+
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
| 450 |
+
n_freqs = spec.shape[1] // 2
|
| 451 |
+
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
| 452 |
+
magnitude = torch.sqrt(real**2 + imag**2)
|
| 453 |
+
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
| 454 |
+
return magnitude, phase
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class MelSTFT(nn.Module):
|
| 458 |
+
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
| 459 |
+
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
| 460 |
+
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
| 461 |
+
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
| 462 |
+
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
| 463 |
+
"""
|
| 464 |
+
|
| 465 |
+
def __init__(
|
| 466 |
+
self,
|
| 467 |
+
filter_length: int,
|
| 468 |
+
hop_length: int,
|
| 469 |
+
win_length: int,
|
| 470 |
+
n_mel_channels: int,
|
| 471 |
+
) -> None:
|
| 472 |
+
super().__init__()
|
| 473 |
+
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
| 474 |
+
|
| 475 |
+
# Initialized to zeros; load_state_dict overwrites with the checkpoint's
|
| 476 |
+
# exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]).
|
| 477 |
+
n_freqs = filter_length // 2 + 1
|
| 478 |
+
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
| 479 |
+
|
| 480 |
+
def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 481 |
+
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
| 482 |
+
Args:
|
| 483 |
+
y: Waveform tensor of shape (B, T).
|
| 484 |
+
Returns:
|
| 485 |
+
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
| 486 |
+
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
| 487 |
+
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
| 488 |
+
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
| 489 |
+
"""
|
| 490 |
+
magnitude, phase = self.stft_fn(y)
|
| 491 |
+
energy = torch.norm(magnitude, dim=1)
|
| 492 |
+
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
| 493 |
+
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
| 494 |
+
return log_mel, magnitude, phase, energy
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class VocoderWithBWE(nn.Module):
|
| 498 |
+
"""Vocoder with bandwidth extension (BWE) upsampling.
|
| 499 |
+
Chains a mel-to-wav vocoder with a BWE module that upsamples the output
|
| 500 |
+
to a higher sample rate. The BWE computes a mel spectrogram from the
|
| 501 |
+
vocoder output, runs it through a second generator to predict a residual,
|
| 502 |
+
and adds it to a sinc-resampled skip connection.
|
| 503 |
+
"""
|
| 504 |
+
|
| 505 |
+
def __init__(
|
| 506 |
+
self,
|
| 507 |
+
vocoder: Vocoder,
|
| 508 |
+
bwe_generator: Vocoder,
|
| 509 |
+
mel_stft: MelSTFT,
|
| 510 |
+
input_sampling_rate: int,
|
| 511 |
+
output_sampling_rate: int,
|
| 512 |
+
hop_length: int,
|
| 513 |
+
) -> None:
|
| 514 |
+
super().__init__()
|
| 515 |
+
self.vocoder = vocoder
|
| 516 |
+
self.bwe_generator = bwe_generator
|
| 517 |
+
self.mel_stft = mel_stft
|
| 518 |
+
self.input_sampling_rate = input_sampling_rate
|
| 519 |
+
self.output_sampling_rate = output_sampling_rate
|
| 520 |
+
self.hop_length = hop_length
|
| 521 |
+
# Compute the resampler on CPU so the sinc filter is materialized even when
|
| 522 |
+
# the model is constructed on meta device (SingleGPUModelBuilder pattern).
|
| 523 |
+
# The filter is not stored in the checkpoint (persistent=False).
|
| 524 |
+
with torch.device("cpu"):
|
| 525 |
+
self.resampler = UpSample1d(
|
| 526 |
+
ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann"
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
@property
|
| 530 |
+
def conv_pre(self) -> nn.Conv1d:
|
| 531 |
+
return self.vocoder.conv_pre
|
| 532 |
+
|
| 533 |
+
@property
|
| 534 |
+
def conv_post(self) -> nn.Conv1d:
|
| 535 |
+
return self.vocoder.conv_post
|
| 536 |
+
|
| 537 |
+
def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor:
|
| 538 |
+
"""Compute log-mel spectrogram from waveform using causal STFT bases.
|
| 539 |
+
Args:
|
| 540 |
+
audio: Waveform tensor of shape (B, C, T).
|
| 541 |
+
Returns:
|
| 542 |
+
mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames).
|
| 543 |
+
"""
|
| 544 |
+
batch, n_channels, _ = audio.shape
|
| 545 |
+
flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
|
| 546 |
+
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
| 547 |
+
return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
| 548 |
+
|
| 549 |
+
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
| 550 |
+
"""Run the full vocoder + BWE forward pass.
|
| 551 |
+
Args:
|
| 552 |
+
mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo
|
| 553 |
+
or (B, T, mel_bins) for mono. Same format as Vocoder.forward.
|
| 554 |
+
Returns:
|
| 555 |
+
Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1].
|
| 556 |
+
"""
|
| 557 |
+
x = self.vocoder(mel_spec)
|
| 558 |
+
_, _, length_low_rate = x.shape
|
| 559 |
+
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
| 560 |
+
|
| 561 |
+
# Pad to multiple of hop_length for exact mel frame count
|
| 562 |
+
remainder = length_low_rate % self.hop_length
|
| 563 |
+
if remainder != 0:
|
| 564 |
+
x = F.pad(x, (0, self.hop_length - remainder))
|
| 565 |
+
|
| 566 |
+
# Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames)
|
| 567 |
+
mel = self._compute_mel(x)
|
| 568 |
+
|
| 569 |
+
# Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator
|
| 570 |
+
mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins)
|
| 571 |
+
residual = self.bwe_generator(mel_for_bwe)
|
| 572 |
+
skip = self.resampler(x)
|
| 573 |
+
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
| 574 |
+
|
| 575 |
+
return torch.clamp(residual + skip, -1, 1)[..., :output_length]
|
packages/ltx-core/src/ltx_core/model/model_protocol.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Protocol, TypeVar
|
| 2 |
+
|
| 3 |
+
ModelType = TypeVar("ModelType")
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ModelConfigurator(Protocol[ModelType]):
|
| 7 |
+
"""Protocol for model loader classes that instantiates models from a configuration dictionary."""
|
| 8 |
+
|
| 9 |
+
@classmethod
|
| 10 |
+
def from_config(cls, config: dict) -> ModelType: ...
|
packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.transformer.gelu_approx import GELUApprox
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FeedForward(torch.nn.Module):
|
| 7 |
+
def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None:
|
| 8 |
+
super().__init__()
|
| 9 |
+
inner_dim = int(dim * mult)
|
| 10 |
+
project_in = GELUApprox(dim, inner_dim)
|
| 11 |
+
|
| 12 |
+
self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out))
|
| 13 |
+
|
| 14 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
return self.net(x)
|
packages/ltx-core/src/ltx_core/model/upsampler/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Latent upsampler model components."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.upsampler.model import LatentUpsampler, upsample_video
|
| 4 |
+
from ltx_core.model.upsampler.model_configurator import LatentUpsamplerConfigurator
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"LatentUpsampler",
|
| 8 |
+
"LatentUpsamplerConfigurator",
|
| 9 |
+
"upsample_video",
|
| 10 |
+
]
|