Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +8 -34
- .gitignore +41 -0
- LICENSE +381 -0
- README.md +120 -0
- packages/ltx-core/README.md +409 -0
- packages/ltx-core/src/ltx_core/conditioning/types/__init__.py +13 -0
- packages/ltx-core/src/ltx_core/conditioning/types/attention_strength_wrapper.py +71 -0
- packages/ltx-core/src/ltx_core/conditioning/types/keyframe_cond.py +70 -0
- packages/ltx-core/src/ltx_core/conditioning/types/latent_cond.py +44 -0
- packages/ltx-core/src/ltx_core/conditioning/types/reference_video_cond.py +91 -0
- packages/ltx-core/src/ltx_core/guidance/__pycache__/__init__.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/guidance/__pycache__/perturbations.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/causality_axis.py +10 -0
- packages/ltx-core/src/ltx_core/model/common/__init__.py +9 -0
- packages/ltx-core/src/ltx_core/model/common/__pycache__/__init__.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/common/__pycache__/normalization.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/common/normalization.py +59 -0
- packages/ltx-core/src/ltx_core/model/transformer/__init__.py +18 -0
- packages/ltx-core/src/ltx_core/model/transformer/__pycache__/adaln.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/transformer/__pycache__/transformer_args.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/transformer/adaln.py +45 -0
- packages/ltx-core/src/ltx_core/model/transformer/attention.py +249 -0
- packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py +10 -0
- packages/ltx-core/src/ltx_core/model/transformer/modality.py +40 -0
- packages/ltx-core/src/ltx_core/model/transformer/model.py +486 -0
- packages/ltx-core/src/ltx_core/model/transformer/model_configurator.py +152 -0
- packages/ltx-core/src/ltx_core/model/transformer/rope.py +204 -0
- packages/ltx-core/src/ltx_core/model/transformer/text_projection.py +38 -0
- packages/ltx-core/src/ltx_core/model/transformer/timestep_embedding.py +143 -0
- packages/ltx-core/src/ltx_core/model/transformer/transformer.py +398 -0
- packages/ltx-core/src/ltx_core/model/transformer/transformer_args.py +297 -0
- packages/ltx-core/src/ltx_core/model/video_vae/__init__.py +24 -0
- packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/__init__.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/convolution.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/enums.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/model_configurator.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/ops.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/resnet.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/sampling.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/tiling.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/video_vae.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/model/video_vae/convolution.py +317 -0
- packages/ltx-core/src/ltx_core/model/video_vae/model_configurator.py +79 -0
- packages/ltx-core/src/ltx_core/model/video_vae/resnet.py +277 -0
- packages/ltx-core/src/ltx_core/model/video_vae/tiling.py +291 -0
- packages/ltx-core/src/ltx_core/model/video_vae/video_vae.py +1219 -0
- packages/ltx-core/src/ltx_core/quantization/__pycache__/__init__.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/quantization/__pycache__/fp8_cast.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/quantization/__pycache__/fp8_scaled_mm.cpython-312.pyc +0 -0
- packages/ltx-core/src/ltx_core/quantization/__pycache__/policy.cpython-312.pyc +0 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,9 @@
|
|
| 1 |
-
*.
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
|
| 27 |
-
*.
|
| 28 |
-
*.
|
| 29 |
-
*.
|
| 30 |
-
*.
|
| 31 |
-
*.
|
| 32 |
-
*.
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.sft filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
checkpoints/
|
| 8 |
+
*.egg-info
|
| 9 |
+
|
| 10 |
+
# Virtual environments
|
| 11 |
+
.venv
|
| 12 |
+
.python-version
|
| 13 |
+
|
| 14 |
+
# IDE settings
|
| 15 |
+
.idea/
|
| 16 |
+
.vscode/
|
| 17 |
+
|
| 18 |
+
# Other files
|
| 19 |
+
.DS_Store
|
| 20 |
+
tmp
|
| 21 |
+
.wandb
|
| 22 |
+
|
| 23 |
+
# Model checkpoints
|
| 24 |
+
*.ckpt
|
| 25 |
+
*.pt
|
| 26 |
+
*.safetensors
|
| 27 |
+
*.sft
|
| 28 |
+
|
| 29 |
+
# Media files
|
| 30 |
+
*.gif
|
| 31 |
+
*.heic
|
| 32 |
+
*.heif
|
| 33 |
+
*.jpg
|
| 34 |
+
*.jpeg
|
| 35 |
+
*.json
|
| 36 |
+
*.m4a
|
| 37 |
+
*.mov
|
| 38 |
+
*.mp4
|
| 39 |
+
*.png
|
| 40 |
+
*.wav
|
| 41 |
+
*.webp
|
LICENSE
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LTX-2 Community License Agreement
|
| 2 |
+
License date: January 5, 2026
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
By using or distributing any portion or element of LTX-2, you agree
|
| 6 |
+
to be bound by this Agreement.
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"Agreement" means the terms and conditions for the license, use,
|
| 11 |
+
reproduction, and distribution of LTX-2 and the Complementary
|
| 12 |
+
Materials, as specified in this document.
|
| 13 |
+
|
| 14 |
+
"Control" means the direct or indirect ownership of more than
|
| 15 |
+
fifty percent (50%) of the voting securities or other ownership
|
| 16 |
+
interests, or the power to direct the management and policies of
|
| 17 |
+
such Entity through voting rights, contract, or otherwise.
|
| 18 |
+
|
| 19 |
+
"Data" means a collection of information and/or content extracted
|
| 20 |
+
from the dataset used with LTX-2, including to train, pretrain,
|
| 21 |
+
or otherwise evaluate LTX-2. The Data is not licensed under this
|
| 22 |
+
Agreement.
|
| 23 |
+
|
| 24 |
+
"Derivatives of LTX-2" means all modifications to LTX-2, works
|
| 25 |
+
based on LTX-2, or any other model which is created or initialized
|
| 26 |
+
by transfer of patterns of the weights, parameters, activations or
|
| 27 |
+
output of LTX-2, to the other model, in order to cause the other
|
| 28 |
+
model to perform similarly to LTX-2, including – but not limited
|
| 29 |
+
to - distillation methods entailing the use of intermediate data
|
| 30 |
+
representations or methods based on the generation of synthetic
|
| 31 |
+
data by LTX-2 for training the other model. For clarity, Derivatives
|
| 32 |
+
of LTX-2 include: (i) any fine-tuned or adapted weights, parameters,
|
| 33 |
+
or checkpoints derived from LTX-2; (ii) derivative model architectures
|
| 34 |
+
that incorporate or are based upon LTX-2's architecture; and
|
| 35 |
+
(iii) any modified or extended versions of the Complementary
|
| 36 |
+
Materials. All intellectual property rights in Derivatives of LTX-2
|
| 37 |
+
shall be subject to the terms of this Agreement, and you may not
|
| 38 |
+
claim exclusive ownership rights in any Derivatives of LTX-2 that
|
| 39 |
+
would restrict the rights granted herein.
|
| 40 |
+
|
| 41 |
+
"Entity" means any individual, corporation, partnership, limited
|
| 42 |
+
liability company, or other legal entity. For purposes of this
|
| 43 |
+
Agreement, an Entity shall be deemed to include, on an aggregative
|
| 44 |
+
basis, all subsidiaries, affiliates, and other companies under
|
| 45 |
+
common Control with such Entity. When determining whether an Entity
|
| 46 |
+
meets any threshold under this Agreement (including revenue
|
| 47 |
+
thresholds), all subsidiaries, affiliates, and companies under
|
| 48 |
+
common Control shall be considered collectively.
|
| 49 |
+
|
| 50 |
+
"Harm" includes but is not limited to physical, mental,
|
| 51 |
+
psychological, financial and reputational damage, pain, or loss.
|
| 52 |
+
|
| 53 |
+
"Licensor" or "Lightricks" means the owner that is granting the
|
| 54 |
+
license under this Agreement. For the purposes of this Agreement,
|
| 55 |
+
the Licensor is Lightricks Ltd.
|
| 56 |
+
|
| 57 |
+
"LTX-2" means the large language models, text/image/video/audio/3D
|
| 58 |
+
generation models, and multimodal large language models and their
|
| 59 |
+
software and algorithms, including trained model weights, parameters
|
| 60 |
+
(including optimizer states), machine-learning model code,
|
| 61 |
+
inference-enabling code, training-enabling code, fine-tuning
|
| 62 |
+
enabling code, accompanying source code, scripts, documentation,
|
| 63 |
+
tutorials, examples, and all other elements of the foregoing
|
| 64 |
+
distributed and made publicly available by Lightricks (including,
|
| 65 |
+
for example, at https://github.com/Lightricks/LTX-2) for the LTX-2
|
| 66 |
+
model released on January 5, 2026. This license is applicable to
|
| 67 |
+
all LTX-2 versions released since January 5, 2026, and all future
|
| 68 |
+
releases of LTX-2 under this license.
|
| 69 |
+
|
| 70 |
+
"Output" means the results of operating LTX-2 as embodied in
|
| 71 |
+
informational content resulting therefrom.
|
| 72 |
+
|
| 73 |
+
"you" (or "your") means an individual or legal Entity licensing
|
| 74 |
+
LTX-2 in accordance with this Agreement and/or making use of LTX-2
|
| 75 |
+
for whichever purpose and in any field of use, including usage of
|
| 76 |
+
LTX-2 in an end-use application - e.g. chatbot, translator, image
|
| 77 |
+
generator.
|
| 78 |
+
|
| 79 |
+
2. Grant of License. Subject to the terms and conditions of this
|
| 80 |
+
Agreement, you are granted a non-exclusive, worldwide,
|
| 81 |
+
non-transferable and royalty-free limited license under Licensor's
|
| 82 |
+
intellectual property or other rights owned by Licensor embodied
|
| 83 |
+
in LTX-2 to use, reproduce, prepare, distribute, publicly display,
|
| 84 |
+
publicly perform, sublicense, copy, create derivative works of,
|
| 85 |
+
and make modifications to LTX-2, for any purpose, subject to the
|
| 86 |
+
restrictions set forth in Attachment A; provided however, that
|
| 87 |
+
Entities with annual revenues of at least $10,000,000 (the
|
| 88 |
+
"Commercial Entities") are required to obtain a paid commercial
|
| 89 |
+
use license in order to use LTX-2 and Derivatives of LTX-2,
|
| 90 |
+
subject to the terms and provisions of a different license (the
|
| 91 |
+
"Commercial Use Agreement"), as will be provided by the Licensor.
|
| 92 |
+
Commercial Entities interested in such a commercial license are
|
| 93 |
+
required to [contact Licensor](https://ltx.io/model/licensing).
|
| 94 |
+
Any commercial use of LTX-2 or Derivatives of LTX-2 by the
|
| 95 |
+
Commercial Entities not in accordance with this Agreement and/or
|
| 96 |
+
the Commercial Use Agreement is strictly prohibited and shall be
|
| 97 |
+
deemed a material breach of this Agreement. Such material breach
|
| 98 |
+
will be subject, in addition to any license fees owed to Licensor
|
| 99 |
+
for the period such Commercial Entity used LTX-2 (as will be
|
| 100 |
+
determined by Licensor), to liquidated damages, which will be paid
|
| 101 |
+
to Licensor immediately upon demand, in an amount equal to double
|
| 102 |
+
the amount that would otherwise have been paid by you for the
|
| 103 |
+
relevant period of time. Such amount reflects a reasonable estimation
|
| 104 |
+
of the losses and administrative costs incurred due to such breach.
|
| 105 |
+
You agree and understand that this remedy does not limit the Licensor's
|
| 106 |
+
right to pursue other remedies available at law or equity.
|
| 107 |
+
|
| 108 |
+
3. Distribution and Redistribution. You may host for third parties
|
| 109 |
+
remote access purposes (e.g. software-as-a-service), reproduce
|
| 110 |
+
and distribute copies of LTX-2 or Derivatives of LTX-2 thereof in
|
| 111 |
+
any medium, with or without modifications, provided that you meet
|
| 112 |
+
the following conditions:
|
| 113 |
+
|
| 114 |
+
(a) Use-based restrictions as referenced in paragraph 4 and all
|
| 115 |
+
provisions of Attachment A MUST be included as an enforceable
|
| 116 |
+
provision by you in any type of legal agreement (e.g. a
|
| 117 |
+
license) governing the use and/or distribution of LTX-2 or
|
| 118 |
+
Derivatives of LTX-2, and you shall give notice to subsequent
|
| 119 |
+
users you distribute to, that LTX-2 or Derivatives of LTX-2
|
| 120 |
+
are subject to paragraph 4 and Attachment A in their entirety,
|
| 121 |
+
including all use restrictions and acceptable use policies;
|
| 122 |
+
|
| 123 |
+
(b) You must provide any third party recipients of LTX-2 or
|
| 124 |
+
Derivatives of LTX-2 a copy of this Agreement, including all
|
| 125 |
+
attachments and use policies. Any Derivative of LTX-2 (as
|
| 126 |
+
defined in Section 1, including but not limited to fine-tuned
|
| 127 |
+
weights, modified training code, models trained on Outputs, or
|
| 128 |
+
any other derivative) must be distributed exclusively under
|
| 129 |
+
the terms of this Agreement with a complete copy of this
|
| 130 |
+
license included;
|
| 131 |
+
|
| 132 |
+
(c) You must cause any modified files to carry prominent notices
|
| 133 |
+
stating that you changed the files;
|
| 134 |
+
|
| 135 |
+
(d) You must retain all copyright, patent, trademark, and
|
| 136 |
+
attribution notices excluding those notices that do not
|
| 137 |
+
pertain to any part of LTX-2, Derivatives of LTX-2.
|
| 138 |
+
|
| 139 |
+
You may add your own copyright statement to your modifications and
|
| 140 |
+
may provide additional or different license terms and conditions -
|
| 141 |
+
respecting paragraph 3(a) - for use, reproduction, or distribution
|
| 142 |
+
of your modifications, or for any such Derivatives of LTX-2 as a
|
| 143 |
+
whole, provided your use, reproduction, and distribution of LTX-2
|
| 144 |
+
otherwise complies with the conditions stated in this Agreement,
|
| 145 |
+
and you provide a complete copy of this Agreement with any such
|
| 146 |
+
use, reproduction and distribution of LTX-2 and any Derivatives
|
| 147 |
+
thereof.
|
| 148 |
+
|
| 149 |
+
4. Use-based restrictions. The restrictions set forth in Attachment A
|
| 150 |
+
are considered Use-based restrictions. Therefore, you cannot use
|
| 151 |
+
LTX-2 and the Derivatives of LTX-2 in violation of the specified
|
| 152 |
+
restricted uses. You may use LTX-2 subject to this Agreement,
|
| 153 |
+
including only for lawful purposes and in accordance with the
|
| 154 |
+
Agreement. "Use" may include creating any content with, fine-tuning,
|
| 155 |
+
updating, running, training, evaluating and/or re-parametrizing
|
| 156 |
+
LTX-2. You shall require all of your users who use LTX-2 or a
|
| 157 |
+
Derivative of LTX-2 to comply with the terms of this paragraph 4.
|
| 158 |
+
|
| 159 |
+
5. The Output You Generate. Except as set forth herein, Licensor
|
| 160 |
+
claims no rights in the Output you generate using LTX-2. You are
|
| 161 |
+
accountable for input you insert into LTX-2, the Output you
|
| 162 |
+
generate and its subsequent uses. No use of the Output can
|
| 163 |
+
contravene any provision as stated in the Agreement.
|
| 164 |
+
|
| 165 |
+
6. Updates and Runtime Restrictions. To the maximum extent permitted
|
| 166 |
+
by law, Licensor reserves the right to restrict (remotely or
|
| 167 |
+
otherwise) usage of LTX-2 in violation of this Agreement, update
|
| 168 |
+
LTX-2 through electronic means, or modify the Output of LTX-2
|
| 169 |
+
based on updates. You shall undertake reasonable efforts to use
|
| 170 |
+
the latest version of LTX-2. Any use of the non-current version
|
| 171 |
+
of LTX-2 is done solely at your risk.
|
| 172 |
+
|
| 173 |
+
7. Export Controls and Sanctions Compliance. You acknowledge that
|
| 174 |
+
LTX-2, Derivatives of LTX-2 may be subject to export control laws
|
| 175 |
+
and regulations, including but not limited to the U.S. Export
|
| 176 |
+
Administration Regulations and sanctions programs administered by
|
| 177 |
+
the Office of Foreign Assets Control (OFAC). You represent and
|
| 178 |
+
warrant that you and any users of LTX-2 are not (i) located in,
|
| 179 |
+
organized under the laws of, or ordinarily resident in any country
|
| 180 |
+
or territory subject to comprehensive sanctions; (ii) identified
|
| 181 |
+
on any U.S. government restricted party list, including the
|
| 182 |
+
Specially Designated Nationals and Blocked Persons List; or
|
| 183 |
+
(iii) otherwise prohibited from receiving LTX-2 under applicable
|
| 184 |
+
law. You shall not export, re-export, or transfer LTX-2, directly
|
| 185 |
+
or indirectly, in violation of any applicable export control or
|
| 186 |
+
sanctions laws or regulations. You agree to comply with all
|
| 187 |
+
applicable trade control laws and shall indemnify and hold
|
| 188 |
+
Licensor harmless from any claims arising from your failure to
|
| 189 |
+
comply with such laws.
|
| 190 |
+
|
| 191 |
+
8. Trademarks and related. Nothing in this Agreement permits you to
|
| 192 |
+
make use of Licensor's trademarks, trade names, logos or to
|
| 193 |
+
otherwise suggest endorsement or misrepresent the relationship
|
| 194 |
+
between the parties; and any rights not expressly granted herein
|
| 195 |
+
are reserved by the Licensor.
|
| 196 |
+
|
| 197 |
+
9. Disclaimer of Warranty. Unless required by applicable law or
|
| 198 |
+
agreed to in writing, Licensor provides LTX-2 on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 200 |
+
implied, including, without limitation, any warranties or
|
| 201 |
+
conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS
|
| 202 |
+
FOR A PARTICULAR PURPOSE. You are solely responsible for
|
| 203 |
+
determining the appropriateness of using or redistributing LTX-2
|
| 204 |
+
and Derivatives of LTX-2 and assume any risks associated with
|
| 205 |
+
your exercise of permissions under this Agreement.
|
| 206 |
+
|
| 207 |
+
10. Limitation of Liability. In no event and under no legal theory,
|
| 208 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 209 |
+
unless required by applicable law (such as deliberate and grossly
|
| 210 |
+
negligent acts) or agreed to in writing, shall Licensor be liable
|
| 211 |
+
to you for damages, including any direct, indirect, special,
|
| 212 |
+
incidental, or consequential damages of any character arising as
|
| 213 |
+
a result of this Agreement or out of the use or inability to use
|
| 214 |
+
LTX-2 (including but not limited to damages for loss of goodwill,
|
| 215 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 216 |
+
other commercial damages or losses), even if Licensor has been
|
| 217 |
+
advised of the possibility of such damages.
|
| 218 |
+
|
| 219 |
+
11. Accepting Warranty or Additional Liability. While redistributing
|
| 220 |
+
LTX-2 and Derivatives of LTX-2, you may, provided you do not
|
| 221 |
+
violate the terms of this Agreement, choose to offer and charge
|
| 222 |
+
a fee for, acceptance of support, warranty, indemnity, or other
|
| 223 |
+
liability obligations. However, in accepting such obligations,
|
| 224 |
+
you may act only on your own behalf and on your sole
|
| 225 |
+
responsibility, not on behalf of Licensor, and only if you agree
|
| 226 |
+
to indemnify, defend, and hold Licensor harmless for any liability
|
| 227 |
+
incurred by, or claims asserted against Licensor, by reason of
|
| 228 |
+
your accepting any such warranty or additional liability.
|
| 229 |
+
|
| 230 |
+
12. Governing Law. This Agreement and all relations, disputes, claims
|
| 231 |
+
and other matters arising hereunder (including non-contractual
|
| 232 |
+
disputes or claims) will be governed exclusively by, and construed
|
| 233 |
+
exclusively in accordance with, the laws of the State of New York.
|
| 234 |
+
To the extent permitted by law, choice of laws rules and the
|
| 235 |
+
United Nations Convention on Contracts for the International Sale
|
| 236 |
+
of Goods will not apply. For the purposes of adjudicating any
|
| 237 |
+
action or proceeding to enforce the terms of this Agreement, you
|
| 238 |
+
hereby irrevocably consent to the exclusive jurisdiction of, and
|
| 239 |
+
venue in, the federal and state courts located in the County of
|
| 240 |
+
New York within the State of New York. The prevailing party in
|
| 241 |
+
any claim or dispute between the parties under this Agreement
|
| 242 |
+
will be entitled to reimbursement of its reasonable attorneys'
|
| 243 |
+
fees and costs. You hereby waive the right to a trial by jury,
|
| 244 |
+
to participate in a class or representative action (including in
|
| 245 |
+
arbitration), or to combine individual proceedings in court or
|
| 246 |
+
in arbitration without the consent of all parties.
|
| 247 |
+
|
| 248 |
+
13. Term and Termination. This Agreement is effective upon your
|
| 249 |
+
acceptance and continues until terminated. Licensor may terminate
|
| 250 |
+
this Agreement immediately upon written notice to you if you
|
| 251 |
+
breach any provision of this Agreement, including but not limited
|
| 252 |
+
to violations of the use restrictions in Attachment A or
|
| 253 |
+
unauthorized commercial use. Upon termination: (a) all rights
|
| 254 |
+
granted to you under this Agreement will immediately cease;
|
| 255 |
+
(b) you must immediately cease all use of LTX-2 and Derivatives
|
| 256 |
+
of LTX-2; (c) you must delete or destroy all copies of LTX-2
|
| 257 |
+
and Derivatives of LTX-2 in your possession or control; and
|
| 258 |
+
(d) you must notify any third parties to whom you distributed
|
| 259 |
+
LTX-2 or Derivatives of LTX-2 of the termination. Sections 8-13,
|
| 260 |
+
and Section 15 shall survive termination of this Agreement.
|
| 261 |
+
Termination does not relieve you of any obligations incurred
|
| 262 |
+
prior to termination, including payment obligations under
|
| 263 |
+
Section 2. In addition, if You commence a lawsuit or other
|
| 264 |
+
proceedings (including a cross-claim or counterclaim in a lawsuit)
|
| 265 |
+
against Licensor or any person or entity alleging that LTX-2 or
|
| 266 |
+
any Output, or any portion of any of the foregoing, infringe any
|
| 267 |
+
intellectual property or other right owned or licensable by you,
|
| 268 |
+
then all licenses granted to you under this Agreement shall
|
| 269 |
+
terminate as of the date such lawsuit or other proceeding is filed.
|
| 270 |
+
|
| 271 |
+
14. Disputes and Arbitration. All disputes arising in connection with
|
| 272 |
+
this Agreement shall be finally settled by arbitration under the
|
| 273 |
+
Rules of Arbitration of the International Chamber of Commerce
|
| 274 |
+
("ICC Rules"), by one (1) arbitrator appointed in accordance with
|
| 275 |
+
the ICC Rules. The seat of arbitration shall be New York, NY, USA,
|
| 276 |
+
and the proceedings shall be conducted in English. The arbitrator
|
| 277 |
+
shall be empowered to grant any relief that a court could grant.
|
| 278 |
+
Judgment on the arbitration award may be entered by any court
|
| 279 |
+
having jurisdiction thereof. Each party waives its right to a
|
| 280 |
+
trial by jury and to participate in any class or representative
|
| 281 |
+
action.
|
| 282 |
+
|
| 283 |
+
15. If any provision of this Agreement is held to be
|
| 284 |
+
invalid, illegal
|
| 285 |
+
or unenforceable, the remaining provisions shall be unaffected
|
| 286 |
+
thereby and remain valid as if such provision had not been set
|
| 287 |
+
forth herein.
|
| 288 |
+
|
| 289 |
+
END OF TERMS AND CONDITIONS
|
| 290 |
+
|
| 291 |
+
ATTACHMENT A: Use Restrictions
|
| 292 |
+
|
| 293 |
+
When using the Outputs, LTX-2 and any Derivatives thereof, you
|
| 294 |
+
will comply with the Acceptable Use Policy. In addition, you
|
| 295 |
+
agree not to use the Outputs, LTX-2 or its Derivatives in any
|
| 296 |
+
of the following ways:
|
| 297 |
+
|
| 298 |
+
1. In any way that violates any applicable national, federal,
|
| 299 |
+
state, local or international law or regulation;
|
| 300 |
+
|
| 301 |
+
2. For the purpose of exploiting, Harming or attempting to
|
| 302 |
+
exploit or Harm minors in any way;
|
| 303 |
+
|
| 304 |
+
3. To generate or disseminate false information and/or content
|
| 305 |
+
with the purpose of Harming others;
|
| 306 |
+
|
| 307 |
+
4. To generate or disseminate personal identifiable information
|
| 308 |
+
that can be used to Harm an individual;
|
| 309 |
+
|
| 310 |
+
5. To generate or disseminate information and/or content (e.g.
|
| 311 |
+
images, code, posts, articles), and place the information
|
| 312 |
+
and/or content in any context (e.g. bot generating tweets)
|
| 313 |
+
without expressly and intelligibly disclaiming that the
|
| 314 |
+
information and/or content is machine generated;
|
| 315 |
+
|
| 316 |
+
6. To defame, disparage or otherwise harass others;
|
| 317 |
+
|
| 318 |
+
7. To impersonate or attempt to impersonate (e.g. deepfakes)
|
| 319 |
+
others without their consent;
|
| 320 |
+
|
| 321 |
+
8. For fully automated decision making that adversely impacts an
|
| 322 |
+
individual's legal rights or otherwise creates or modifies a
|
| 323 |
+
binding, enforceable obligation;
|
| 324 |
+
|
| 325 |
+
9. For any use intended to or which has the effect of
|
| 326 |
+
discriminating against or Harming individuals or groups based
|
| 327 |
+
on online or offline social behavior or known or predicted
|
| 328 |
+
personal or personality characteristics;
|
| 329 |
+
|
| 330 |
+
10. To exploit any of the vulnerabilities of a specific group of
|
| 331 |
+
persons based on their age, social, physical or mental
|
| 332 |
+
characteristics, in order to materially distort the behavior
|
| 333 |
+
of a person pertaining to that group in a manner that causes
|
| 334 |
+
or is likely to cause that person or another person physical
|
| 335 |
+
or psychological Harm;
|
| 336 |
+
|
| 337 |
+
11. For any use intended to or which has the effect of
|
| 338 |
+
discriminating against individuals or groups based on legally
|
| 339 |
+
protected characteristics or categories;
|
| 340 |
+
|
| 341 |
+
12. To provide medical advice and medical results interpretation;
|
| 342 |
+
|
| 343 |
+
13. To generate or disseminate information for the purpose to be
|
| 344 |
+
used for administration of justice, law enforcement,
|
| 345 |
+
immigration or asylum processes, such as predicting an
|
| 346 |
+
individual will commit fraud/crime commitment (e.g. by text
|
| 347 |
+
profiling, drawing causal relationships between assertions
|
| 348 |
+
made in documents, indiscriminate and arbitrarily-targeted use);
|
| 349 |
+
|
| 350 |
+
14. To generate and/or disseminate malware (including – but not
|
| 351 |
+
limited to – ransomware) or any other content to be used for
|
| 352 |
+
the purpose of harming electronic systems;
|
| 353 |
+
|
| 354 |
+
15. To engage in, promote, incite, or facilitate discrimination
|
| 355 |
+
or other unlawful or harmful conduct in the provision of
|
| 356 |
+
employment, employment benefits, credit, housing, or other
|
| 357 |
+
essential goods and services;
|
| 358 |
+
|
| 359 |
+
16. To engage in, promote, incite, or facilitate the harassment,
|
| 360 |
+
abuse, threatening, or bullying of individuals or groups of
|
| 361 |
+
individuals;
|
| 362 |
+
|
| 363 |
+
17. For military, warfare, nuclear industries or applications,
|
| 364 |
+
weapons development, or any use in connection with activities
|
| 365 |
+
that may cause death, personal injury, or severe physical or
|
| 366 |
+
environmental damage;
|
| 367 |
+
|
| 368 |
+
18. For commercial use only: To train, improve, or fine-tune any
|
| 369 |
+
other machine learning model, artificial intelligence system,
|
| 370 |
+
or competing model, except for Derivatives of LTX-2 as
|
| 371 |
+
expressly permitted under this Agreement;
|
| 372 |
+
|
| 373 |
+
19. To circumvent, disable, or interfere with any technical
|
| 374 |
+
limitations, safety features, content filters, or use
|
| 375 |
+
restrictions implemented in LTX-2 by Licensor;
|
| 376 |
+
|
| 377 |
+
20. To use LTX-2 or Derivatives of LTX-2 in any product, service,
|
| 378 |
+
or application that directly competes with Licensor's
|
| 379 |
+
commercial products or services, or is designed to replace or
|
| 380 |
+
substitute Licensor's offerings in the market, without
|
| 381 |
+
obtaining a separate commercial license from Licensor.
|
README.md
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LTX-2
|
| 2 |
+
|
| 3 |
+
[](https://ltx.io)
|
| 4 |
+
[](https://huggingface.co/Lightricks/LTX-2.3)
|
| 5 |
+
[](https://app.ltx.studio/ltx-2-playground/i2v)
|
| 6 |
+
[](https://arxiv.org/abs/2601.03233)
|
| 7 |
+
[](https://discord.gg/ltxplatform)
|
| 8 |
+
|
| 9 |
+
**LTX-2** is the first DiT-based audio-video foundation model that contains all core capabilities of modern video generation in one model: synchronized audio and video, high fidelity, multiple performance modes, production-ready outputs, API access, and open access.
|
| 10 |
+
|
| 11 |
+
<div align="center">
|
| 12 |
+
<video src="https://github.com/user-attachments/assets/4414adc0-086c-43de-b367-9362eeb20228" width="70%" poster=""> </video>
|
| 13 |
+
</div>
|
| 14 |
+
|
| 15 |
+
## 🚀 Quick Start
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
# Clone the repository
|
| 19 |
+
git clone https://github.com/Lightricks/LTX-2.git
|
| 20 |
+
cd LTX-2
|
| 21 |
+
|
| 22 |
+
# Set up the environment
|
| 23 |
+
uv sync --frozen
|
| 24 |
+
source .venv/bin/activate
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
### Required Models
|
| 28 |
+
|
| 29 |
+
Download the following models from the [LTX-2.3 HuggingFace repository](https://huggingface.co/Lightricks/LTX-2.3):
|
| 30 |
+
|
| 31 |
+
**LTX-2.3 Model Checkpoint** (choose and download one of the following)
|
| 32 |
+
* [`ltx-2.3-22b-dev.safetensors`](https://huggingface.co/Lightricks/LTX-2.3/blob/main/ltx-2.3-22b-dev.safetensors) - [Download](https://huggingface.co/Lightricks/LTX-2.3/resolve/main/ltx-2.3-22b-dev.safetensors)
|
| 33 |
+
* [`ltx-2.3-22b-distilled.safetensors`](https://huggingface.co/Lightricks/LTX-2.3/blob/main/ltx-2.3-22b-distilled.safetensors) - [Download](https://huggingface.co/Lightricks/LTX-2.3/resolve/main/ltx-2.3-22b-distilled.safetensors)
|
| 34 |
+
|
| 35 |
+
**Spatial Upscaler** - Required for current two-stage pipeline implementations in this repository
|
| 36 |
+
* [`ltx-2.3-spatial-upscaler-x2-1.0.safetensors`](https://huggingface.co/Lightricks/LTX-2.3/blob/main/ltx-2.3-spatial-upscaler-x2-1.0.safetensors) - [Download](https://huggingface.co/Lightricks/LTX-2.3/resolve/main/ltx-2.3-spatial-upscaler-x2-1.0.safetensors)
|
| 37 |
+
* [`ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors`](https://huggingface.co/Lightricks/LTX-2.3/blob/main/ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors) - [Download](https://huggingface.co/Lightricks/LTX-2.3/resolve/main/ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors)
|
| 38 |
+
|
| 39 |
+
**Temporal Upscaler** - Supported by the model and will be required for future pipeline implementations
|
| 40 |
+
* [`ltx-2.3-temporal-upscaler-x2-1.0.safetensors`](https://huggingface.co/Lightricks/LTX-2.3/blob/main/ltx-2.3-temporal-upscaler-x2-1.0.safetensors) - [Download](https://huggingface.co/Lightricks/LTX-2.3/resolve/main/ltx-2.3-temporal-upscaler-x2-1.0.safetensors)
|
| 41 |
+
|
| 42 |
+
**Distilled LoRA** - Required for current two-stage pipeline implementations in this repository (except DistilledPipeline and ICLoraPipeline)
|
| 43 |
+
* [`ltx-2.3-22b-distilled-lora-384.safetensors`](https://huggingface.co/Lightricks/LTX-2.3/blob/main/ltx-2.3-22b-distilled-lora-384.safetensors) - [Download](https://huggingface.co/Lightricks/LTX-2.3/resolve/main/ltx-2.3-22b-distilled-lora-384.safetensors)
|
| 44 |
+
|
| 45 |
+
**Gemma Text Encoder** (download all assets from the repository)
|
| 46 |
+
* [`Gemma 3`](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized/tree/main)
|
| 47 |
+
|
| 48 |
+
**LoRAs**
|
| 49 |
+
* [`LTX-2.3-22b-IC-LoRA-Union-Control`](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control) - [Download](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control/resolve/main/ltx-2.3-22b-ic-lora-union-control-ref0.5.safetensors)
|
| 50 |
+
* [`LTX-2.3-22b-IC-LoRA-Inpainting`](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Inpainting) - [Download](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Inpainting/resolve/main/ltx-2.3-22b-ic-lora-inpainting.safetensors)
|
| 51 |
+
* [`LTX-2.3-22b-IC-LoRA-Motion-Track-Control`](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control) - [Download](https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control/resolve/main/ltx-2.3-22b-ic-lora-motion-track-control-ref0.5.safetensors)
|
| 52 |
+
* [`LTX-2-19b-IC-LoRA-Detailer`](https://huggingface.co/Lightricks/LTX-2-19b-IC-LoRA-Detailer) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-IC-LoRA-Detailer/resolve/main/ltx-2-19b-ic-lora-detailer.safetensors)
|
| 53 |
+
* [`LTX-2-19b-IC-LoRA-Pose-Control`](https://huggingface.co/Lightricks/LTX-2-19b-IC-LoRA-Pose-Control) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-IC-LoRA-Pose-Control/resolve/main/ltx-2-19b-ic-lora-pose-control.safetensors)
|
| 54 |
+
* [`LTX-2-19b-LoRA-Camera-Control-Dolly-In`](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In/resolve/main/ltx-2-19b-lora-camera-control-dolly-in.safetensors)
|
| 55 |
+
* [`LTX-2-19b-LoRA-Camera-Control-Dolly-Left`](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Left/resolve/main/ltx-2-19b-lora-camera-control-dolly-left.safetensors)
|
| 56 |
+
* [`LTX-2-19b-LoRA-Camera-Control-Dolly-Out`](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out/resolve/main/ltx-2-19b-lora-camera-control-dolly-out.safetensors)
|
| 57 |
+
* [`LTX-2-19b-LoRA-Camera-Control-Dolly-Right`](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Right/resolve/main/ltx-2-19b-lora-camera-control-dolly-right.safetensors)
|
| 58 |
+
* [`LTX-2-19b-LoRA-Camera-Control-Jib-Down`](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Down/resolve/main/ltx-2-19b-lora-camera-control-jib-down.safetensors)
|
| 59 |
+
* [`LTX-2-19b-LoRA-Camera-Control-Jib-Up`](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Jib-Up/resolve/main/ltx-2-19b-lora-camera-control-jib-up.safetensors)
|
| 60 |
+
* [`LTX-2-19b-LoRA-Camera-Control-Static`](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static) - [Download](https://huggingface.co/Lightricks/LTX-2-19b-LoRA-Camera-Control-Static/resolve/main/ltx-2-19b-lora-camera-control-static.safetensors)
|
| 61 |
+
|
| 62 |
+
### Available Pipelines
|
| 63 |
+
|
| 64 |
+
* **[TI2VidTwoStagesPipeline](packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py)** - Production-quality text/image-to-video with 2x upsampling (recommended)
|
| 65 |
+
* **[TI2VidTwoStagesHQPipeline](packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages_hq.py)** - Same two-stage flow as above but uses the res_2s second-order sampler (fewer steps, better quality)
|
| 66 |
+
* **[TI2VidOneStagePipeline](packages/ltx-pipelines/src/ltx_pipelines/ti2vid_one_stage.py)** - Single-stage generation for quick prototyping
|
| 67 |
+
* **[DistilledPipeline](packages/ltx-pipelines/src/ltx_pipelines/distilled.py)** - Fastest inference with 8 predefined sigmas
|
| 68 |
+
* **[ICLoraPipeline](packages/ltx-pipelines/src/ltx_pipelines/ic_lora.py)** - Video-to-video and image-to-video transformations (uses distilled model.)
|
| 69 |
+
* **[KeyframeInterpolationPipeline](packages/ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py)** - Interpolate between keyframe images
|
| 70 |
+
* **[A2VidPipelineTwoStage](packages/ltx-pipelines/src/ltx_pipelines/a2vid_two_stage.py)** - Audio-to-video generation conditioned on an input audio file
|
| 71 |
+
* **[RetakePipeline](packages/ltx-pipelines/src/ltx_pipelines/retake.py)** - Regenerate a specific time region of an existing video
|
| 72 |
+
|
| 73 |
+
### ⚡ Optimization Tips
|
| 74 |
+
|
| 75 |
+
* **Use DistilledPipeline** - Fastest inference with only 8 predefined sigmas (8 steps stage 1, 4 steps stage 2)
|
| 76 |
+
* **Enable FP8 quantization** - Enables lower memory footprint: `--quantization fp8-cast` (CLI) or `quantization=QuantizationPolicy.fp8_cast()` (Python). For Hopper GPUs with TensorRT-LLM, use `--quantization fp8-scaled-mm` for FP8 scaled matrix multiplication.
|
| 77 |
+
* **Install attention optimizations** - Use xFormers (`uv sync --extra xformers`) or [Flash Attention 3](https://github.com/Dao-AILab/flash-attention) for Hopper GPUs
|
| 78 |
+
* **Use gradient estimation** - Reduce inference steps from 40 to 20-30 while maintaining quality (see [pipeline documentation](packages/ltx-pipelines/README.md#denoising-loop-optimization))
|
| 79 |
+
* **Skip memory cleanup** - If you have sufficient VRAM, disable automatic memory cleanup between stages for faster processing
|
| 80 |
+
* **Choose single-stage pipeline** - Use `TI2VidOneStagePipeline` for faster generation when high resolution isn't required
|
| 81 |
+
|
| 82 |
+
## ✍️ Prompting for LTX-2
|
| 83 |
+
|
| 84 |
+
When writing prompts, focus on detailed, chronological descriptions of actions and scenes. Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. Start directly with the action, and keep descriptions literal and precise. Think like a cinematographer describing a shot list. Keep within 200 words. For best results, build your prompts using this structure:
|
| 85 |
+
|
| 86 |
+
- Start with main action in a single sentence
|
| 87 |
+
- Add specific details about movements and gestures
|
| 88 |
+
- Describe character/object appearances precisely
|
| 89 |
+
- Include background and environment details
|
| 90 |
+
- Specify camera angles and movements
|
| 91 |
+
- Describe lighting and colors
|
| 92 |
+
- Note any changes or sudden events
|
| 93 |
+
|
| 94 |
+
For additional guidance on writing a prompt please refer to <https://ltx.video/blog/how-to-prompt-for-ltx-2>
|
| 95 |
+
|
| 96 |
+
### Automatic Prompt Enhancement
|
| 97 |
+
|
| 98 |
+
LTX-2 pipelines support automatic prompt enhancement via an `enhance_prompt` parameter.
|
| 99 |
+
|
| 100 |
+
## 🔌 ComfyUI Integration
|
| 101 |
+
|
| 102 |
+
To use our model with ComfyUI, please follow the instructions at <https://github.com/Lightricks/ComfyUI-LTXVideo/>.
|
| 103 |
+
|
| 104 |
+
## 📦 Packages
|
| 105 |
+
|
| 106 |
+
This repository is organized as a monorepo with three main packages:
|
| 107 |
+
|
| 108 |
+
* **[ltx-core](packages/ltx-core/)** - Core model implementation, inference stack, and utilities
|
| 109 |
+
* **[ltx-pipelines](packages/ltx-pipelines/)** - High-level pipeline implementations for text-to-video, image-to-video, and other generation modes
|
| 110 |
+
* **[ltx-trainer](packages/ltx-trainer/)** - Training and fine-tuning tools for LoRA, full fine-tuning, and IC-LoRA
|
| 111 |
+
|
| 112 |
+
Each package has its own README and documentation. See the [Documentation](#-documentation) section below.
|
| 113 |
+
|
| 114 |
+
## 📚 Documentation
|
| 115 |
+
|
| 116 |
+
Each package includes comprehensive documentation:
|
| 117 |
+
|
| 118 |
+
* **[LTX-Core README](packages/ltx-core/README.md)** - Core model implementation, inference stack, and utilities
|
| 119 |
+
* **[LTX-Pipelines README](packages/ltx-pipelines/README.md)** - High-level pipeline implementations and usage guides
|
| 120 |
+
* **[LTX-Trainer README](packages/ltx-trainer/README.md)** - Training and fine-tuning documentation with detailed guides
|
packages/ltx-core/README.md
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LTX-Core
|
| 2 |
+
|
| 3 |
+
The foundational library for the LTX-2 Audio-Video generation model. This package contains the raw model definitions, component implementations, and loading logic used by `ltx-pipelines` and `ltx-trainer`.
|
| 4 |
+
|
| 5 |
+
## 📦 What's Inside?
|
| 6 |
+
|
| 7 |
+
- **`components/`**: Modular diffusion components (Schedulers, Guiders, Noisers, Patchifiers) following standard protocols
|
| 8 |
+
- **`conditioning/`**: Tools for preparing latent states and applying conditioning (image, video, keyframes)
|
| 9 |
+
- **`guidance/`**: Perturbation system for fine-grained control over attention mechanisms
|
| 10 |
+
- **`loader/`**: Utilities for loading weights from `.safetensors`, fusing LoRAs, and managing memory
|
| 11 |
+
- **`model/`**: PyTorch implementations of the LTX-2 Transformer, Video VAE, Audio VAE, Vocoder and Upscaler
|
| 12 |
+
- **`text_encoders/gemma`**: Gemma text encoder implementation with tokenizers, feature extractors, and separate encoders for audio-video and video-only generation
|
| 13 |
+
- **`quantization/`**: FP8 quantization backends (FP8-TensorRT-LLM scaled MM, FP8 cast) for reduced memory footprint.
|
| 14 |
+
|
| 15 |
+
## 🚀 Quick Start
|
| 16 |
+
|
| 17 |
+
`ltx-core` provides the building blocks (models, components, and utilities) needed to construct inference flows. For ready-made inference pipelines use [`ltx-pipelines`](../ltx-pipelines/) or [`ltx-trainer`](../ltx-trainer/) for training.
|
| 18 |
+
|
| 19 |
+
## 🔧 Installation
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
# From the repository root
|
| 23 |
+
uv sync --frozen
|
| 24 |
+
|
| 25 |
+
# Or install as a package
|
| 26 |
+
pip install -e packages/ltx-core
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Building Blocks Overview
|
| 30 |
+
|
| 31 |
+
`ltx-core` provides modular components that can be combined to build custom inference flows:
|
| 32 |
+
|
| 33 |
+
### Core Models
|
| 34 |
+
|
| 35 |
+
- **Transformer** ([`model/transformer/`](src/ltx_core/model/transformer/)): The asymmetric dual-stream LTX-2 transformer (14B-parameter video stream, 5B-parameter audio stream) with bidirectional cross-modal attention for joint audio-video processing. Expects inputs in [`Modality`](src/ltx_core/model/transformer/modality.py) format
|
| 36 |
+
- **Video VAE** ([`model/video_vae/`](src/ltx_core/model/video_vae/)): Encodes/decodes video pixels to/from latent space with temporal and spatial compression
|
| 37 |
+
- **Audio VAE** ([`model/audio_vae/`](src/ltx_core/model/audio_vae/)): Encodes/decodes audio spectrograms to/from latent space
|
| 38 |
+
- **Vocoder** ([`model/audio_vae/`](src/ltx_core/model/audio_vae/)): Neural vocoder that converts mel spectrograms to audio waveforms
|
| 39 |
+
- **Text Encoder** ([`text_encoders/`](src/ltx_core/text_encoders/)): Gemma 3-based multilingual encoder with multi-layer feature extraction and thinking tokens that produces separate embeddings for video and audio conditioning
|
| 40 |
+
- **Spatial Upscaler** ([`model/upsampler/`](src/ltx_core/model/upsampler/)): Upsamples latent representations for higher-resolution generation
|
| 41 |
+
|
| 42 |
+
### Diffusion Components
|
| 43 |
+
|
| 44 |
+
- **Schedulers** ([`components/schedulers.py`](src/ltx_core/components/schedulers.py)): Noise schedules (LTX2Scheduler, LinearQuadratic, Beta) that control the denoising process
|
| 45 |
+
- **Guiders** ([`components/guiders.py`](src/ltx_core/components/guiders.py)): Guidance strategies (CFG, STG, APG) for controlling generation quality and adherence to prompts
|
| 46 |
+
- **Noisers** ([`components/noisers.py`](src/ltx_core/components/noisers.py)): Add noise to latents according to the diffusion schedule
|
| 47 |
+
- **Patchifiers** ([`components/patchifiers.py`](src/ltx_core/components/patchifiers.py)): Convert between spatial latents `[B, C, F, H, W]` and sequence format `[B, seq_len, dim]` for transformer processing
|
| 48 |
+
|
| 49 |
+
### Conditioning & Control
|
| 50 |
+
|
| 51 |
+
- **Conditioning** ([`conditioning/`](src/ltx_core/conditioning/)): Tools for preparing and applying various conditioning types (image, video, keyframes)
|
| 52 |
+
- **Guidance** ([`guidance/`](src/ltx_core/guidance/)): Perturbation system for fine-grained control over attention mechanisms (e.g., skipping specific attention layers)
|
| 53 |
+
|
| 54 |
+
### Utilities
|
| 55 |
+
|
| 56 |
+
- **Loader** ([`loader/`](src/ltx_core/loader/)): Model loading from `.safetensors`, LoRA fusion, weight remapping, and memory management
|
| 57 |
+
- **Quantization** ([`quantization/`](src/ltx_core/quantization/)): FP8 quantization backends for reduced memory footprint and faster inference
|
| 58 |
+
|
| 59 |
+
### Loader
|
| 60 |
+
|
| 61 |
+
The `loader/` module provides `SingleGPUModelBuilder`, a frozen dataclass that loads a PyTorch model from `.safetensors` checkpoints and optionally fuses one or more LoRA adapters.
|
| 62 |
+
|
| 63 |
+
#### Basic usage
|
| 64 |
+
|
| 65 |
+
```python
|
| 66 |
+
from ltx_core.loader import SingleGPUModelBuilder
|
| 67 |
+
|
| 68 |
+
builder = SingleGPUModelBuilder(
|
| 69 |
+
model_class_configurator=MyModelConfigurator,
|
| 70 |
+
model_path="/path/to/model.safetensors",
|
| 71 |
+
)
|
| 72 |
+
model = builder.build(device=torch.device("cuda"))
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
#### Loading LoRA adapters
|
| 76 |
+
|
| 77 |
+
Use the `.lora()` method to attach one or more LoRA adapters before calling `.build()`:
|
| 78 |
+
|
| 79 |
+
```python
|
| 80 |
+
builder = (
|
| 81 |
+
SingleGPUModelBuilder(
|
| 82 |
+
model_class_configurator=MyModelConfigurator,
|
| 83 |
+
model_path="/path/to/model.safetensors",
|
| 84 |
+
)
|
| 85 |
+
.lora("/path/to/lora_a.safetensors", strength=0.8)
|
| 86 |
+
.lora("/path/to/lora_b.safetensors", strength=0.5)
|
| 87 |
+
)
|
| 88 |
+
model = builder.build(device=torch.device("cuda"))
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
#### Memory-efficient LoRA loading (`lora_load_device`)
|
| 92 |
+
|
| 93 |
+
By default, LoRA weights are loaded onto the **CPU** (`lora_load_device=torch.device("cpu")`). This means each LoRA adapter is kept in CPU memory and transferred to the GPU sequentially during weight fusion, which keeps peak GPU memory low even when fusing large adapters.
|
| 94 |
+
|
| 95 |
+
If all adapters fit comfortably in GPU memory you can skip the CPU staging by setting `lora_load_device` to the target CUDA device:
|
| 96 |
+
|
| 97 |
+
```python
|
| 98 |
+
import torch
|
| 99 |
+
from ltx_core.loader import SingleGPUModelBuilder
|
| 100 |
+
|
| 101 |
+
# Load LoRA weights directly onto the GPU (faster, but uses more GPU memory)
|
| 102 |
+
builder = SingleGPUModelBuilder(
|
| 103 |
+
model_class_configurator=MyModelConfigurator,
|
| 104 |
+
model_path="/path/to/model.safetensors",
|
| 105 |
+
lora_load_device=torch.device("cuda"),
|
| 106 |
+
).lora("/path/to/lora.safetensors", strength=1.0)
|
| 107 |
+
|
| 108 |
+
model = builder.build(device=torch.device("cuda"))
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
### Quantization
|
| 112 |
+
|
| 113 |
+
The `quantization/` module provides FP8 quantization support for the LTX-2 transformer, significantly reducing memory usage while maintaining quality. Two backends are available:
|
| 114 |
+
|
| 115 |
+
#### FP8 Scaled MM (TensorRT-LLM)
|
| 116 |
+
|
| 117 |
+
Uses NVIDIA TensorRT-LLM's `cublas_scaled_mm` for efficient FP8 matrix multiplication. Weights are stored in FP8 format with per-tensor scaling, and inputs are quantized dynamically (or statically with calibration data).
|
| 118 |
+
|
| 119 |
+
**Requirements**: `uv sync --frozen --extra fp8-trtllm`
|
| 120 |
+
|
| 121 |
+
**Usage with QuantizationPolicy:**
|
| 122 |
+
|
| 123 |
+
```python
|
| 124 |
+
from ltx_core.quantization import QuantizationPolicy
|
| 125 |
+
|
| 126 |
+
# Dynamic input quantization (no calibration needed)
|
| 127 |
+
policy = QuantizationPolicy.fp8_scaled_mm()
|
| 128 |
+
|
| 129 |
+
# Static input quantization with calibration file
|
| 130 |
+
policy = QuantizationPolicy.fp8_scaled_mm(calibration_amax_path="/path/to/amax.json")
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
The policy provides `sd_ops` and `module_ops` that can be passed to the model builder:
|
| 134 |
+
|
| 135 |
+
```python
|
| 136 |
+
from ltx_core.loader import SingleGPUModelBuilder
|
| 137 |
+
|
| 138 |
+
builder = SingleGPUModelBuilder(
|
| 139 |
+
model=model,
|
| 140 |
+
device=device,
|
| 141 |
+
sd_ops=policy.sd_ops,
|
| 142 |
+
module_ops=policy.module_ops,
|
| 143 |
+
)
|
| 144 |
+
builder.load(checkpoint_path)
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
**Calibration File Format** (for static input quantization):
|
| 148 |
+
|
| 149 |
+
```json
|
| 150 |
+
{
|
| 151 |
+
"amax_values": {
|
| 152 |
+
"transformer_blocks.0.attn.to_q.input_quantizer": 12.5,
|
| 153 |
+
"transformer_blocks.0.attn.to_k.input_quantizer": 8.3,
|
| 154 |
+
...
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
#### FP8 Cast
|
| 160 |
+
|
| 161 |
+
A simpler approach that casts weights to FP8 for storage and upcasts during inference:
|
| 162 |
+
|
| 163 |
+
```python
|
| 164 |
+
policy = QuantizationPolicy.fp8_cast()
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
For complete, production-ready pipeline implementations that combine these building blocks, see the [`ltx-pipelines`](../ltx-pipelines/) package.
|
| 168 |
+
|
| 169 |
+
---
|
| 170 |
+
|
| 171 |
+
# Architecture Overview
|
| 172 |
+
|
| 173 |
+
This section provides a deep dive into the internal architecture of the LTX-2 Audio-Video generation model.
|
| 174 |
+
|
| 175 |
+
## Table of Contents
|
| 176 |
+
|
| 177 |
+
1. [High-Level Architecture](#high-level-architecture)
|
| 178 |
+
2. [The Transformer](#the-transformer)
|
| 179 |
+
3. [Video VAE](#video-vae)
|
| 180 |
+
4. [Audio VAE](#audio-vae)
|
| 181 |
+
5. [Text Encoding (Gemma)](#text-encoding-gemma)
|
| 182 |
+
6. [Spatial Upscaler](#spatial-upsampler)
|
| 183 |
+
7. [Data Flow](#data-flow)
|
| 184 |
+
|
| 185 |
+
---
|
| 186 |
+
|
| 187 |
+
## High-Level Architecture
|
| 188 |
+
|
| 189 |
+
LTX-2 is an **asymmetric dual-stream diffusion transformer** that jointly models the text-conditioned distribution of video and audio signals, capturing true joint dependencies (unlike sequential T2V→V2A pipelines).
|
| 190 |
+
|
| 191 |
+
### Key Design Principles
|
| 192 |
+
|
| 193 |
+
- **Decoupled Latent Representations**: Separate modality-specific VAEs enable 3D RoPE (video) vs 1D RoPE (audio), independent compression optimization, and native V2A/A2V editing workflows
|
| 194 |
+
- **Asymmetric Dual-Stream**: 14B-parameter video stream (spatiotemporal dynamics) + 5B-parameter audio stream (1D temporal), sharing 48 transformer blocks but differing in width
|
| 195 |
+
- **Bidirectional Cross-Modal Attention**: 1D temporal RoPE enables sub-frame alignment, mapping visual cues to auditory events (lip-sync, foley, environmental acoustics)
|
| 196 |
+
- **Cross-Modality AdaLN**: Scaling/shift parameters conditioned on the other modality's hidden states for synchronization across differing diffusion timesteps/temporal resolutions
|
| 197 |
+
|
| 198 |
+
```text
|
| 199 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 200 |
+
│ INPUT PREPARATION │
|
| 201 |
+
│ │
|
| 202 |
+
│ Video Pixels → Video VAE Encoder → Video Latents │
|
| 203 |
+
│ Audio Waveform → Audio VAE Encoder → Audio Latents │
|
| 204 |
+
│ Text Prompt → Gemma 3 Encoder → Text Embeddings │
|
| 205 |
+
└─────────────────────────────────────────────────────────────┘
|
| 206 |
+
↓
|
| 207 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 208 |
+
│ LTX-2 ASYMMETRIC DUAL-STREAM TRANSFORMER (48 Blocks) │
|
| 209 |
+
│ │
|
| 210 |
+
│ ┌──────────────────────┐ ┌──────────────────────┐ │
|
| 211 |
+
│ │ Video Stream (14B) │ │ Audio Stream (5B) │ │
|
| 212 |
+
│ │ │ │ │ │
|
| 213 |
+
│ │ 3D RoPE (x,y,t) │ │ 1D RoPE (temporal) │ │
|
| 214 |
+
│ │ │ │ │ │
|
| 215 |
+
│ │ Self-Attn │ │ Self-Attn │ │
|
| 216 |
+
│ │ Text Cross-Attn │ │ Text Cross-Attn │ │
|
| 217 |
+
│ │ │◄────►│ │ │
|
| 218 |
+
│ │ A↔V Cross-Attn │ │ A↔V Cross-Attn │ │
|
| 219 |
+
│ │ (1D temporal RoPE) │ │ (1D temporal RoPE) │ │
|
| 220 |
+
│ │ Cross-modality │ │ Cross-modality │ │
|
| 221 |
+
│ │ AdaLN │ │ AdaLN │ │
|
| 222 |
+
│ │ Feed-Forward │ │ Feed-Forward │ │
|
| 223 |
+
│ └──────────────────────┘ └──────────────────────┘ │
|
| 224 |
+
└─────────────────────────────────────────────────────────────┘
|
| 225 |
+
↓
|
| 226 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 227 |
+
│ OUTPUT DECODING │
|
| 228 |
+
│ │
|
| 229 |
+
│ Video Latents → Video VAE Decoder → Video Pixels │
|
| 230 |
+
│ Audio Latents → Audio VAE Decoder → Mel Spectrogram │
|
| 231 |
+
│ Mel Spectrogram → Vocoder → Audio Waveform (24 kHz) │
|
| 232 |
+
└─────────────────────────────────────────────────────────────┘
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
---
|
| 236 |
+
|
| 237 |
+
## The Transformer
|
| 238 |
+
|
| 239 |
+
The core of LTX-2 is an **asymmetric dual-stream diffusion transformer** with 48 layers that processes both video and audio tokens simultaneously. The architecture allocates 14B parameters to the video stream and 5B parameters to the audio stream, reflecting the different information densities of the two modalities.
|
| 240 |
+
|
| 241 |
+
### Model Structure
|
| 242 |
+
|
| 243 |
+
**Source**: [`src/ltx_core/model/transformer/model.py`](src/ltx_core/model/transformer/model.py)
|
| 244 |
+
|
| 245 |
+
The `LTXModel` class implements the transformer. It supports both video-only and audio-video generation modes. For actual usage, see the [`ltx-pipelines`](../ltx-pipelines/) package which handles model loading and initialization.
|
| 246 |
+
|
| 247 |
+
### Transformer Block Architecture
|
| 248 |
+
|
| 249 |
+
**Source**: [`src/ltx_core/model/transformer/transformer.py`](src/ltx_core/model/transformer/transformer.py)
|
| 250 |
+
|
| 251 |
+
Each dual-stream block performs four operations sequentially:
|
| 252 |
+
|
| 253 |
+
1. **Self-Attention**: Within-modality attention for each stream
|
| 254 |
+
2. **Text Cross-Attention**: Textual prompt conditioning for both streams
|
| 255 |
+
3. **Audio-Visual Cross-Attention**: Bidirectional inter-modal exchange
|
| 256 |
+
4. **Feed-Forward Network (FFN)**: Feature refinement
|
| 257 |
+
|
| 258 |
+
```text
|
| 259 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 260 |
+
│ TRANSFORMER BLOCK │
|
| 261 |
+
│ │
|
| 262 |
+
│ VIDEO (14B): Input → RMSNorm → AdaLN → Self-Attn → │
|
| 263 |
+
│ RMSNorm → Text Cross-Attn → │
|
| 264 |
+
│ RMSNorm → AdaLN → A↔V Cross-Attn (1D RoPE) → │
|
| 265 |
+
│ RMSNorm → AdaLN → FFN → Output │
|
| 266 |
+
│ │
|
| 267 |
+
│ AUDIO (5B): Input → RMSNorm → AdaLN → Self-Attn → │
|
| 268 |
+
│ RMSNorm → Text Cross-Attn → │
|
| 269 |
+
│ RMSNorm → AdaLN → A↔V Cross-Attn (1D RoPE) → │
|
| 270 |
+
│ RMSNorm → AdaLN → FFN → Output │
|
| 271 |
+
│ │
|
| 272 |
+
│ RoPE: Video=3D (x,y,t), Audio=1D (t), Cross-Attn=1D (t) │
|
| 273 |
+
│ AdaLN: Timestep-conditioned, cross-modality for A↔V CA │
|
| 274 |
+
└─────────────────────────────────────────────────────────────┘
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
### Audio-Visual Cross-Attention Details
|
| 278 |
+
|
| 279 |
+
Bidirectional cross-attention enables tight temporal alignment: video and audio streams exchange information bidirectionally using 1D temporal RoPE (synchronization only, no spatial alignment). AdaLN gates condition on each modality's timestep for cross-modal synchronization.
|
| 280 |
+
|
| 281 |
+
### Perturbations
|
| 282 |
+
|
| 283 |
+
The transformer supports [**perturbations**](src/ltx_core/guidance/perturbations.py) that selectively skip attention operations.
|
| 284 |
+
|
| 285 |
+
Perturbations allow you to disable specific attention mechanisms during inference, which is useful for guidance techniques like STG (Spatio-Temporal Guidance).
|
| 286 |
+
|
| 287 |
+
**Supported Perturbation Types**:
|
| 288 |
+
|
| 289 |
+
- `SKIP_VIDEO_SELF_ATTN`: Skip video self-attention
|
| 290 |
+
- `SKIP_AUDIO_SELF_ATTN`: Skip audio self-attention
|
| 291 |
+
- `SKIP_A2V_CROSS_ATTN`: Skip audio-to-video cross-attention
|
| 292 |
+
- `SKIP_V2A_CROSS_ATTN`: Skip video-to-audio cross-attention
|
| 293 |
+
|
| 294 |
+
Perturbations are used internally by guidance mechanisms like STG (Spatio-Temporal Guidance). For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
|
| 295 |
+
|
| 296 |
+
---
|
| 297 |
+
|
| 298 |
+
## Video VAE
|
| 299 |
+
|
| 300 |
+
The Video VAE ([`src/ltx_core/model/video_vae/`](src/ltx_core/model/video_vae/)) encodes video pixels into latent representations and decodes them back.
|
| 301 |
+
|
| 302 |
+
### Architecture
|
| 303 |
+
|
| 304 |
+
- **Encoder**: Compresses `[B, 3, F, H, W]` pixels → `[B, 128, F', H/32, W/32]` latents
|
| 305 |
+
- Where `F' = 1 + (F-1)/8` (frame count must satisfy `(F-1) % 8 == 0`)
|
| 306 |
+
- Example: `[B, 3, 33, 512, 512]` → `[B, 128, 5, 16, 16]`
|
| 307 |
+
- **Decoder**: Expands `[B, 128, F, H, W]` latents → `[B, 3, F', H*32, W*32]` pixels
|
| 308 |
+
- Where `F' = 1 + (F-1)*8`
|
| 309 |
+
- Example: `[B, 128, 5, 16, 16]` → `[B, 3, 33, 512, 512]`
|
| 310 |
+
|
| 311 |
+
The Video VAE is used internally by pipelines for encoding video pixels to latents and decoding latents back to pixels. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
|
| 312 |
+
|
| 313 |
+
---
|
| 314 |
+
|
| 315 |
+
## Audio VAE
|
| 316 |
+
|
| 317 |
+
The Audio VAE ([`src/ltx_core/model/audio_vae/`](src/ltx_core/model/audio_vae/)) processes audio spectrograms.
|
| 318 |
+
|
| 319 |
+
### Audio VAE Architecture
|
| 320 |
+
|
| 321 |
+
Compact neural audio representation optimized for diffusion-based training. Natively supports stereo: processes two-channel mel-spectrograms (16 kHz input) with channel concatenation before encoding.
|
| 322 |
+
|
| 323 |
+
- **Encoder**: `[B, mel_bins, T]` → `[B, 8, T/4, 16]` latents (4× temporal downsampling, 8 channels, 16 mel bins in latent space, ~1/25s per token, 128-dim feature vector)
|
| 324 |
+
- **Decoder**: `[B, 8, T, 16]` → `[B, mel_bins, T*4]` mel spectrogram
|
| 325 |
+
- **Vocoder**: HiFi-GAN-based, modified for stereo synthesis and upsampling (16 kHz mel → 24 kHz waveform, doubled generator capacity for stereo)
|
| 326 |
+
|
| 327 |
+
**Downsampling**:
|
| 328 |
+
|
| 329 |
+
- Temporal: 4× (time steps)
|
| 330 |
+
- Frequency: Variable (input mel_bins → fixed 16 in latent space)
|
| 331 |
+
|
| 332 |
+
The Audio VAE is used internally by pipelines for encoding mel spectrograms to latents and decoding latents back to mel spectrograms. The vocoder converts mel spectrograms to audio waveforms. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
|
| 333 |
+
|
| 334 |
+
---
|
| 335 |
+
|
| 336 |
+
## Text Encoding (Gemma)
|
| 337 |
+
|
| 338 |
+
LTX-2 uses **Gemma 3** (Gemma 3-12B) as the multilingual text encoder backbone, located in [`src/ltx_core/text_encoders/gemma/`](src/ltx_core/text_encoders/gemma/). Advanced text understanding is critical not only for global language support but for the phonetic and semantic accuracy of generated speech.
|
| 339 |
+
|
| 340 |
+
### Text Encoder Architecture
|
| 341 |
+
|
| 342 |
+
The text conditioning pipeline consists of three stages:
|
| 343 |
+
|
| 344 |
+
1. **Gemma 3 Backbone**: Decoder-only LLM processes text tokens → embeddings across all layers `[B, T, D, L]`
|
| 345 |
+
2. **Multi-Layer Feature Extractor**: Aggregates features from all decoder layers (not just final layer), applies mean-centered scaling, flattens to `[B, T, D×L]`, and projects via learnable matrix W (jointly optimized with LTX-2, LLM weights frozen)
|
| 346 |
+
3. **Text Connector**: Bidirectional transformer blocks with learnable registers (replacing padded positions, also referred to as "thinking tokens" in the paper) for contextual mixing. Separate connectors for video and audio streams (`Embeddings1DConnector`)
|
| 347 |
+
|
| 348 |
+
**Encoders**:
|
| 349 |
+
|
| 350 |
+
- `AVGemmaTextEncoderModel`: Audio-video generation (two connectors → `AVGemmaEncoderOutput` with separate video/audio contexts)
|
| 351 |
+
- `VideoGemmaTextEncoderModel`: Video-only generation (single connector → `VideoGemmaEncoderOutput`)
|
| 352 |
+
|
| 353 |
+
### System Prompts
|
| 354 |
+
|
| 355 |
+
System prompts are also used to enhance user's prompts.
|
| 356 |
+
|
| 357 |
+
- **Text-to-Video**: [`gemma_t2v_system_prompt.txt`](src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt)
|
| 358 |
+
- **Image-to-Video**: [`gemma_i2v_system_prompt.txt`](src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_i2v_system_prompt.txt)
|
| 359 |
+
|
| 360 |
+
**Important**: Video and audio receive **different** context embeddings, even from the same prompt. This allows better modality-specific conditioning and enables the model to synthesize speech that is synchronized with visual lip movement while being natural in cadence, accent, and emotional tone.
|
| 361 |
+
|
| 362 |
+
**Output Format**:
|
| 363 |
+
|
| 364 |
+
- Video context: `[B, seq_len, 4096]` - Video-specific text embeddings
|
| 365 |
+
- Audio context: `[B, seq_len, 2048]` - Audio-specific text embeddings
|
| 366 |
+
|
| 367 |
+
The text encoder is used internally by pipelines. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
|
| 368 |
+
|
| 369 |
+
---
|
| 370 |
+
|
| 371 |
+
## Upscaler
|
| 372 |
+
|
| 373 |
+
The Upscaler ([`src/ltx_core/model/upsampler/`](src/ltx_core/model/upsampler/)) upsamples latent representations for higher-resolution output.
|
| 374 |
+
|
| 375 |
+
The spatial upsampler is used internally by two-stage pipelines (e.g., [`TI2VidTwoStagesPipeline`](../ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py), [`ICLoraPipeline`](../ltx-pipelines/src/ltx_pipelines/ic_lora.py)) to upsample low-resolution latents before final VAE decoding. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
|
| 376 |
+
|
| 377 |
+
---
|
| 378 |
+
|
| 379 |
+
## Data Flow
|
| 380 |
+
|
| 381 |
+
### Complete Generation Pipeline
|
| 382 |
+
|
| 383 |
+
Here's how all the components work together conceptually ([`src/ltx_core/components/`](src/ltx_core/components/)):
|
| 384 |
+
|
| 385 |
+
**Pipeline Steps**:
|
| 386 |
+
|
| 387 |
+
1. **Text Encoding**: Text prompt → Gemma encoder → separate video/audio embeddings
|
| 388 |
+
2. **Latent Initialization**: Initialize noise latents in spatial format `[B, C, F, H, W]`
|
| 389 |
+
3. **Patchification**: Convert spatial latents to sequence format `[B, seq_len, dim]` for transformer
|
| 390 |
+
4. **Sigma Schedule**: Generate noise schedule (adapts to token count)
|
| 391 |
+
5. **Denoising Loop**: Iteratively denoise using transformer predictions
|
| 392 |
+
- Create Modality inputs with per-token timesteps and RoPE positions
|
| 393 |
+
- Forward pass through transformer (conditional and unconditional for CFG)
|
| 394 |
+
- Apply guidance (CFG, STG, etc.)
|
| 395 |
+
- Update latents using diffusion step (Euler, etc.)
|
| 396 |
+
6. **Unpatchification**: Convert sequence back to spatial format
|
| 397 |
+
7. **VAE Decoding**: Decode latents to pixel space (with optional upsampling for two-stage)
|
| 398 |
+
|
| 399 |
+
- [`TI2VidTwoStagesPipeline`](../ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py) - Two-stage text-to-video (recommended)
|
| 400 |
+
- [`ICLoraPipeline`](../ltx-pipelines/src/ltx_pipelines/ic_lora.py) - Video-to-video with IC-LoRA control
|
| 401 |
+
- [`DistilledPipeline`](../ltx-pipelines/src/ltx_pipelines/distilled.py) - Fast inference with distilled model
|
| 402 |
+
- [`KeyframeInterpolationPipeline`](../ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py) - Keyframe-based interpolation
|
| 403 |
+
|
| 404 |
+
See the [ltx-pipelines README](../ltx-pipelines/README.md) for usage examples.
|
| 405 |
+
|
| 406 |
+
## 🔗 Related Projects
|
| 407 |
+
|
| 408 |
+
- **[ltx-pipelines](../ltx-pipelines/)** - High-level pipeline implementations for text-to-video, image-to-video, and video-to-video
|
| 409 |
+
- **[ltx-trainer](../ltx-trainer/)** - Training and fine-tuning tools
|
packages/ltx-core/src/ltx_core/conditioning/types/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Conditioning type implementations."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.conditioning.types.attention_strength_wrapper import ConditioningItemAttentionStrengthWrapper
|
| 4 |
+
from ltx_core.conditioning.types.keyframe_cond import VideoConditionByKeyframeIndex
|
| 5 |
+
from ltx_core.conditioning.types.latent_cond import VideoConditionByLatentIndex
|
| 6 |
+
from ltx_core.conditioning.types.reference_video_cond import VideoConditionByReferenceLatent
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"ConditioningItemAttentionStrengthWrapper",
|
| 10 |
+
"VideoConditionByKeyframeIndex",
|
| 11 |
+
"VideoConditionByLatentIndex",
|
| 12 |
+
"VideoConditionByReferenceLatent",
|
| 13 |
+
]
|
packages/ltx-core/src/ltx_core/conditioning/types/attention_strength_wrapper.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Wrapper conditioning item that adds attention masking to any inner conditioning."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import replace
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 8 |
+
from ltx_core.conditioning.mask_utils import update_attention_mask
|
| 9 |
+
from ltx_core.tools import LatentTools
|
| 10 |
+
from ltx_core.types import LatentState
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ConditioningItemAttentionStrengthWrapper(ConditioningItem):
|
| 14 |
+
"""Wraps a conditioning item to add an attention mask for its tokens.
|
| 15 |
+
Separates the *attention-masking* concern from the underlying conditioning
|
| 16 |
+
logic (token layout, positional encoding, denoise strength). The inner
|
| 17 |
+
conditioning item appends tokens to the latent sequence as usual, and this
|
| 18 |
+
wrapper then builds or updates the self-attention mask so that the newly
|
| 19 |
+
added tokens interact with the noisy tokens according to *attention_mask*.
|
| 20 |
+
Args:
|
| 21 |
+
conditioning: Any conditioning item that appends tokens to the latent.
|
| 22 |
+
attention_mask: Per-token attention weight controlling how strongly the
|
| 23 |
+
new conditioning tokens attend to/from noisy tokens. Can be a
|
| 24 |
+
scalar (float) applied uniformly, or a tensor of shape ``(B, M)``
|
| 25 |
+
for spatial control, where ``M = F * H * W`` is the number of
|
| 26 |
+
patchified conditioning tokens. Values in ``[0, 1]``.
|
| 27 |
+
Example::
|
| 28 |
+
cond = ConditioningItemAttentionStrengthWrapper(
|
| 29 |
+
VideoConditionByReferenceLatent(latent=ref, strength=1.0),
|
| 30 |
+
attention_mask=0.5,
|
| 31 |
+
)
|
| 32 |
+
state = cond.apply_to(latent_state, latent_tools)
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
conditioning: ConditioningItem,
|
| 38 |
+
attention_mask: float | torch.Tensor,
|
| 39 |
+
):
|
| 40 |
+
self.conditioning = conditioning
|
| 41 |
+
self.attention_mask = attention_mask
|
| 42 |
+
|
| 43 |
+
def apply_to(
|
| 44 |
+
self,
|
| 45 |
+
latent_state: LatentState,
|
| 46 |
+
latent_tools: LatentTools,
|
| 47 |
+
) -> LatentState:
|
| 48 |
+
"""Apply inner conditioning, then build the attention mask for its tokens."""
|
| 49 |
+
# Snapshot the original state for mask building
|
| 50 |
+
original_state = latent_state
|
| 51 |
+
|
| 52 |
+
# Inner conditioning appends tokens (positions, denoise mask, etc.)
|
| 53 |
+
new_state = self.conditioning.apply_to(latent_state, latent_tools)
|
| 54 |
+
|
| 55 |
+
num_new_tokens = new_state.latent.shape[1] - original_state.latent.shape[1]
|
| 56 |
+
if num_new_tokens == 0:
|
| 57 |
+
return new_state
|
| 58 |
+
|
| 59 |
+
# Build the attention mask using the *original* state as the reference
|
| 60 |
+
# so that the block structure is computed correctly.
|
| 61 |
+
new_attention_mask = update_attention_mask(
|
| 62 |
+
latent_state=original_state,
|
| 63 |
+
attention_mask=self.attention_mask,
|
| 64 |
+
num_noisy_tokens=latent_tools.target_shape.token_count(),
|
| 65 |
+
num_new_tokens=num_new_tokens,
|
| 66 |
+
batch_size=new_state.latent.shape[0],
|
| 67 |
+
device=new_state.latent.device,
|
| 68 |
+
dtype=new_state.latent.dtype,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return replace(new_state, attention_mask=new_attention_mask)
|
packages/ltx-core/src/ltx_core/conditioning/types/keyframe_cond.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.components.patchifiers import get_pixel_coords
|
| 4 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 5 |
+
from ltx_core.conditioning.mask_utils import update_attention_mask
|
| 6 |
+
from ltx_core.tools import VideoLatentTools
|
| 7 |
+
from ltx_core.types import LatentState, VideoLatentShape
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class VideoConditionByKeyframeIndex(ConditioningItem):
|
| 11 |
+
"""
|
| 12 |
+
Conditions video generation on keyframe latents at a specific frame index.
|
| 13 |
+
Appends keyframe tokens to the latent state with positions offset by frame_idx,
|
| 14 |
+
and sets denoise strength according to the strength parameter.
|
| 15 |
+
To add attention masking, wrap with :class:`ConditioningItemAttentionStrengthWrapper`.
|
| 16 |
+
Args:
|
| 17 |
+
keyframes: Keyframe latents [B, C, F, H, W].
|
| 18 |
+
frame_idx: Frame index offset for positional encoding.
|
| 19 |
+
strength: Conditioning strength (1.0 = clean, 0.0 = fully denoised).
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, keyframes: torch.Tensor, frame_idx: int, strength: float):
|
| 23 |
+
self.keyframes = keyframes
|
| 24 |
+
self.frame_idx = frame_idx
|
| 25 |
+
self.strength = strength
|
| 26 |
+
|
| 27 |
+
def apply_to(
|
| 28 |
+
self,
|
| 29 |
+
latent_state: LatentState,
|
| 30 |
+
latent_tools: VideoLatentTools,
|
| 31 |
+
) -> LatentState:
|
| 32 |
+
tokens = latent_tools.patchifier.patchify(self.keyframes)
|
| 33 |
+
latent_coords = latent_tools.patchifier.get_patch_grid_bounds(
|
| 34 |
+
output_shape=VideoLatentShape.from_torch_shape(self.keyframes.shape),
|
| 35 |
+
device=self.keyframes.device,
|
| 36 |
+
)
|
| 37 |
+
positions = get_pixel_coords(
|
| 38 |
+
latent_coords=latent_coords,
|
| 39 |
+
scale_factors=latent_tools.scale_factors,
|
| 40 |
+
causal_fix=latent_tools.causal_fix if self.frame_idx == 0 else False,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
positions[:, 0, ...] += self.frame_idx
|
| 44 |
+
positions = positions.to(dtype=torch.float32)
|
| 45 |
+
positions[:, 0, ...] /= latent_tools.fps
|
| 46 |
+
|
| 47 |
+
denoise_mask = torch.full(
|
| 48 |
+
size=(*tokens.shape[:2], 1),
|
| 49 |
+
fill_value=1.0 - self.strength,
|
| 50 |
+
device=self.keyframes.device,
|
| 51 |
+
dtype=self.keyframes.dtype,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
new_attention_mask = update_attention_mask(
|
| 55 |
+
latent_state=latent_state,
|
| 56 |
+
attention_mask=None,
|
| 57 |
+
num_noisy_tokens=latent_tools.target_shape.token_count(),
|
| 58 |
+
num_new_tokens=tokens.shape[1],
|
| 59 |
+
batch_size=tokens.shape[0],
|
| 60 |
+
device=self.keyframes.device,
|
| 61 |
+
dtype=self.keyframes.dtype,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return LatentState(
|
| 65 |
+
latent=torch.cat([latent_state.latent, tokens], dim=1),
|
| 66 |
+
denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
|
| 67 |
+
positions=torch.cat([latent_state.positions, positions], dim=2),
|
| 68 |
+
clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
|
| 69 |
+
attention_mask=new_attention_mask,
|
| 70 |
+
)
|
packages/ltx-core/src/ltx_core/conditioning/types/latent_cond.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.conditioning.exceptions import ConditioningError
|
| 4 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 5 |
+
from ltx_core.tools import LatentTools
|
| 6 |
+
from ltx_core.types import LatentState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VideoConditionByLatentIndex(ConditioningItem):
|
| 10 |
+
"""
|
| 11 |
+
Conditions video generation by injecting latents at a specific latent frame index.
|
| 12 |
+
Replaces tokens in the latent state at positions corresponding to latent_idx,
|
| 13 |
+
and sets denoise strength according to the strength parameter.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, latent: torch.Tensor, strength: float, latent_idx: int):
|
| 17 |
+
self.latent = latent
|
| 18 |
+
self.strength = strength
|
| 19 |
+
self.latent_idx = latent_idx
|
| 20 |
+
|
| 21 |
+
def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
|
| 22 |
+
cond_batch, cond_channels, _, cond_height, cond_width = self.latent.shape
|
| 23 |
+
tgt_batch, tgt_channels, tgt_frames, tgt_height, tgt_width = latent_tools.target_shape.to_torch_shape()
|
| 24 |
+
|
| 25 |
+
if (cond_batch, cond_channels, cond_height, cond_width) != (tgt_batch, tgt_channels, tgt_height, tgt_width):
|
| 26 |
+
raise ConditioningError(
|
| 27 |
+
f"Can't apply image conditioning item to latent with shape {latent_tools.target_shape}, expected "
|
| 28 |
+
f"shape is ({tgt_batch}, {tgt_channels}, {tgt_frames}, {tgt_height}, {tgt_width}). Make sure "
|
| 29 |
+
"the image and latent have the same spatial shape."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
tokens = latent_tools.patchifier.patchify(self.latent)
|
| 33 |
+
start_token = latent_tools.patchifier.get_token_count(
|
| 34 |
+
latent_tools.target_shape._replace(frames=self.latent_idx)
|
| 35 |
+
)
|
| 36 |
+
stop_token = start_token + tokens.shape[1]
|
| 37 |
+
|
| 38 |
+
latent_state = latent_state.clone()
|
| 39 |
+
|
| 40 |
+
latent_state.latent[:, start_token:stop_token] = tokens
|
| 41 |
+
latent_state.clean_latent[:, start_token:stop_token] = tokens
|
| 42 |
+
latent_state.denoise_mask[:, start_token:stop_token] = 1.0 - self.strength
|
| 43 |
+
|
| 44 |
+
return latent_state
|
packages/ltx-core/src/ltx_core/conditioning/types/reference_video_cond.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reference video conditioning for IC-LoRA inference."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.components.patchifiers import get_pixel_coords
|
| 6 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 7 |
+
from ltx_core.conditioning.mask_utils import update_attention_mask
|
| 8 |
+
from ltx_core.tools import VideoLatentTools
|
| 9 |
+
from ltx_core.types import LatentState, VideoLatentShape
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VideoConditionByReferenceLatent(ConditioningItem):
|
| 13 |
+
"""
|
| 14 |
+
Conditions video generation on a reference video latent for IC-LoRA inference.
|
| 15 |
+
IC-LoRAs are trained by concatenating reference (control signal) and target tokens,
|
| 16 |
+
learning to attend across both. This class replicates that setup at inference by
|
| 17 |
+
appending reference tokens to the latent sequence.
|
| 18 |
+
IC-LoRAs can be trained with lower-resolution references than the target (e.g., 384px
|
| 19 |
+
reference for 768px output) for efficiency and better generalization. The
|
| 20 |
+
`downscale_factor` scales reference positions to match target coordinates, preserving
|
| 21 |
+
the learned positional relationships. This must match the factor used during training
|
| 22 |
+
(stored in LoRA metadata).
|
| 23 |
+
To add attention masking, wrap with :class:`ConditioningItemAttentionStrengthWrapper`.
|
| 24 |
+
Args:
|
| 25 |
+
latent: Reference video latents [B, C, F, H, W]
|
| 26 |
+
downscale_factor: Target/reference resolution ratio (e.g., 2 = half-resolution
|
| 27 |
+
reference). Spatial positions are scaled by this factor.
|
| 28 |
+
strength: Conditioning strength. 1.0 = full (reference kept clean),
|
| 29 |
+
0.0 = none (reference denoised). Default 1.0.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
latent: torch.Tensor,
|
| 35 |
+
downscale_factor: int = 1,
|
| 36 |
+
strength: float = 1.0,
|
| 37 |
+
):
|
| 38 |
+
self.latent = latent
|
| 39 |
+
self.downscale_factor = downscale_factor
|
| 40 |
+
self.strength = strength
|
| 41 |
+
|
| 42 |
+
def apply_to(
|
| 43 |
+
self,
|
| 44 |
+
latent_state: LatentState,
|
| 45 |
+
latent_tools: VideoLatentTools,
|
| 46 |
+
) -> LatentState:
|
| 47 |
+
"""Append reference video tokens with scaled positions."""
|
| 48 |
+
tokens = latent_tools.patchifier.patchify(self.latent)
|
| 49 |
+
|
| 50 |
+
# Compute positions for the reference video's actual dimensions
|
| 51 |
+
latent_coords = latent_tools.patchifier.get_patch_grid_bounds(
|
| 52 |
+
output_shape=VideoLatentShape.from_torch_shape(self.latent.shape),
|
| 53 |
+
device=self.latent.device,
|
| 54 |
+
)
|
| 55 |
+
positions = get_pixel_coords(
|
| 56 |
+
latent_coords=latent_coords,
|
| 57 |
+
scale_factors=latent_tools.scale_factors,
|
| 58 |
+
causal_fix=latent_tools.causal_fix,
|
| 59 |
+
)
|
| 60 |
+
positions = positions.to(dtype=torch.float32)
|
| 61 |
+
positions[:, 0, ...] /= latent_tools.fps
|
| 62 |
+
|
| 63 |
+
# Scale spatial positions to match target coordinate space
|
| 64 |
+
if self.downscale_factor != 1:
|
| 65 |
+
positions[:, 1, ...] *= self.downscale_factor # height axis
|
| 66 |
+
positions[:, 2, ...] *= self.downscale_factor # width axis
|
| 67 |
+
|
| 68 |
+
denoise_mask = torch.full(
|
| 69 |
+
size=(*tokens.shape[:2], 1),
|
| 70 |
+
fill_value=1.0 - self.strength,
|
| 71 |
+
device=self.latent.device,
|
| 72 |
+
dtype=self.latent.dtype,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
new_attention_mask = update_attention_mask(
|
| 76 |
+
latent_state=latent_state,
|
| 77 |
+
attention_mask=None,
|
| 78 |
+
num_noisy_tokens=latent_tools.target_shape.token_count(),
|
| 79 |
+
num_new_tokens=tokens.shape[1],
|
| 80 |
+
batch_size=tokens.shape[0],
|
| 81 |
+
device=self.latent.device,
|
| 82 |
+
dtype=self.latent.dtype,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
return LatentState(
|
| 86 |
+
latent=torch.cat([latent_state.latent, tokens], dim=1),
|
| 87 |
+
denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
|
| 88 |
+
positions=torch.cat([latent_state.positions, positions], dim=2),
|
| 89 |
+
clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
|
| 90 |
+
attention_mask=new_attention_mask,
|
| 91 |
+
)
|
packages/ltx-core/src/ltx_core/guidance/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (454 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/guidance/__pycache__/perturbations.cpython-312.pyc
ADDED
|
Binary file (5.7 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/audio_vae/causality_axis.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class CausalityAxis(Enum):
|
| 5 |
+
"""Enum for specifying the causality axis in causal convolutions."""
|
| 6 |
+
|
| 7 |
+
NONE = None
|
| 8 |
+
WIDTH = "width"
|
| 9 |
+
HEIGHT = "height"
|
| 10 |
+
WIDTH_COMPATIBILITY = "width-compatibility"
|
packages/ltx-core/src/ltx_core/model/common/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Common model utilities."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.common.normalization import NormType, PixelNorm, build_normalization_layer
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"NormType",
|
| 7 |
+
"PixelNorm",
|
| 8 |
+
"build_normalization_layer",
|
| 9 |
+
]
|
packages/ltx-core/src/ltx_core/model/common/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (378 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/model/common/__pycache__/normalization.cpython-312.pyc
ADDED
|
Binary file (3.18 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/common/normalization.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class NormType(Enum):
|
| 8 |
+
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
|
| 9 |
+
|
| 10 |
+
GROUP = "group"
|
| 11 |
+
PIXEL = "pixel"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class PixelNorm(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Per-pixel (per-location) RMS normalization layer.
|
| 17 |
+
For each element along the chosen dimension, this layer normalizes the tensor
|
| 18 |
+
by the root-mean-square of its values across that dimension:
|
| 19 |
+
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
| 23 |
+
"""
|
| 24 |
+
Args:
|
| 25 |
+
dim: Dimension along which to compute the RMS (typically channels).
|
| 26 |
+
eps: Small constant added for numerical stability.
|
| 27 |
+
"""
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.dim = dim
|
| 30 |
+
self.eps = eps
|
| 31 |
+
|
| 32 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
"""
|
| 34 |
+
Apply RMS normalization along the configured dimension.
|
| 35 |
+
"""
|
| 36 |
+
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
|
| 37 |
+
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
| 38 |
+
# Normalize by the root-mean-square (RMS).
|
| 39 |
+
rms = torch.sqrt(mean_sq + self.eps)
|
| 40 |
+
return x / rms
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def build_normalization_layer(
|
| 44 |
+
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
|
| 45 |
+
) -> nn.Module:
|
| 46 |
+
"""
|
| 47 |
+
Create a normalization layer based on the normalization type.
|
| 48 |
+
Args:
|
| 49 |
+
in_channels: Number of input channels
|
| 50 |
+
num_groups: Number of groups for group normalization
|
| 51 |
+
normtype: Type of normalization: "group" or "pixel"
|
| 52 |
+
Returns:
|
| 53 |
+
A normalization layer
|
| 54 |
+
"""
|
| 55 |
+
if normtype == NormType.GROUP:
|
| 56 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 57 |
+
if normtype == NormType.PIXEL:
|
| 58 |
+
return PixelNorm(dim=1, eps=1e-6)
|
| 59 |
+
raise ValueError(f"Invalid normalization type: {normtype}")
|
packages/ltx-core/src/ltx_core/model/transformer/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Transformer model components."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.transformer.modality import Modality
|
| 4 |
+
from ltx_core.model.transformer.model import LTXModel, X0Model
|
| 5 |
+
from ltx_core.model.transformer.model_configurator import (
|
| 6 |
+
LTXV_MODEL_COMFY_RENAMING_MAP,
|
| 7 |
+
LTXModelConfigurator,
|
| 8 |
+
LTXVideoOnlyModelConfigurator,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"LTXV_MODEL_COMFY_RENAMING_MAP",
|
| 13 |
+
"LTXModel",
|
| 14 |
+
"LTXModelConfigurator",
|
| 15 |
+
"LTXVideoOnlyModelConfigurator",
|
| 16 |
+
"Modality",
|
| 17 |
+
"X0Model",
|
| 18 |
+
]
|
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/adaln.cpython-312.pyc
ADDED
|
Binary file (2.6 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/transformer/__pycache__/transformer_args.cpython-312.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/transformer/adaln.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings
|
| 6 |
+
|
| 7 |
+
# Number of AdaLN modulation parameters per transformer block.
|
| 8 |
+
# Base: 2 params (shift + scale) x 3 norms (self-attn, feed-forward, output).
|
| 9 |
+
ADALN_NUM_BASE_PARAMS = 6
|
| 10 |
+
# Cross-attention AdaLN adds 3 more (scale, shift, gate) for the CA norm.
|
| 11 |
+
ADALN_NUM_CROSS_ATTN_PARAMS = 3
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def adaln_embedding_coefficient(cross_attention_adaln: bool) -> int:
|
| 15 |
+
"""Total number of AdaLN parameters per block."""
|
| 16 |
+
return ADALN_NUM_BASE_PARAMS + (ADALN_NUM_CROSS_ATTN_PARAMS if cross_attention_adaln else 0)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AdaLayerNormSingle(torch.nn.Module):
|
| 20 |
+
r"""
|
| 21 |
+
Norm layer adaptive layer norm single (adaLN-single).
|
| 22 |
+
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
| 23 |
+
Parameters:
|
| 24 |
+
embedding_dim (`int`): The size of each embedding vector.
|
| 25 |
+
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, embedding_dim: int, embedding_coefficient: int = 6):
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
| 32 |
+
embedding_dim,
|
| 33 |
+
size_emb_dim=embedding_dim // 3,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
self.silu = torch.nn.SiLU()
|
| 37 |
+
self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
|
| 38 |
+
|
| 39 |
+
def forward(
|
| 40 |
+
self,
|
| 41 |
+
timestep: torch.Tensor,
|
| 42 |
+
hidden_dtype: Optional[torch.dtype] = None,
|
| 43 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 44 |
+
embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype)
|
| 45 |
+
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
packages/ltx-core/src/ltx_core/model/transformer/attention.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from typing import Protocol
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ltx_core.model.transformer.rope import LTXRopeType, apply_rotary_emb
|
| 7 |
+
|
| 8 |
+
memory_efficient_attention = None
|
| 9 |
+
flash_attn_interface = None
|
| 10 |
+
try:
|
| 11 |
+
from xformers.ops import memory_efficient_attention
|
| 12 |
+
except ImportError:
|
| 13 |
+
memory_efficient_attention = None
|
| 14 |
+
try:
|
| 15 |
+
# FlashAttention3 and XFormersAttention cannot be used together
|
| 16 |
+
if memory_efficient_attention is None:
|
| 17 |
+
import flash_attn_interface
|
| 18 |
+
except ImportError:
|
| 19 |
+
flash_attn_interface = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class AttentionCallable(Protocol):
|
| 23 |
+
def __call__(
|
| 24 |
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
|
| 25 |
+
) -> torch.Tensor: ...
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PytorchAttention(AttentionCallable):
|
| 29 |
+
def __call__(
|
| 30 |
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
|
| 31 |
+
) -> torch.Tensor:
|
| 32 |
+
b, _, dim_head = q.shape
|
| 33 |
+
dim_head //= heads
|
| 34 |
+
q, k, v = (t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v))
|
| 35 |
+
|
| 36 |
+
if mask is not None:
|
| 37 |
+
# add a batch dimension if there isn't already one
|
| 38 |
+
if mask.ndim == 2:
|
| 39 |
+
mask = mask.unsqueeze(0)
|
| 40 |
+
# add a heads dimension if there isn't already one
|
| 41 |
+
if mask.ndim == 3:
|
| 42 |
+
mask = mask.unsqueeze(1)
|
| 43 |
+
|
| 44 |
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
| 45 |
+
out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
| 46 |
+
return out
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class XFormersAttention(AttentionCallable):
|
| 50 |
+
def __call__(
|
| 51 |
+
self,
|
| 52 |
+
q: torch.Tensor,
|
| 53 |
+
k: torch.Tensor,
|
| 54 |
+
v: torch.Tensor,
|
| 55 |
+
heads: int,
|
| 56 |
+
mask: torch.Tensor | None = None,
|
| 57 |
+
) -> torch.Tensor:
|
| 58 |
+
if memory_efficient_attention is None:
|
| 59 |
+
raise RuntimeError("XFormersAttention was selected but `xformers` is not installed.")
|
| 60 |
+
|
| 61 |
+
b, _, dim_head = q.shape
|
| 62 |
+
dim_head //= heads
|
| 63 |
+
|
| 64 |
+
# xformers expects [B, M, H, K]
|
| 65 |
+
q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
|
| 66 |
+
|
| 67 |
+
if mask is not None:
|
| 68 |
+
# add a singleton batch dimension
|
| 69 |
+
if mask.ndim == 2:
|
| 70 |
+
mask = mask.unsqueeze(0)
|
| 71 |
+
# add a singleton heads dimension
|
| 72 |
+
if mask.ndim == 3:
|
| 73 |
+
mask = mask.unsqueeze(1)
|
| 74 |
+
# pad to a multiple of 8
|
| 75 |
+
pad = 8 - mask.shape[-1] % 8
|
| 76 |
+
# the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
|
| 77 |
+
# but when using separated heads, the shape has to be (B, H, Nq, Nk)
|
| 78 |
+
# in flux, this matrix ends up being over 1GB
|
| 79 |
+
# here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
|
| 80 |
+
mask_out = torch.empty(
|
| 81 |
+
[mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
mask_out[..., : mask.shape[-1]] = mask
|
| 85 |
+
# doesn't this remove the padding again??
|
| 86 |
+
mask = mask_out[..., : mask.shape[-1]]
|
| 87 |
+
mask = mask.expand(b, heads, -1, -1)
|
| 88 |
+
|
| 89 |
+
out = memory_efficient_attention(q.to(v.dtype), k.to(v.dtype), v, attn_bias=mask, p=0.0)
|
| 90 |
+
out = out.reshape(b, -1, heads * dim_head)
|
| 91 |
+
return out
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class FlashAttention3(AttentionCallable):
|
| 95 |
+
def __call__(
|
| 96 |
+
self,
|
| 97 |
+
q: torch.Tensor,
|
| 98 |
+
k: torch.Tensor,
|
| 99 |
+
v: torch.Tensor,
|
| 100 |
+
heads: int,
|
| 101 |
+
mask: torch.Tensor | None = None,
|
| 102 |
+
) -> torch.Tensor:
|
| 103 |
+
if flash_attn_interface is None:
|
| 104 |
+
raise RuntimeError("FlashAttention3 was selected but `FlashAttention3` is not installed.")
|
| 105 |
+
|
| 106 |
+
b, _, dim_head = q.shape
|
| 107 |
+
dim_head //= heads
|
| 108 |
+
|
| 109 |
+
q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
|
| 110 |
+
|
| 111 |
+
if mask is not None:
|
| 112 |
+
raise NotImplementedError("Mask is not supported for FlashAttention3")
|
| 113 |
+
|
| 114 |
+
out = flash_attn_interface.flash_attn_func(q.to(v.dtype), k.to(v.dtype), v)
|
| 115 |
+
out = out.reshape(b, -1, heads * dim_head)
|
| 116 |
+
return out
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class AttentionFunction(Enum):
|
| 120 |
+
PYTORCH = "pytorch"
|
| 121 |
+
XFORMERS = "xformers"
|
| 122 |
+
FLASH_ATTENTION_3 = "flash_attention_3"
|
| 123 |
+
DEFAULT = "default"
|
| 124 |
+
|
| 125 |
+
def __call__(
|
| 126 |
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
|
| 127 |
+
) -> torch.Tensor:
|
| 128 |
+
if self is AttentionFunction.PYTORCH:
|
| 129 |
+
return PytorchAttention()(q, k, v, heads, mask)
|
| 130 |
+
elif self is AttentionFunction.XFORMERS:
|
| 131 |
+
return XFormersAttention()(q, k, v, heads, mask)
|
| 132 |
+
elif self is AttentionFunction.FLASH_ATTENTION_3:
|
| 133 |
+
return FlashAttention3()(q, k, v, heads, mask)
|
| 134 |
+
else:
|
| 135 |
+
# Default behavior: XFormers if installed else - PyTorch
|
| 136 |
+
return (
|
| 137 |
+
XFormersAttention()(q, k, v, heads, mask)
|
| 138 |
+
if memory_efficient_attention is not None
|
| 139 |
+
else PytorchAttention()(q, k, v, heads, mask)
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class Attention(torch.nn.Module):
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
query_dim: int,
|
| 147 |
+
context_dim: int | None = None,
|
| 148 |
+
heads: int = 8,
|
| 149 |
+
dim_head: int = 64,
|
| 150 |
+
norm_eps: float = 1e-6,
|
| 151 |
+
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
| 152 |
+
attention_function: AttentionCallable | AttentionFunction = AttentionFunction.DEFAULT,
|
| 153 |
+
apply_gated_attention: bool = False,
|
| 154 |
+
) -> None:
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.rope_type = rope_type
|
| 157 |
+
self.attention_function = attention_function
|
| 158 |
+
|
| 159 |
+
inner_dim = dim_head * heads
|
| 160 |
+
context_dim = query_dim if context_dim is None else context_dim
|
| 161 |
+
|
| 162 |
+
self.heads = heads
|
| 163 |
+
self.dim_head = dim_head
|
| 164 |
+
|
| 165 |
+
self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)
|
| 166 |
+
self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)
|
| 167 |
+
|
| 168 |
+
self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True)
|
| 169 |
+
self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
|
| 170 |
+
self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
|
| 171 |
+
|
| 172 |
+
# Optional per-head gating
|
| 173 |
+
if apply_gated_attention:
|
| 174 |
+
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
|
| 175 |
+
else:
|
| 176 |
+
self.to_gate_logits = None
|
| 177 |
+
|
| 178 |
+
self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity())
|
| 179 |
+
|
| 180 |
+
def forward(
|
| 181 |
+
self,
|
| 182 |
+
x: torch.Tensor,
|
| 183 |
+
context: torch.Tensor | None = None,
|
| 184 |
+
mask: torch.Tensor | None = None,
|
| 185 |
+
pe: torch.Tensor | None = None,
|
| 186 |
+
k_pe: torch.Tensor | None = None,
|
| 187 |
+
perturbation_mask: torch.Tensor | None = None,
|
| 188 |
+
all_perturbed: bool = False,
|
| 189 |
+
) -> torch.Tensor:
|
| 190 |
+
"""Multi-head attention with optional RoPE, perturbation masking, and per-head gating.
|
| 191 |
+
When ``perturbation_mask`` is all zeros, the expensive query/key path
|
| 192 |
+
(linear projections, RMSNorm, RoPE) is skipped entirely and only the
|
| 193 |
+
value projection is used as a pass-through.
|
| 194 |
+
Args:
|
| 195 |
+
x: Query input tensor of shape ``(B, T, query_dim)``.
|
| 196 |
+
context: Key/value context tensor of shape ``(B, S, context_dim)``.
|
| 197 |
+
Falls back to ``x`` (self-attention) when *None*.
|
| 198 |
+
mask: Optional attention mask. Interpretation depends on the attention
|
| 199 |
+
backend (additive bias for xformers/PyTorch SDPA).
|
| 200 |
+
pe: Rotary positional embeddings applied to both ``q`` and ``k``.
|
| 201 |
+
k_pe: Separate rotary positional embeddings for ``k`` only. When
|
| 202 |
+
*None*, ``pe`` is reused for keys.
|
| 203 |
+
perturbation_mask: Optional mask in ``[0, 1]`` that
|
| 204 |
+
blends the attention output with the raw value projection:
|
| 205 |
+
``out = attn_out * mask + v * (1 - mask)``.
|
| 206 |
+
**1** keeps the full attention output, **0** bypasses attention
|
| 207 |
+
and passes the value projection through unchanged.
|
| 208 |
+
*None* or all-ones means standard attention; all-zeros skips
|
| 209 |
+
the query/key path entirely for efficiency.
|
| 210 |
+
all_perturbed: Whether all perturbations are active for this block.
|
| 211 |
+
Returns:
|
| 212 |
+
Output tensor of shape ``(B, T, query_dim)``.
|
| 213 |
+
"""
|
| 214 |
+
context = x if context is None else context
|
| 215 |
+
use_attention = not all_perturbed
|
| 216 |
+
|
| 217 |
+
v = self.to_v(context)
|
| 218 |
+
|
| 219 |
+
if not use_attention:
|
| 220 |
+
out = v
|
| 221 |
+
else:
|
| 222 |
+
q = self.to_q(x)
|
| 223 |
+
k = self.to_k(context)
|
| 224 |
+
|
| 225 |
+
q = self.q_norm(q)
|
| 226 |
+
k = self.k_norm(k)
|
| 227 |
+
|
| 228 |
+
if pe is not None:
|
| 229 |
+
q = apply_rotary_emb(q, pe, self.rope_type)
|
| 230 |
+
k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type)
|
| 231 |
+
|
| 232 |
+
out = self.attention_function(q, k, v, self.heads, mask) # (B, T, H*D)
|
| 233 |
+
|
| 234 |
+
if perturbation_mask is not None:
|
| 235 |
+
out = out * perturbation_mask + v * (1 - perturbation_mask)
|
| 236 |
+
|
| 237 |
+
# Apply per-head gating if enabled
|
| 238 |
+
if self.to_gate_logits is not None:
|
| 239 |
+
gate_logits = self.to_gate_logits(x) # (B, T, H)
|
| 240 |
+
b, t, _ = out.shape
|
| 241 |
+
# Reshape to (B, T, H, D) for per-head gating
|
| 242 |
+
out = out.view(b, t, self.heads, self.dim_head)
|
| 243 |
+
# Apply gating: 2 * sigmoid(x) so that zero-init gives identity (2 * 0.5 = 1.0)
|
| 244 |
+
gates = 2.0 * torch.sigmoid(gate_logits) # (B, T, H)
|
| 245 |
+
out = out * gates.unsqueeze(-1) # (B, T, H, D) * (B, T, H, 1)
|
| 246 |
+
# Reshape back to (B, T, H*D)
|
| 247 |
+
out = out.view(b, t, self.heads * self.dim_head)
|
| 248 |
+
|
| 249 |
+
return self.to_out(out)
|
packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class GELUApprox(torch.nn.Module):
|
| 5 |
+
def __init__(self, dim_in: int, dim_out: int) -> None:
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.proj = torch.nn.Linear(dim_in, dim_out)
|
| 8 |
+
|
| 9 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 10 |
+
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
|
packages/ltx-core/src/ltx_core/model/transformer/modality.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass(frozen=True)
|
| 7 |
+
class Modality:
|
| 8 |
+
"""
|
| 9 |
+
Input data for a single modality (video or audio) in the transformer.
|
| 10 |
+
Bundles the latent tokens, timestep embeddings, positional information,
|
| 11 |
+
and text conditioning context for processing by the diffusion transformer.
|
| 12 |
+
Attributes:
|
| 13 |
+
latent: Patchified latent tokens, shape ``(B, T, D)`` where *B* is
|
| 14 |
+
the batch size, *T* is the total number of tokens (noisy +
|
| 15 |
+
conditioning), and *D* is the input dimension.
|
| 16 |
+
timesteps: Per-token timestep embeddings, shape ``(B, T)``.
|
| 17 |
+
positions: Positional coordinates, shape ``(B, 3, T)`` for video
|
| 18 |
+
(time, height, width) or ``(B, 1, T)`` for audio.
|
| 19 |
+
context: Text conditioning embeddings from the prompt encoder.
|
| 20 |
+
enabled: Whether this modality is active in the current forward pass.
|
| 21 |
+
context_mask: Optional mask for the text context tokens.
|
| 22 |
+
attention_mask: Optional 2-D self-attention mask, shape ``(B, T, T)``.
|
| 23 |
+
Values in ``[0, 1]`` where ``1`` = full attention and ``0`` = no
|
| 24 |
+
attention. ``None`` means unrestricted (full) attention between
|
| 25 |
+
all tokens. Built incrementally by conditioning items; see
|
| 26 |
+
:class:`~ltx_core.conditioning.types.attention_strength_wrapper.ConditioningItemAttentionStrengthWrapper`.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
latent: (
|
| 30 |
+
torch.Tensor
|
| 31 |
+
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
| 32 |
+
sigma: torch.Tensor # Shape: (B,). Current sigma value, used for cross-attention timestep calculation.
|
| 33 |
+
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
|
| 34 |
+
positions: (
|
| 35 |
+
torch.Tensor
|
| 36 |
+
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
|
| 37 |
+
context: torch.Tensor
|
| 38 |
+
enabled: bool = True
|
| 39 |
+
context_mask: torch.Tensor | None = None
|
| 40 |
+
attention_mask: torch.Tensor | None = None
|
packages/ltx-core/src/ltx_core/model/transformer/model.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.guidance.perturbations import BatchedPerturbationConfig
|
| 6 |
+
from ltx_core.model.transformer.adaln import AdaLayerNormSingle, adaln_embedding_coefficient
|
| 7 |
+
from ltx_core.model.transformer.attention import AttentionCallable, AttentionFunction
|
| 8 |
+
from ltx_core.model.transformer.modality import Modality
|
| 9 |
+
from ltx_core.model.transformer.rope import LTXRopeType
|
| 10 |
+
from ltx_core.model.transformer.transformer import BasicAVTransformerBlock, TransformerConfig
|
| 11 |
+
from ltx_core.model.transformer.transformer_args import (
|
| 12 |
+
MultiModalTransformerArgsPreprocessor,
|
| 13 |
+
TransformerArgs,
|
| 14 |
+
TransformerArgsPreprocessor,
|
| 15 |
+
)
|
| 16 |
+
from ltx_core.utils import to_denoised
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LTXModelType(Enum):
|
| 20 |
+
AudioVideo = "ltx av model"
|
| 21 |
+
VideoOnly = "ltx video only model"
|
| 22 |
+
AudioOnly = "ltx audio only model"
|
| 23 |
+
|
| 24 |
+
def is_video_enabled(self) -> bool:
|
| 25 |
+
return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)
|
| 26 |
+
|
| 27 |
+
def is_audio_enabled(self) -> bool:
|
| 28 |
+
return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class LTXModel(torch.nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
LTX model transformer implementation.
|
| 34 |
+
This class implements the transformer blocks for the LTX model.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__( # noqa: PLR0913
|
| 38 |
+
self,
|
| 39 |
+
*,
|
| 40 |
+
model_type: LTXModelType = LTXModelType.AudioVideo,
|
| 41 |
+
num_attention_heads: int = 32,
|
| 42 |
+
attention_head_dim: int = 128,
|
| 43 |
+
in_channels: int = 128,
|
| 44 |
+
out_channels: int = 128,
|
| 45 |
+
num_layers: int = 48,
|
| 46 |
+
cross_attention_dim: int = 4096,
|
| 47 |
+
norm_eps: float = 1e-06,
|
| 48 |
+
attention_type: AttentionFunction | AttentionCallable = AttentionFunction.DEFAULT,
|
| 49 |
+
positional_embedding_theta: float = 10000.0,
|
| 50 |
+
positional_embedding_max_pos: list[int] | None = None,
|
| 51 |
+
timestep_scale_multiplier: int = 1000,
|
| 52 |
+
use_middle_indices_grid: bool = True,
|
| 53 |
+
audio_num_attention_heads: int = 32,
|
| 54 |
+
audio_attention_head_dim: int = 64,
|
| 55 |
+
audio_in_channels: int = 128,
|
| 56 |
+
audio_out_channels: int = 128,
|
| 57 |
+
audio_cross_attention_dim: int = 2048,
|
| 58 |
+
audio_positional_embedding_max_pos: list[int] | None = None,
|
| 59 |
+
av_ca_timestep_scale_multiplier: int = 1,
|
| 60 |
+
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
| 61 |
+
double_precision_rope: bool = False,
|
| 62 |
+
apply_gated_attention: bool = False,
|
| 63 |
+
caption_projection: torch.nn.Module | None = None,
|
| 64 |
+
audio_caption_projection: torch.nn.Module | None = None,
|
| 65 |
+
cross_attention_adaln: bool = False,
|
| 66 |
+
):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self._enable_gradient_checkpointing = False
|
| 69 |
+
self.cross_attention_adaln = cross_attention_adaln
|
| 70 |
+
self.use_middle_indices_grid = use_middle_indices_grid
|
| 71 |
+
self.rope_type = rope_type
|
| 72 |
+
self.double_precision_rope = double_precision_rope
|
| 73 |
+
self.timestep_scale_multiplier = timestep_scale_multiplier
|
| 74 |
+
self.positional_embedding_theta = positional_embedding_theta
|
| 75 |
+
self.model_type = model_type
|
| 76 |
+
cross_pe_max_pos = None
|
| 77 |
+
if model_type.is_video_enabled():
|
| 78 |
+
if positional_embedding_max_pos is None:
|
| 79 |
+
positional_embedding_max_pos = [20, 2048, 2048]
|
| 80 |
+
self.positional_embedding_max_pos = positional_embedding_max_pos
|
| 81 |
+
self.num_attention_heads = num_attention_heads
|
| 82 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 83 |
+
self._init_video(
|
| 84 |
+
in_channels=in_channels,
|
| 85 |
+
out_channels=out_channels,
|
| 86 |
+
norm_eps=norm_eps,
|
| 87 |
+
caption_projection=caption_projection,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if model_type.is_audio_enabled():
|
| 91 |
+
if audio_positional_embedding_max_pos is None:
|
| 92 |
+
audio_positional_embedding_max_pos = [20]
|
| 93 |
+
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
| 94 |
+
self.audio_num_attention_heads = audio_num_attention_heads
|
| 95 |
+
self.audio_inner_dim = self.audio_num_attention_heads * audio_attention_head_dim
|
| 96 |
+
self._init_audio(
|
| 97 |
+
in_channels=audio_in_channels,
|
| 98 |
+
out_channels=audio_out_channels,
|
| 99 |
+
norm_eps=norm_eps,
|
| 100 |
+
caption_projection=audio_caption_projection,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if model_type.is_video_enabled() and model_type.is_audio_enabled():
|
| 104 |
+
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
|
| 105 |
+
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
| 106 |
+
self.audio_cross_attention_dim = audio_cross_attention_dim
|
| 107 |
+
self._init_audio_video(num_scale_shift_values=4)
|
| 108 |
+
|
| 109 |
+
self._init_preprocessors(cross_pe_max_pos)
|
| 110 |
+
# Initialize transformer blocks
|
| 111 |
+
self._init_transformer_blocks(
|
| 112 |
+
num_layers=num_layers,
|
| 113 |
+
attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0,
|
| 114 |
+
cross_attention_dim=cross_attention_dim,
|
| 115 |
+
audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0,
|
| 116 |
+
audio_cross_attention_dim=audio_cross_attention_dim,
|
| 117 |
+
norm_eps=norm_eps,
|
| 118 |
+
attention_type=attention_type,
|
| 119 |
+
apply_gated_attention=apply_gated_attention,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def _adaln_embedding_coefficient(self) -> int:
|
| 124 |
+
return adaln_embedding_coefficient(self.cross_attention_adaln)
|
| 125 |
+
|
| 126 |
+
def _init_video(
|
| 127 |
+
self,
|
| 128 |
+
in_channels: int,
|
| 129 |
+
out_channels: int,
|
| 130 |
+
norm_eps: float,
|
| 131 |
+
caption_projection: torch.nn.Module | None = None,
|
| 132 |
+
) -> None:
|
| 133 |
+
"""Initialize video-specific components."""
|
| 134 |
+
# Video input components
|
| 135 |
+
self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True)
|
| 136 |
+
if caption_projection is not None:
|
| 137 |
+
self.caption_projection = caption_projection
|
| 138 |
+
|
| 139 |
+
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=self._adaln_embedding_coefficient)
|
| 140 |
+
|
| 141 |
+
self.prompt_adaln_single = (
|
| 142 |
+
AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Video output components
|
| 146 |
+
self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim))
|
| 147 |
+
self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps)
|
| 148 |
+
self.proj_out = torch.nn.Linear(self.inner_dim, out_channels)
|
| 149 |
+
|
| 150 |
+
def _init_audio(
|
| 151 |
+
self,
|
| 152 |
+
in_channels: int,
|
| 153 |
+
out_channels: int,
|
| 154 |
+
norm_eps: float,
|
| 155 |
+
caption_projection: torch.nn.Module | None = None,
|
| 156 |
+
) -> None:
|
| 157 |
+
"""Initialize audio-specific components."""
|
| 158 |
+
|
| 159 |
+
# Audio input components
|
| 160 |
+
self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True)
|
| 161 |
+
if caption_projection is not None:
|
| 162 |
+
self.audio_caption_projection = caption_projection
|
| 163 |
+
|
| 164 |
+
self.audio_adaln_single = AdaLayerNormSingle(
|
| 165 |
+
self.audio_inner_dim,
|
| 166 |
+
embedding_coefficient=self._adaln_embedding_coefficient,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
self.audio_prompt_adaln_single = (
|
| 170 |
+
AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Audio output components
|
| 174 |
+
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim))
|
| 175 |
+
self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps)
|
| 176 |
+
self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels)
|
| 177 |
+
|
| 178 |
+
def _init_audio_video(
|
| 179 |
+
self,
|
| 180 |
+
num_scale_shift_values: int,
|
| 181 |
+
) -> None:
|
| 182 |
+
"""Initialize audio-video cross-attention components."""
|
| 183 |
+
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
| 184 |
+
self.inner_dim,
|
| 185 |
+
embedding_coefficient=num_scale_shift_values,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
| 189 |
+
self.audio_inner_dim,
|
| 190 |
+
embedding_coefficient=num_scale_shift_values,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
| 194 |
+
self.inner_dim,
|
| 195 |
+
embedding_coefficient=1,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
| 199 |
+
self.audio_inner_dim,
|
| 200 |
+
embedding_coefficient=1,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
def _init_preprocessors(
|
| 204 |
+
self,
|
| 205 |
+
cross_pe_max_pos: int | None = None,
|
| 206 |
+
) -> None:
|
| 207 |
+
"""Initialize preprocessors for LTX."""
|
| 208 |
+
|
| 209 |
+
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
|
| 210 |
+
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
| 211 |
+
patchify_proj=self.patchify_proj,
|
| 212 |
+
adaln=self.adaln_single,
|
| 213 |
+
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
|
| 214 |
+
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
|
| 215 |
+
inner_dim=self.inner_dim,
|
| 216 |
+
max_pos=self.positional_embedding_max_pos,
|
| 217 |
+
num_attention_heads=self.num_attention_heads,
|
| 218 |
+
cross_pe_max_pos=cross_pe_max_pos,
|
| 219 |
+
use_middle_indices_grid=self.use_middle_indices_grid,
|
| 220 |
+
audio_cross_attention_dim=self.audio_cross_attention_dim,
|
| 221 |
+
timestep_scale_multiplier=self.timestep_scale_multiplier,
|
| 222 |
+
double_precision_rope=self.double_precision_rope,
|
| 223 |
+
positional_embedding_theta=self.positional_embedding_theta,
|
| 224 |
+
rope_type=self.rope_type,
|
| 225 |
+
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
|
| 226 |
+
caption_projection=getattr(self, "caption_projection", None),
|
| 227 |
+
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
| 228 |
+
)
|
| 229 |
+
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
| 230 |
+
patchify_proj=self.audio_patchify_proj,
|
| 231 |
+
adaln=self.audio_adaln_single,
|
| 232 |
+
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
|
| 233 |
+
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
|
| 234 |
+
inner_dim=self.audio_inner_dim,
|
| 235 |
+
max_pos=self.audio_positional_embedding_max_pos,
|
| 236 |
+
num_attention_heads=self.audio_num_attention_heads,
|
| 237 |
+
cross_pe_max_pos=cross_pe_max_pos,
|
| 238 |
+
use_middle_indices_grid=self.use_middle_indices_grid,
|
| 239 |
+
audio_cross_attention_dim=self.audio_cross_attention_dim,
|
| 240 |
+
timestep_scale_multiplier=self.timestep_scale_multiplier,
|
| 241 |
+
double_precision_rope=self.double_precision_rope,
|
| 242 |
+
positional_embedding_theta=self.positional_embedding_theta,
|
| 243 |
+
rope_type=self.rope_type,
|
| 244 |
+
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
|
| 245 |
+
caption_projection=getattr(self, "audio_caption_projection", None),
|
| 246 |
+
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
| 247 |
+
)
|
| 248 |
+
elif self.model_type.is_video_enabled():
|
| 249 |
+
self.video_args_preprocessor = TransformerArgsPreprocessor(
|
| 250 |
+
patchify_proj=self.patchify_proj,
|
| 251 |
+
adaln=self.adaln_single,
|
| 252 |
+
inner_dim=self.inner_dim,
|
| 253 |
+
max_pos=self.positional_embedding_max_pos,
|
| 254 |
+
num_attention_heads=self.num_attention_heads,
|
| 255 |
+
use_middle_indices_grid=self.use_middle_indices_grid,
|
| 256 |
+
timestep_scale_multiplier=self.timestep_scale_multiplier,
|
| 257 |
+
double_precision_rope=self.double_precision_rope,
|
| 258 |
+
positional_embedding_theta=self.positional_embedding_theta,
|
| 259 |
+
rope_type=self.rope_type,
|
| 260 |
+
caption_projection=getattr(self, "caption_projection", None),
|
| 261 |
+
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
| 262 |
+
)
|
| 263 |
+
elif self.model_type.is_audio_enabled():
|
| 264 |
+
self.audio_args_preprocessor = TransformerArgsPreprocessor(
|
| 265 |
+
patchify_proj=self.audio_patchify_proj,
|
| 266 |
+
adaln=self.audio_adaln_single,
|
| 267 |
+
inner_dim=self.audio_inner_dim,
|
| 268 |
+
max_pos=self.audio_positional_embedding_max_pos,
|
| 269 |
+
num_attention_heads=self.audio_num_attention_heads,
|
| 270 |
+
use_middle_indices_grid=self.use_middle_indices_grid,
|
| 271 |
+
timestep_scale_multiplier=self.timestep_scale_multiplier,
|
| 272 |
+
double_precision_rope=self.double_precision_rope,
|
| 273 |
+
positional_embedding_theta=self.positional_embedding_theta,
|
| 274 |
+
rope_type=self.rope_type,
|
| 275 |
+
caption_projection=getattr(self, "audio_caption_projection", None),
|
| 276 |
+
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
def _init_transformer_blocks(
|
| 280 |
+
self,
|
| 281 |
+
num_layers: int,
|
| 282 |
+
attention_head_dim: int,
|
| 283 |
+
cross_attention_dim: int,
|
| 284 |
+
audio_attention_head_dim: int,
|
| 285 |
+
audio_cross_attention_dim: int,
|
| 286 |
+
norm_eps: float,
|
| 287 |
+
attention_type: AttentionFunction | AttentionCallable,
|
| 288 |
+
apply_gated_attention: bool,
|
| 289 |
+
) -> None:
|
| 290 |
+
"""Initialize transformer blocks for LTX."""
|
| 291 |
+
video_config = (
|
| 292 |
+
TransformerConfig(
|
| 293 |
+
dim=self.inner_dim,
|
| 294 |
+
heads=self.num_attention_heads,
|
| 295 |
+
d_head=attention_head_dim,
|
| 296 |
+
context_dim=cross_attention_dim,
|
| 297 |
+
apply_gated_attention=apply_gated_attention,
|
| 298 |
+
cross_attention_adaln=self.cross_attention_adaln,
|
| 299 |
+
)
|
| 300 |
+
if self.model_type.is_video_enabled()
|
| 301 |
+
else None
|
| 302 |
+
)
|
| 303 |
+
audio_config = (
|
| 304 |
+
TransformerConfig(
|
| 305 |
+
dim=self.audio_inner_dim,
|
| 306 |
+
heads=self.audio_num_attention_heads,
|
| 307 |
+
d_head=audio_attention_head_dim,
|
| 308 |
+
context_dim=audio_cross_attention_dim,
|
| 309 |
+
apply_gated_attention=apply_gated_attention,
|
| 310 |
+
cross_attention_adaln=self.cross_attention_adaln,
|
| 311 |
+
)
|
| 312 |
+
if self.model_type.is_audio_enabled()
|
| 313 |
+
else None
|
| 314 |
+
)
|
| 315 |
+
self.transformer_blocks = torch.nn.ModuleList(
|
| 316 |
+
[
|
| 317 |
+
BasicAVTransformerBlock(
|
| 318 |
+
idx=idx,
|
| 319 |
+
video=video_config,
|
| 320 |
+
audio=audio_config,
|
| 321 |
+
rope_type=self.rope_type,
|
| 322 |
+
norm_eps=norm_eps,
|
| 323 |
+
attention_function=attention_type,
|
| 324 |
+
)
|
| 325 |
+
for idx in range(num_layers)
|
| 326 |
+
]
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
def set_gradient_checkpointing(self, enable: bool) -> None:
|
| 330 |
+
"""Enable or disable gradient checkpointing for transformer blocks.
|
| 331 |
+
Gradient checkpointing trades compute for memory by recomputing activations
|
| 332 |
+
during the backward pass instead of storing them. This can significantly
|
| 333 |
+
reduce memory usage at the cost of ~20-30% slower training.
|
| 334 |
+
Args:
|
| 335 |
+
enable: Whether to enable gradient checkpointing
|
| 336 |
+
"""
|
| 337 |
+
self._enable_gradient_checkpointing = enable
|
| 338 |
+
|
| 339 |
+
def _process_transformer_blocks(
|
| 340 |
+
self,
|
| 341 |
+
video: TransformerArgs | None,
|
| 342 |
+
audio: TransformerArgs | None,
|
| 343 |
+
perturbations: BatchedPerturbationConfig,
|
| 344 |
+
) -> tuple[TransformerArgs, TransformerArgs]:
|
| 345 |
+
"""Process transformer blocks for LTXAV."""
|
| 346 |
+
|
| 347 |
+
# Process transformer blocks
|
| 348 |
+
for block in self.transformer_blocks:
|
| 349 |
+
if self._enable_gradient_checkpointing and self.training:
|
| 350 |
+
# Use gradient checkpointing to save memory during training.
|
| 351 |
+
# With use_reentrant=False, we can pass dataclasses directly -
|
| 352 |
+
# PyTorch will track all tensor leaves in the computation graph.
|
| 353 |
+
video, audio = torch.utils.checkpoint.checkpoint(
|
| 354 |
+
block,
|
| 355 |
+
video,
|
| 356 |
+
audio,
|
| 357 |
+
perturbations,
|
| 358 |
+
use_reentrant=False,
|
| 359 |
+
)
|
| 360 |
+
else:
|
| 361 |
+
video, audio = block(
|
| 362 |
+
video=video,
|
| 363 |
+
audio=audio,
|
| 364 |
+
perturbations=perturbations,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
return video, audio
|
| 368 |
+
|
| 369 |
+
def _process_output(
|
| 370 |
+
self,
|
| 371 |
+
scale_shift_table: torch.Tensor,
|
| 372 |
+
norm_out: torch.nn.LayerNorm,
|
| 373 |
+
proj_out: torch.nn.Linear,
|
| 374 |
+
x: torch.Tensor,
|
| 375 |
+
embedded_timestep: torch.Tensor,
|
| 376 |
+
) -> torch.Tensor:
|
| 377 |
+
"""Process output for LTXV."""
|
| 378 |
+
# Apply scale-shift modulation
|
| 379 |
+
scale_shift_values = (
|
| 380 |
+
scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
| 381 |
+
)
|
| 382 |
+
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
| 383 |
+
|
| 384 |
+
x = norm_out(x)
|
| 385 |
+
x = x * (1 + scale) + shift
|
| 386 |
+
x = proj_out(x)
|
| 387 |
+
return x
|
| 388 |
+
|
| 389 |
+
def forward(
|
| 390 |
+
self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
|
| 391 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 392 |
+
"""
|
| 393 |
+
Forward pass for LTX models.
|
| 394 |
+
Returns:
|
| 395 |
+
Processed output tensors
|
| 396 |
+
"""
|
| 397 |
+
if not self.model_type.is_video_enabled() and video is not None:
|
| 398 |
+
raise ValueError("Video is not enabled for this model")
|
| 399 |
+
if not self.model_type.is_audio_enabled() and audio is not None:
|
| 400 |
+
raise ValueError("Audio is not enabled for this model")
|
| 401 |
+
|
| 402 |
+
video_args = self.video_args_preprocessor.prepare(video, audio) if video is not None else None
|
| 403 |
+
audio_args = self.audio_args_preprocessor.prepare(audio, video) if audio is not None else None
|
| 404 |
+
# Process transformer blocks
|
| 405 |
+
video_out, audio_out = self._process_transformer_blocks(
|
| 406 |
+
video=video_args,
|
| 407 |
+
audio=audio_args,
|
| 408 |
+
perturbations=perturbations,
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# Process output
|
| 412 |
+
vx = (
|
| 413 |
+
self._process_output(
|
| 414 |
+
self.scale_shift_table, self.norm_out, self.proj_out, video_out.x, video_out.embedded_timestep
|
| 415 |
+
)
|
| 416 |
+
if video_out is not None
|
| 417 |
+
else None
|
| 418 |
+
)
|
| 419 |
+
ax = (
|
| 420 |
+
self._process_output(
|
| 421 |
+
self.audio_scale_shift_table,
|
| 422 |
+
self.audio_norm_out,
|
| 423 |
+
self.audio_proj_out,
|
| 424 |
+
audio_out.x,
|
| 425 |
+
audio_out.embedded_timestep,
|
| 426 |
+
)
|
| 427 |
+
if audio_out is not None
|
| 428 |
+
else None
|
| 429 |
+
)
|
| 430 |
+
return vx, ax
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class LegacyX0Model(torch.nn.Module):
|
| 434 |
+
"""
|
| 435 |
+
Legacy X0 model implementation.
|
| 436 |
+
Returns fully denoised output based on the velocities produced by the base model.
|
| 437 |
+
"""
|
| 438 |
+
|
| 439 |
+
def __init__(self, velocity_model: LTXModel):
|
| 440 |
+
super().__init__()
|
| 441 |
+
self.velocity_model = velocity_model
|
| 442 |
+
|
| 443 |
+
def forward(
|
| 444 |
+
self,
|
| 445 |
+
video: Modality | None,
|
| 446 |
+
audio: Modality | None,
|
| 447 |
+
perturbations: BatchedPerturbationConfig,
|
| 448 |
+
sigma: float,
|
| 449 |
+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
| 450 |
+
"""
|
| 451 |
+
Denoise the video and audio according to the sigma.
|
| 452 |
+
Returns:
|
| 453 |
+
Denoised video and audio
|
| 454 |
+
"""
|
| 455 |
+
vx, ax = self.velocity_model(video, audio, perturbations)
|
| 456 |
+
denoised_video = to_denoised(video.latent, vx, sigma) if vx is not None else None
|
| 457 |
+
denoised_audio = to_denoised(audio.latent, ax, sigma) if ax is not None else None
|
| 458 |
+
return denoised_video, denoised_audio
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
class X0Model(torch.nn.Module):
|
| 462 |
+
"""
|
| 463 |
+
X0 model implementation.
|
| 464 |
+
Returns fully denoised outputs based on the velocities produced by the base model.
|
| 465 |
+
Applies scaled denoising to the video and audio according to the timesteps = sigma * denoising_mask.
|
| 466 |
+
"""
|
| 467 |
+
|
| 468 |
+
def __init__(self, velocity_model: LTXModel):
|
| 469 |
+
super().__init__()
|
| 470 |
+
self.velocity_model = velocity_model
|
| 471 |
+
|
| 472 |
+
def forward(
|
| 473 |
+
self,
|
| 474 |
+
video: Modality | None,
|
| 475 |
+
audio: Modality | None,
|
| 476 |
+
perturbations: BatchedPerturbationConfig,
|
| 477 |
+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
| 478 |
+
"""
|
| 479 |
+
Denoise the video and audio according to the sigma.
|
| 480 |
+
Returns:
|
| 481 |
+
Denoised video and audio
|
| 482 |
+
"""
|
| 483 |
+
vx, ax = self.velocity_model(video, audio, perturbations)
|
| 484 |
+
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
|
| 485 |
+
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
|
| 486 |
+
return denoised_video, denoised_audio
|
packages/ltx-core/src/ltx_core/model/transformer/model_configurator.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 4 |
+
from ltx_core.model.model_protocol import ModelConfigurator
|
| 5 |
+
from ltx_core.model.transformer.attention import AttentionFunction
|
| 6 |
+
from ltx_core.model.transformer.model import LTXModel, LTXModelType
|
| 7 |
+
from ltx_core.model.transformer.rope import LTXRopeType
|
| 8 |
+
from ltx_core.model.transformer.text_projection import create_caption_projection
|
| 9 |
+
from ltx_core.utils import check_config_value
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LTXModelConfigurator(ModelConfigurator[LTXModel]):
|
| 13 |
+
"""
|
| 14 |
+
Configurator for LTX model.
|
| 15 |
+
Used to create an LTX model from a configuration dictionary.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
@classmethod
|
| 19 |
+
def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
|
| 20 |
+
# Build caption projections for 19B models (projection handled in transformer).
|
| 21 |
+
caption_projection, audio_caption_projection = _build_caption_projections(config, is_av=True)
|
| 22 |
+
|
| 23 |
+
config = config.get("transformer", {})
|
| 24 |
+
|
| 25 |
+
check_config_value(config, "dropout", 0.0)
|
| 26 |
+
check_config_value(config, "attention_bias", True)
|
| 27 |
+
check_config_value(config, "num_vector_embeds", None)
|
| 28 |
+
check_config_value(config, "activation_fn", "gelu-approximate")
|
| 29 |
+
check_config_value(config, "num_embeds_ada_norm", 1000)
|
| 30 |
+
check_config_value(config, "use_linear_projection", False)
|
| 31 |
+
check_config_value(config, "only_cross_attention", False)
|
| 32 |
+
check_config_value(config, "cross_attention_norm", True)
|
| 33 |
+
check_config_value(config, "double_self_attention", False)
|
| 34 |
+
check_config_value(config, "upcast_attention", False)
|
| 35 |
+
check_config_value(config, "standardization_norm", "rms_norm")
|
| 36 |
+
check_config_value(config, "norm_elementwise_affine", False)
|
| 37 |
+
check_config_value(config, "qk_norm", "rms_norm")
|
| 38 |
+
check_config_value(config, "positional_embedding_type", "rope")
|
| 39 |
+
check_config_value(config, "use_audio_video_cross_attention", True)
|
| 40 |
+
check_config_value(config, "share_ff", False)
|
| 41 |
+
check_config_value(config, "av_cross_ada_norm", True)
|
| 42 |
+
check_config_value(config, "use_middle_indices_grid", True)
|
| 43 |
+
|
| 44 |
+
return LTXModel(
|
| 45 |
+
model_type=LTXModelType.AudioVideo,
|
| 46 |
+
num_attention_heads=config.get("num_attention_heads", 32),
|
| 47 |
+
attention_head_dim=config.get("attention_head_dim", 128),
|
| 48 |
+
in_channels=config.get("in_channels", 128),
|
| 49 |
+
out_channels=config.get("out_channels", 128),
|
| 50 |
+
num_layers=config.get("num_layers", 48),
|
| 51 |
+
cross_attention_dim=config.get("cross_attention_dim", 4096),
|
| 52 |
+
norm_eps=config.get("norm_eps", 1e-06),
|
| 53 |
+
attention_type=AttentionFunction(config.get("attention_type", "default")),
|
| 54 |
+
positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
|
| 55 |
+
positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
|
| 56 |
+
timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
|
| 57 |
+
use_middle_indices_grid=config.get("use_middle_indices_grid", True),
|
| 58 |
+
audio_num_attention_heads=config.get("audio_num_attention_heads", 32),
|
| 59 |
+
audio_attention_head_dim=config.get("audio_attention_head_dim", 64),
|
| 60 |
+
audio_in_channels=config.get("audio_in_channels", 128),
|
| 61 |
+
audio_out_channels=config.get("audio_out_channels", 128),
|
| 62 |
+
audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048),
|
| 63 |
+
audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]),
|
| 64 |
+
av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1),
|
| 65 |
+
rope_type=LTXRopeType(config.get("rope_type", "interleaved")),
|
| 66 |
+
double_precision_rope=config.get("frequencies_precision", False) == "float64",
|
| 67 |
+
apply_gated_attention=config.get("apply_gated_attention", False),
|
| 68 |
+
caption_projection=caption_projection,
|
| 69 |
+
audio_caption_projection=audio_caption_projection,
|
| 70 |
+
cross_attention_adaln=config.get("cross_attention_adaln", False),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class LTXVideoOnlyModelConfigurator(ModelConfigurator[LTXModel]):
|
| 75 |
+
"""
|
| 76 |
+
Configurator for LTX video only model.
|
| 77 |
+
Used to create an LTX video only model from a configuration dictionary.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
@classmethod
|
| 81 |
+
def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
|
| 82 |
+
# Build caption projection for 19B model (projection handled in transformer).
|
| 83 |
+
caption_projection, _ = _build_caption_projections(config, is_av=False)
|
| 84 |
+
|
| 85 |
+
config = config.get("transformer", {})
|
| 86 |
+
|
| 87 |
+
check_config_value(config, "dropout", 0.0)
|
| 88 |
+
check_config_value(config, "attention_bias", True)
|
| 89 |
+
check_config_value(config, "num_vector_embeds", None)
|
| 90 |
+
check_config_value(config, "activation_fn", "gelu-approximate")
|
| 91 |
+
check_config_value(config, "num_embeds_ada_norm", 1000)
|
| 92 |
+
check_config_value(config, "use_linear_projection", False)
|
| 93 |
+
check_config_value(config, "only_cross_attention", False)
|
| 94 |
+
check_config_value(config, "cross_attention_norm", True)
|
| 95 |
+
check_config_value(config, "double_self_attention", False)
|
| 96 |
+
check_config_value(config, "upcast_attention", False)
|
| 97 |
+
check_config_value(config, "standardization_norm", "rms_norm")
|
| 98 |
+
check_config_value(config, "norm_elementwise_affine", False)
|
| 99 |
+
check_config_value(config, "qk_norm", "rms_norm")
|
| 100 |
+
check_config_value(config, "positional_embedding_type", "rope")
|
| 101 |
+
check_config_value(config, "use_middle_indices_grid", True)
|
| 102 |
+
|
| 103 |
+
return LTXModel(
|
| 104 |
+
model_type=LTXModelType.VideoOnly,
|
| 105 |
+
num_attention_heads=config.get("num_attention_heads", 32),
|
| 106 |
+
attention_head_dim=config.get("attention_head_dim", 128),
|
| 107 |
+
in_channels=config.get("in_channels", 128),
|
| 108 |
+
out_channels=config.get("out_channels", 128),
|
| 109 |
+
num_layers=config.get("num_layers", 48),
|
| 110 |
+
cross_attention_dim=config.get("cross_attention_dim", 4096),
|
| 111 |
+
norm_eps=config.get("norm_eps", 1e-06),
|
| 112 |
+
attention_type=AttentionFunction(config.get("attention_type", "default")),
|
| 113 |
+
positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
|
| 114 |
+
positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
|
| 115 |
+
timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
|
| 116 |
+
use_middle_indices_grid=config.get("use_middle_indices_grid", True),
|
| 117 |
+
rope_type=LTXRopeType(config.get("rope_type", "interleaved")),
|
| 118 |
+
double_precision_rope=config.get("frequencies_precision", False) == "float64",
|
| 119 |
+
apply_gated_attention=config.get("apply_gated_attention", False),
|
| 120 |
+
caption_projection=caption_projection,
|
| 121 |
+
cross_attention_adaln=config.get("cross_attention_adaln", False),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _build_caption_projections(
|
| 126 |
+
config: dict,
|
| 127 |
+
is_av: bool,
|
| 128 |
+
) -> tuple[torch.nn.Module | None, torch.nn.Module | None]:
|
| 129 |
+
"""Build caption projections for the transformer when projection is NOT in the text encoder.
|
| 130 |
+
19B models: projection is in the transformer (caption_proj_before_connector=False).
|
| 131 |
+
22B models: projection is in the text encoder, so no projections are created here.
|
| 132 |
+
Args:
|
| 133 |
+
config: Full model config dict (must contain "transformer" key).
|
| 134 |
+
is_av: Whether this is an audio-video model. When False, audio projection is skipped.
|
| 135 |
+
Returns:
|
| 136 |
+
Tuple of (video_caption_projection, audio_caption_projection), both None for 22B models.
|
| 137 |
+
"""
|
| 138 |
+
transformer_config = config.get("transformer", {})
|
| 139 |
+
if transformer_config.get("caption_proj_before_connector", False):
|
| 140 |
+
return None, None
|
| 141 |
+
|
| 142 |
+
with torch.device("meta"):
|
| 143 |
+
caption_projection = create_caption_projection(transformer_config)
|
| 144 |
+
audio_caption_projection = create_caption_projection(transformer_config, audio=True) if is_av else None
|
| 145 |
+
return caption_projection, audio_caption_projection
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
LTXV_MODEL_COMFY_RENAMING_MAP = (
|
| 149 |
+
SDOps("LTXV_MODEL_COMFY_PREFIX_MAP")
|
| 150 |
+
.with_matching(prefix="model.diffusion_model.")
|
| 151 |
+
.with_replacement("model.diffusion_model.", "")
|
| 152 |
+
)
|
packages/ltx-core/src/ltx_core/model/transformer/rope.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import math
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from typing import Callable, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LTXRopeType(Enum):
|
| 12 |
+
INTERLEAVED = "interleaved"
|
| 13 |
+
SPLIT = "split"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def apply_rotary_emb(
|
| 17 |
+
input_tensor: torch.Tensor,
|
| 18 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
|
| 19 |
+
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
| 20 |
+
) -> torch.Tensor:
|
| 21 |
+
if rope_type == LTXRopeType.INTERLEAVED:
|
| 22 |
+
return apply_interleaved_rotary_emb(input_tensor, *freqs_cis)
|
| 23 |
+
elif rope_type == LTXRopeType.SPLIT:
|
| 24 |
+
return apply_split_rotary_emb(input_tensor, *freqs_cis)
|
| 25 |
+
else:
|
| 26 |
+
raise ValueError(f"Invalid rope type: {rope_type}")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def apply_interleaved_rotary_emb(
|
| 30 |
+
input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor
|
| 31 |
+
) -> torch.Tensor:
|
| 32 |
+
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
| 33 |
+
t1, t2 = t_dup.unbind(dim=-1)
|
| 34 |
+
t_dup = torch.stack((-t2, t1), dim=-1)
|
| 35 |
+
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
| 36 |
+
|
| 37 |
+
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
| 38 |
+
|
| 39 |
+
return out
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def apply_split_rotary_emb(
|
| 43 |
+
input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor
|
| 44 |
+
) -> torch.Tensor:
|
| 45 |
+
needs_reshape = False
|
| 46 |
+
if input_tensor.ndim != 4 and cos_freqs.ndim == 4:
|
| 47 |
+
b, h, t, _ = cos_freqs.shape
|
| 48 |
+
input_tensor = input_tensor.reshape(b, t, h, -1).swapaxes(1, 2)
|
| 49 |
+
needs_reshape = True
|
| 50 |
+
|
| 51 |
+
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
|
| 52 |
+
first_half_input = split_input[..., :1, :]
|
| 53 |
+
second_half_input = split_input[..., 1:, :]
|
| 54 |
+
|
| 55 |
+
output = split_input * cos_freqs.unsqueeze(-2)
|
| 56 |
+
first_half_output = output[..., :1, :]
|
| 57 |
+
second_half_output = output[..., 1:, :]
|
| 58 |
+
|
| 59 |
+
first_half_output.addcmul_(-sin_freqs.unsqueeze(-2), second_half_input)
|
| 60 |
+
second_half_output.addcmul_(sin_freqs.unsqueeze(-2), first_half_input)
|
| 61 |
+
|
| 62 |
+
output = rearrange(output, "... d r -> ... (d r)")
|
| 63 |
+
if needs_reshape:
|
| 64 |
+
output = output.swapaxes(1, 2).reshape(b, t, -1)
|
| 65 |
+
|
| 66 |
+
return output
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@functools.lru_cache(maxsize=5)
|
| 70 |
+
def generate_freq_grid_np(
|
| 71 |
+
positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int
|
| 72 |
+
) -> torch.Tensor:
|
| 73 |
+
theta = positional_embedding_theta
|
| 74 |
+
start = 1
|
| 75 |
+
end = theta
|
| 76 |
+
|
| 77 |
+
n_elem = 2 * positional_embedding_max_pos_count
|
| 78 |
+
pow_indices = np.power(
|
| 79 |
+
theta,
|
| 80 |
+
np.linspace(
|
| 81 |
+
np.log(start) / np.log(theta),
|
| 82 |
+
np.log(end) / np.log(theta),
|
| 83 |
+
inner_dim // n_elem,
|
| 84 |
+
dtype=np.float64,
|
| 85 |
+
),
|
| 86 |
+
)
|
| 87 |
+
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@functools.lru_cache(maxsize=5)
|
| 91 |
+
def generate_freq_grid_pytorch(
|
| 92 |
+
positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int
|
| 93 |
+
) -> torch.Tensor:
|
| 94 |
+
theta = positional_embedding_theta
|
| 95 |
+
start = 1
|
| 96 |
+
end = theta
|
| 97 |
+
n_elem = 2 * positional_embedding_max_pos_count
|
| 98 |
+
|
| 99 |
+
indices = theta ** (
|
| 100 |
+
torch.linspace(
|
| 101 |
+
math.log(start, theta),
|
| 102 |
+
math.log(end, theta),
|
| 103 |
+
inner_dim // n_elem,
|
| 104 |
+
dtype=torch.float32,
|
| 105 |
+
)
|
| 106 |
+
)
|
| 107 |
+
indices = indices.to(dtype=torch.float32)
|
| 108 |
+
|
| 109 |
+
indices = indices * math.pi / 2
|
| 110 |
+
|
| 111 |
+
return indices
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def get_fractional_positions(indices_grid: torch.Tensor, max_pos: list[int]) -> torch.Tensor:
|
| 115 |
+
n_pos_dims = indices_grid.shape[1]
|
| 116 |
+
assert n_pos_dims == len(max_pos), (
|
| 117 |
+
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
| 118 |
+
)
|
| 119 |
+
fractional_positions = torch.stack(
|
| 120 |
+
[indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
|
| 121 |
+
dim=-1,
|
| 122 |
+
)
|
| 123 |
+
return fractional_positions
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def generate_freqs(
|
| 127 |
+
indices: torch.Tensor, indices_grid: torch.Tensor, max_pos: list[int], use_middle_indices_grid: bool
|
| 128 |
+
) -> torch.Tensor:
|
| 129 |
+
if use_middle_indices_grid:
|
| 130 |
+
assert len(indices_grid.shape) == 4
|
| 131 |
+
assert indices_grid.shape[-1] == 2
|
| 132 |
+
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
|
| 133 |
+
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
| 134 |
+
elif len(indices_grid.shape) == 4:
|
| 135 |
+
indices_grid = indices_grid[..., 0]
|
| 136 |
+
|
| 137 |
+
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
| 138 |
+
indices = indices.to(device=fractional_positions.device)
|
| 139 |
+
|
| 140 |
+
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
| 141 |
+
return freqs
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def split_freqs_cis(freqs: torch.Tensor, pad_size: int, num_attention_heads: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 145 |
+
cos_freq = freqs.cos()
|
| 146 |
+
sin_freq = freqs.sin()
|
| 147 |
+
|
| 148 |
+
if pad_size != 0:
|
| 149 |
+
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
| 150 |
+
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
|
| 151 |
+
|
| 152 |
+
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
|
| 153 |
+
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
|
| 154 |
+
|
| 155 |
+
# Reshape freqs to be compatible with multi-head attention
|
| 156 |
+
b = cos_freq.shape[0]
|
| 157 |
+
t = cos_freq.shape[1]
|
| 158 |
+
|
| 159 |
+
cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1)
|
| 160 |
+
sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1)
|
| 161 |
+
|
| 162 |
+
cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
| 163 |
+
sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
| 164 |
+
return cos_freq, sin_freq
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def interleaved_freqs_cis(freqs: torch.Tensor, pad_size: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 168 |
+
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
| 169 |
+
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
| 170 |
+
if pad_size != 0:
|
| 171 |
+
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
| 172 |
+
sin_padding = torch.zeros_like(cos_freq[:, :, :pad_size])
|
| 173 |
+
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
| 174 |
+
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
| 175 |
+
return cos_freq, sin_freq
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def precompute_freqs_cis(
|
| 179 |
+
indices_grid: torch.Tensor,
|
| 180 |
+
dim: int,
|
| 181 |
+
out_dtype: torch.dtype,
|
| 182 |
+
theta: float = 10000.0,
|
| 183 |
+
max_pos: list[int] | None = None,
|
| 184 |
+
use_middle_indices_grid: bool = False,
|
| 185 |
+
num_attention_heads: int = 32,
|
| 186 |
+
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
| 187 |
+
freq_grid_generator: Callable[[float, int, int, torch.device], torch.Tensor] = generate_freq_grid_pytorch,
|
| 188 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 189 |
+
if max_pos is None:
|
| 190 |
+
max_pos = [20, 2048, 2048]
|
| 191 |
+
|
| 192 |
+
indices = freq_grid_generator(theta, indices_grid.shape[1], dim)
|
| 193 |
+
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
| 194 |
+
|
| 195 |
+
if rope_type == LTXRopeType.SPLIT:
|
| 196 |
+
expected_freqs = dim // 2
|
| 197 |
+
current_freqs = freqs.shape[-1]
|
| 198 |
+
pad_size = expected_freqs - current_freqs
|
| 199 |
+
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
| 200 |
+
else:
|
| 201 |
+
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
|
| 202 |
+
n_elem = 2 * indices_grid.shape[1]
|
| 203 |
+
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
| 204 |
+
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
packages/ltx-core/src/ltx_core/model/transformer/text_projection.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class PixArtAlphaTextProjection(torch.nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Projects caption embeddings using dual linear layers.
|
| 7 |
+
Flow: linear_1 → activation → linear_2
|
| 8 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, in_features: int, hidden_size: int, out_features: int | None = None, act_fn: str = "gelu_tanh"):
|
| 12 |
+
super().__init__()
|
| 13 |
+
if out_features is None:
|
| 14 |
+
out_features = hidden_size
|
| 15 |
+
self.linear_1 = torch.nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
| 16 |
+
if act_fn == "gelu_tanh":
|
| 17 |
+
self.act_1 = torch.nn.GELU(approximate="tanh")
|
| 18 |
+
elif act_fn == "silu":
|
| 19 |
+
self.act_1 = torch.nn.SiLU()
|
| 20 |
+
else:
|
| 21 |
+
raise ValueError(f"Unknown activation function: {act_fn}")
|
| 22 |
+
self.linear_2 = torch.nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
|
| 23 |
+
|
| 24 |
+
def forward(self, caption: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
hidden_states = self.linear_1(caption)
|
| 26 |
+
hidden_states = self.act_1(hidden_states)
|
| 27 |
+
hidden_states = self.linear_2(hidden_states)
|
| 28 |
+
return hidden_states
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def create_caption_projection(transformer_config: dict, audio: bool = False) -> PixArtAlphaTextProjection:
|
| 32 |
+
"""Create a caption projection for the transformer (V1/19B only)."""
|
| 33 |
+
caption_channels = transformer_config["caption_channels"]
|
| 34 |
+
if audio:
|
| 35 |
+
inner_dim = transformer_config["audio_num_attention_heads"] * transformer_config["audio_attention_head_dim"]
|
| 36 |
+
else:
|
| 37 |
+
inner_dim = transformer_config["num_attention_heads"] * transformer_config["attention_head_dim"]
|
| 38 |
+
return PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
packages/ltx-core/src/ltx_core/model/transformer/timestep_embedding.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_timestep_embedding(
|
| 7 |
+
timesteps: torch.Tensor,
|
| 8 |
+
embedding_dim: int,
|
| 9 |
+
flip_sin_to_cos: bool = False,
|
| 10 |
+
downscale_freq_shift: float = 1,
|
| 11 |
+
scale: float = 1,
|
| 12 |
+
max_period: int = 10000,
|
| 13 |
+
) -> torch.Tensor:
|
| 14 |
+
"""
|
| 15 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 16 |
+
Args
|
| 17 |
+
timesteps (torch.Tensor):
|
| 18 |
+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
| 19 |
+
embedding_dim (int):
|
| 20 |
+
the dimension of the output.
|
| 21 |
+
flip_sin_to_cos (bool):
|
| 22 |
+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
| 23 |
+
downscale_freq_shift (float):
|
| 24 |
+
Controls the delta between frequencies between dimensions
|
| 25 |
+
scale (float):
|
| 26 |
+
Scaling factor applied to the embeddings.
|
| 27 |
+
max_period (int):
|
| 28 |
+
Controls the maximum frequency of the embeddings
|
| 29 |
+
Returns
|
| 30 |
+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
| 31 |
+
"""
|
| 32 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 33 |
+
|
| 34 |
+
half_dim = embedding_dim // 2
|
| 35 |
+
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
| 36 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 37 |
+
|
| 38 |
+
emb = torch.exp(exponent)
|
| 39 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 40 |
+
|
| 41 |
+
# scale embeddings
|
| 42 |
+
emb = scale * emb
|
| 43 |
+
|
| 44 |
+
# concat sine and cosine embeddings
|
| 45 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 46 |
+
|
| 47 |
+
# flip sine and cosine embeddings
|
| 48 |
+
if flip_sin_to_cos:
|
| 49 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
| 50 |
+
|
| 51 |
+
# zero pad
|
| 52 |
+
if embedding_dim % 2 == 1:
|
| 53 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 54 |
+
return emb
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TimestepEmbedding(torch.nn.Module):
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
in_channels: int,
|
| 61 |
+
time_embed_dim: int,
|
| 62 |
+
out_dim: int | None = None,
|
| 63 |
+
post_act_fn: str | None = None,
|
| 64 |
+
cond_proj_dim: int | None = None,
|
| 65 |
+
sample_proj_bias: bool = True,
|
| 66 |
+
):
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.linear_1 = torch.nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
| 70 |
+
|
| 71 |
+
if cond_proj_dim is not None:
|
| 72 |
+
self.cond_proj = torch.nn.Linear(cond_proj_dim, in_channels, bias=False)
|
| 73 |
+
else:
|
| 74 |
+
self.cond_proj = None
|
| 75 |
+
|
| 76 |
+
self.act = torch.nn.SiLU()
|
| 77 |
+
time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim
|
| 78 |
+
|
| 79 |
+
self.linear_2 = torch.nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
| 80 |
+
|
| 81 |
+
if post_act_fn is None:
|
| 82 |
+
self.post_act = None
|
| 83 |
+
|
| 84 |
+
def forward(self, sample: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor:
|
| 85 |
+
if condition is not None:
|
| 86 |
+
sample = sample + self.cond_proj(condition)
|
| 87 |
+
sample = self.linear_1(sample)
|
| 88 |
+
|
| 89 |
+
if self.act is not None:
|
| 90 |
+
sample = self.act(sample)
|
| 91 |
+
|
| 92 |
+
sample = self.linear_2(sample)
|
| 93 |
+
|
| 94 |
+
if self.post_act is not None:
|
| 95 |
+
sample = self.post_act(sample)
|
| 96 |
+
return sample
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class Timesteps(torch.nn.Module):
|
| 100 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.num_channels = num_channels
|
| 103 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
| 104 |
+
self.downscale_freq_shift = downscale_freq_shift
|
| 105 |
+
self.scale = scale
|
| 106 |
+
|
| 107 |
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
| 108 |
+
t_emb = get_timestep_embedding(
|
| 109 |
+
timesteps,
|
| 110 |
+
self.num_channels,
|
| 111 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
| 112 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
| 113 |
+
scale=self.scale,
|
| 114 |
+
)
|
| 115 |
+
return t_emb
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module):
|
| 119 |
+
"""
|
| 120 |
+
For PixArt-Alpha.
|
| 121 |
+
Reference:
|
| 122 |
+
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
embedding_dim: int,
|
| 128 |
+
size_emb_dim: int,
|
| 129 |
+
):
|
| 130 |
+
super().__init__()
|
| 131 |
+
|
| 132 |
+
self.outdim = size_emb_dim
|
| 133 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 134 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
| 135 |
+
|
| 136 |
+
def forward(
|
| 137 |
+
self,
|
| 138 |
+
timestep: torch.Tensor,
|
| 139 |
+
hidden_dtype: torch.dtype,
|
| 140 |
+
) -> torch.Tensor:
|
| 141 |
+
timesteps_proj = self.time_proj(timestep)
|
| 142 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
| 143 |
+
return timesteps_emb
|
packages/ltx-core/src/ltx_core/model/transformer/transformer.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, replace
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.guidance.perturbations import BatchedPerturbationConfig, PerturbationType
|
| 6 |
+
from ltx_core.model.transformer.adaln import adaln_embedding_coefficient
|
| 7 |
+
from ltx_core.model.transformer.attention import Attention, AttentionCallable, AttentionFunction
|
| 8 |
+
from ltx_core.model.transformer.feed_forward import FeedForward
|
| 9 |
+
from ltx_core.model.transformer.rope import LTXRopeType
|
| 10 |
+
from ltx_core.model.transformer.transformer_args import TransformerArgs
|
| 11 |
+
from ltx_core.utils import rms_norm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class TransformerConfig:
|
| 16 |
+
dim: int
|
| 17 |
+
heads: int
|
| 18 |
+
d_head: int
|
| 19 |
+
context_dim: int
|
| 20 |
+
apply_gated_attention: bool = False
|
| 21 |
+
cross_attention_adaln: bool = False
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class BasicAVTransformerBlock(torch.nn.Module):
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
idx: int,
|
| 28 |
+
video: TransformerConfig | None = None,
|
| 29 |
+
audio: TransformerConfig | None = None,
|
| 30 |
+
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
| 31 |
+
norm_eps: float = 1e-6,
|
| 32 |
+
attention_function: AttentionFunction | AttentionCallable = AttentionFunction.DEFAULT,
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
self.idx = idx
|
| 37 |
+
if video is not None:
|
| 38 |
+
self.attn1 = Attention(
|
| 39 |
+
query_dim=video.dim,
|
| 40 |
+
heads=video.heads,
|
| 41 |
+
dim_head=video.d_head,
|
| 42 |
+
context_dim=None,
|
| 43 |
+
rope_type=rope_type,
|
| 44 |
+
norm_eps=norm_eps,
|
| 45 |
+
attention_function=attention_function,
|
| 46 |
+
apply_gated_attention=video.apply_gated_attention,
|
| 47 |
+
)
|
| 48 |
+
self.attn2 = Attention(
|
| 49 |
+
query_dim=video.dim,
|
| 50 |
+
context_dim=video.context_dim,
|
| 51 |
+
heads=video.heads,
|
| 52 |
+
dim_head=video.d_head,
|
| 53 |
+
rope_type=rope_type,
|
| 54 |
+
norm_eps=norm_eps,
|
| 55 |
+
attention_function=attention_function,
|
| 56 |
+
apply_gated_attention=video.apply_gated_attention,
|
| 57 |
+
)
|
| 58 |
+
self.ff = FeedForward(video.dim, dim_out=video.dim)
|
| 59 |
+
video_sst_size = adaln_embedding_coefficient(video.cross_attention_adaln)
|
| 60 |
+
self.scale_shift_table = torch.nn.Parameter(torch.empty(video_sst_size, video.dim))
|
| 61 |
+
|
| 62 |
+
if audio is not None:
|
| 63 |
+
self.audio_attn1 = Attention(
|
| 64 |
+
query_dim=audio.dim,
|
| 65 |
+
heads=audio.heads,
|
| 66 |
+
dim_head=audio.d_head,
|
| 67 |
+
context_dim=None,
|
| 68 |
+
rope_type=rope_type,
|
| 69 |
+
norm_eps=norm_eps,
|
| 70 |
+
attention_function=attention_function,
|
| 71 |
+
apply_gated_attention=audio.apply_gated_attention,
|
| 72 |
+
)
|
| 73 |
+
self.audio_attn2 = Attention(
|
| 74 |
+
query_dim=audio.dim,
|
| 75 |
+
context_dim=audio.context_dim,
|
| 76 |
+
heads=audio.heads,
|
| 77 |
+
dim_head=audio.d_head,
|
| 78 |
+
rope_type=rope_type,
|
| 79 |
+
norm_eps=norm_eps,
|
| 80 |
+
attention_function=attention_function,
|
| 81 |
+
apply_gated_attention=audio.apply_gated_attention,
|
| 82 |
+
)
|
| 83 |
+
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
|
| 84 |
+
audio_sst_size = adaln_embedding_coefficient(audio.cross_attention_adaln)
|
| 85 |
+
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(audio_sst_size, audio.dim))
|
| 86 |
+
|
| 87 |
+
if audio is not None and video is not None:
|
| 88 |
+
# Q: Video, K,V: Audio
|
| 89 |
+
self.audio_to_video_attn = Attention(
|
| 90 |
+
query_dim=video.dim,
|
| 91 |
+
context_dim=audio.dim,
|
| 92 |
+
heads=audio.heads,
|
| 93 |
+
dim_head=audio.d_head,
|
| 94 |
+
rope_type=rope_type,
|
| 95 |
+
norm_eps=norm_eps,
|
| 96 |
+
attention_function=attention_function,
|
| 97 |
+
apply_gated_attention=video.apply_gated_attention,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Q: Audio, K,V: Video
|
| 101 |
+
self.video_to_audio_attn = Attention(
|
| 102 |
+
query_dim=audio.dim,
|
| 103 |
+
context_dim=video.dim,
|
| 104 |
+
heads=audio.heads,
|
| 105 |
+
dim_head=audio.d_head,
|
| 106 |
+
rope_type=rope_type,
|
| 107 |
+
norm_eps=norm_eps,
|
| 108 |
+
attention_function=attention_function,
|
| 109 |
+
apply_gated_attention=audio.apply_gated_attention,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim))
|
| 113 |
+
self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim))
|
| 114 |
+
|
| 115 |
+
self.cross_attention_adaln = (video is not None and video.cross_attention_adaln) or (
|
| 116 |
+
audio is not None and audio.cross_attention_adaln
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if self.cross_attention_adaln and video is not None:
|
| 120 |
+
self.prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, video.dim))
|
| 121 |
+
if self.cross_attention_adaln and audio is not None:
|
| 122 |
+
self.audio_prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, audio.dim))
|
| 123 |
+
|
| 124 |
+
self.norm_eps = norm_eps
|
| 125 |
+
|
| 126 |
+
def get_ada_values(
|
| 127 |
+
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice
|
| 128 |
+
) -> tuple[torch.Tensor, ...]:
|
| 129 |
+
num_ada_params = scale_shift_table.shape[0]
|
| 130 |
+
|
| 131 |
+
ada_values = (
|
| 132 |
+
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
|
| 133 |
+
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
|
| 134 |
+
).unbind(dim=2)
|
| 135 |
+
return ada_values
|
| 136 |
+
|
| 137 |
+
def get_av_ca_ada_values(
|
| 138 |
+
self,
|
| 139 |
+
scale_shift_table: torch.Tensor,
|
| 140 |
+
batch_size: int,
|
| 141 |
+
scale_shift_timestep: torch.Tensor,
|
| 142 |
+
gate_timestep: torch.Tensor,
|
| 143 |
+
scale_shift_indices: slice,
|
| 144 |
+
num_scale_shift_values: int = 4,
|
| 145 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 146 |
+
scale_shift_ada_values = self.get_ada_values(
|
| 147 |
+
scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, scale_shift_indices
|
| 148 |
+
)
|
| 149 |
+
gate_ada_values = self.get_ada_values(
|
| 150 |
+
scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
scale, shift = (t.squeeze(2) for t in scale_shift_ada_values)
|
| 154 |
+
(gate,) = (t.squeeze(2) for t in gate_ada_values)
|
| 155 |
+
|
| 156 |
+
return scale, shift, gate
|
| 157 |
+
|
| 158 |
+
def _apply_text_cross_attention(
|
| 159 |
+
self,
|
| 160 |
+
x: torch.Tensor,
|
| 161 |
+
context: torch.Tensor,
|
| 162 |
+
attn: AttentionCallable,
|
| 163 |
+
scale_shift_table: torch.Tensor,
|
| 164 |
+
prompt_scale_shift_table: torch.Tensor | None,
|
| 165 |
+
timestep: torch.Tensor,
|
| 166 |
+
prompt_timestep: torch.Tensor | None,
|
| 167 |
+
context_mask: torch.Tensor | None,
|
| 168 |
+
cross_attention_adaln: bool = False,
|
| 169 |
+
) -> torch.Tensor:
|
| 170 |
+
"""Apply text cross-attention, with optional AdaLN modulation."""
|
| 171 |
+
if cross_attention_adaln:
|
| 172 |
+
shift_q, scale_q, gate = self.get_ada_values(scale_shift_table, x.shape[0], timestep, slice(6, 9))
|
| 173 |
+
return apply_cross_attention_adaln(
|
| 174 |
+
x,
|
| 175 |
+
context,
|
| 176 |
+
attn,
|
| 177 |
+
shift_q,
|
| 178 |
+
scale_q,
|
| 179 |
+
gate,
|
| 180 |
+
prompt_scale_shift_table,
|
| 181 |
+
prompt_timestep,
|
| 182 |
+
context_mask,
|
| 183 |
+
self.norm_eps,
|
| 184 |
+
)
|
| 185 |
+
return attn(rms_norm(x, eps=self.norm_eps), context=context, mask=context_mask)
|
| 186 |
+
|
| 187 |
+
def forward( # noqa: PLR0915
|
| 188 |
+
self,
|
| 189 |
+
video: TransformerArgs | None,
|
| 190 |
+
audio: TransformerArgs | None,
|
| 191 |
+
perturbations: BatchedPerturbationConfig | None = None,
|
| 192 |
+
) -> tuple[TransformerArgs | None, TransformerArgs | None]:
|
| 193 |
+
if video is None and audio is None:
|
| 194 |
+
raise ValueError("At least one of video or audio must be provided")
|
| 195 |
+
|
| 196 |
+
batch_size = (video or audio).x.shape[0]
|
| 197 |
+
|
| 198 |
+
if perturbations is None:
|
| 199 |
+
perturbations = BatchedPerturbationConfig.empty(batch_size)
|
| 200 |
+
|
| 201 |
+
vx = video.x if video is not None else None
|
| 202 |
+
ax = audio.x if audio is not None else None
|
| 203 |
+
|
| 204 |
+
run_vx = video is not None and video.enabled and vx.numel() > 0
|
| 205 |
+
run_ax = audio is not None and audio.enabled and ax.numel() > 0
|
| 206 |
+
|
| 207 |
+
run_a2v = run_vx and (audio is not None and ax.numel() > 0)
|
| 208 |
+
run_v2a = run_ax and (video is not None and vx.numel() > 0)
|
| 209 |
+
|
| 210 |
+
if run_vx:
|
| 211 |
+
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
|
| 212 |
+
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
|
| 213 |
+
)
|
| 214 |
+
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
| 215 |
+
del vshift_msa, vscale_msa
|
| 216 |
+
|
| 217 |
+
all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
|
| 218 |
+
none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx)
|
| 219 |
+
v_mask = (
|
| 220 |
+
perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx)
|
| 221 |
+
if not all_perturbed and not none_perturbed
|
| 222 |
+
else None
|
| 223 |
+
)
|
| 224 |
+
vx = (
|
| 225 |
+
vx
|
| 226 |
+
+ self.attn1(
|
| 227 |
+
norm_vx,
|
| 228 |
+
pe=video.positional_embeddings,
|
| 229 |
+
mask=video.self_attention_mask,
|
| 230 |
+
perturbation_mask=v_mask,
|
| 231 |
+
all_perturbed=all_perturbed,
|
| 232 |
+
)
|
| 233 |
+
* vgate_msa
|
| 234 |
+
)
|
| 235 |
+
del vgate_msa, norm_vx, v_mask
|
| 236 |
+
vx = vx + self._apply_text_cross_attention(
|
| 237 |
+
vx,
|
| 238 |
+
video.context,
|
| 239 |
+
self.attn2,
|
| 240 |
+
self.scale_shift_table,
|
| 241 |
+
getattr(self, "prompt_scale_shift_table", None),
|
| 242 |
+
video.timesteps,
|
| 243 |
+
video.prompt_timestep,
|
| 244 |
+
video.context_mask,
|
| 245 |
+
cross_attention_adaln=self.cross_attention_adaln,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if run_ax:
|
| 249 |
+
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
|
| 250 |
+
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
| 254 |
+
del ashift_msa, ascale_msa
|
| 255 |
+
all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
|
| 256 |
+
none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx)
|
| 257 |
+
a_mask = (
|
| 258 |
+
perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax)
|
| 259 |
+
if not all_perturbed and not none_perturbed
|
| 260 |
+
else None
|
| 261 |
+
)
|
| 262 |
+
ax = (
|
| 263 |
+
ax
|
| 264 |
+
+ self.audio_attn1(
|
| 265 |
+
norm_ax,
|
| 266 |
+
pe=audio.positional_embeddings,
|
| 267 |
+
mask=audio.self_attention_mask,
|
| 268 |
+
perturbation_mask=a_mask,
|
| 269 |
+
all_perturbed=all_perturbed,
|
| 270 |
+
)
|
| 271 |
+
* agate_msa
|
| 272 |
+
)
|
| 273 |
+
del agate_msa, norm_ax, a_mask
|
| 274 |
+
ax = ax + self._apply_text_cross_attention(
|
| 275 |
+
ax,
|
| 276 |
+
audio.context,
|
| 277 |
+
self.audio_attn2,
|
| 278 |
+
self.audio_scale_shift_table,
|
| 279 |
+
getattr(self, "audio_prompt_scale_shift_table", None),
|
| 280 |
+
audio.timesteps,
|
| 281 |
+
audio.prompt_timestep,
|
| 282 |
+
audio.context_mask,
|
| 283 |
+
cross_attention_adaln=self.cross_attention_adaln,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Audio - Video cross attention.
|
| 287 |
+
if run_a2v or run_v2a:
|
| 288 |
+
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
|
| 289 |
+
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
|
| 290 |
+
|
| 291 |
+
if run_a2v and not perturbations.all_in_batch(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx):
|
| 292 |
+
scale_ca_video_a2v, shift_ca_video_a2v, gate_out_a2v = self.get_av_ca_ada_values(
|
| 293 |
+
self.scale_shift_table_a2v_ca_video,
|
| 294 |
+
vx.shape[0],
|
| 295 |
+
video.cross_scale_shift_timestep,
|
| 296 |
+
video.cross_gate_timestep,
|
| 297 |
+
slice(0, 2),
|
| 298 |
+
)
|
| 299 |
+
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
|
| 300 |
+
del scale_ca_video_a2v, shift_ca_video_a2v
|
| 301 |
+
|
| 302 |
+
scale_ca_audio_a2v, shift_ca_audio_a2v, _ = self.get_av_ca_ada_values(
|
| 303 |
+
self.scale_shift_table_a2v_ca_audio,
|
| 304 |
+
ax.shape[0],
|
| 305 |
+
audio.cross_scale_shift_timestep,
|
| 306 |
+
audio.cross_gate_timestep,
|
| 307 |
+
slice(0, 2),
|
| 308 |
+
)
|
| 309 |
+
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
|
| 310 |
+
del scale_ca_audio_a2v, shift_ca_audio_a2v
|
| 311 |
+
a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx)
|
| 312 |
+
vx = vx + (
|
| 313 |
+
self.audio_to_video_attn(
|
| 314 |
+
vx_scaled,
|
| 315 |
+
context=ax_scaled,
|
| 316 |
+
pe=video.cross_positional_embeddings,
|
| 317 |
+
k_pe=audio.cross_positional_embeddings,
|
| 318 |
+
)
|
| 319 |
+
* gate_out_a2v
|
| 320 |
+
* a2v_mask
|
| 321 |
+
)
|
| 322 |
+
del gate_out_a2v, a2v_mask, vx_scaled, ax_scaled
|
| 323 |
+
|
| 324 |
+
if run_v2a and not perturbations.all_in_batch(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx):
|
| 325 |
+
scale_ca_audio_v2a, shift_ca_audio_v2a, gate_out_v2a = self.get_av_ca_ada_values(
|
| 326 |
+
self.scale_shift_table_a2v_ca_audio,
|
| 327 |
+
ax.shape[0],
|
| 328 |
+
audio.cross_scale_shift_timestep,
|
| 329 |
+
audio.cross_gate_timestep,
|
| 330 |
+
slice(2, 4),
|
| 331 |
+
)
|
| 332 |
+
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
|
| 333 |
+
del scale_ca_audio_v2a, shift_ca_audio_v2a
|
| 334 |
+
scale_ca_video_v2a, shift_ca_video_v2a, _ = self.get_av_ca_ada_values(
|
| 335 |
+
self.scale_shift_table_a2v_ca_video,
|
| 336 |
+
vx.shape[0],
|
| 337 |
+
video.cross_scale_shift_timestep,
|
| 338 |
+
video.cross_gate_timestep,
|
| 339 |
+
slice(2, 4),
|
| 340 |
+
)
|
| 341 |
+
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
|
| 342 |
+
del scale_ca_video_v2a, shift_ca_video_v2a
|
| 343 |
+
v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax)
|
| 344 |
+
ax = ax + (
|
| 345 |
+
self.video_to_audio_attn(
|
| 346 |
+
ax_scaled,
|
| 347 |
+
context=vx_scaled,
|
| 348 |
+
pe=audio.cross_positional_embeddings,
|
| 349 |
+
k_pe=video.cross_positional_embeddings,
|
| 350 |
+
)
|
| 351 |
+
* gate_out_v2a
|
| 352 |
+
* v2a_mask
|
| 353 |
+
)
|
| 354 |
+
del gate_out_v2a, v2a_mask, ax_scaled, vx_scaled
|
| 355 |
+
|
| 356 |
+
del vx_norm3, ax_norm3
|
| 357 |
+
|
| 358 |
+
if run_vx:
|
| 359 |
+
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
|
| 360 |
+
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)
|
| 361 |
+
)
|
| 362 |
+
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
|
| 363 |
+
vx = vx + self.ff(vx_scaled) * vgate_mlp
|
| 364 |
+
|
| 365 |
+
del vshift_mlp, vscale_mlp, vgate_mlp, vx_scaled
|
| 366 |
+
|
| 367 |
+
if run_ax:
|
| 368 |
+
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
|
| 369 |
+
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
|
| 370 |
+
)
|
| 371 |
+
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
|
| 372 |
+
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
|
| 373 |
+
|
| 374 |
+
del ashift_mlp, ascale_mlp, agate_mlp, ax_scaled
|
| 375 |
+
|
| 376 |
+
return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def apply_cross_attention_adaln(
|
| 380 |
+
x: torch.Tensor,
|
| 381 |
+
context: torch.Tensor,
|
| 382 |
+
attn: AttentionCallable,
|
| 383 |
+
q_shift: torch.Tensor,
|
| 384 |
+
q_scale: torch.Tensor,
|
| 385 |
+
q_gate: torch.Tensor,
|
| 386 |
+
prompt_scale_shift_table: torch.Tensor,
|
| 387 |
+
prompt_timestep: torch.Tensor,
|
| 388 |
+
context_mask: torch.Tensor | None = None,
|
| 389 |
+
norm_eps: float = 1e-6,
|
| 390 |
+
) -> torch.Tensor:
|
| 391 |
+
batch_size = x.shape[0]
|
| 392 |
+
shift_kv, scale_kv = (
|
| 393 |
+
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
|
| 394 |
+
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
|
| 395 |
+
).unbind(dim=2)
|
| 396 |
+
attn_input = rms_norm(x, eps=norm_eps) * (1 + q_scale) + q_shift
|
| 397 |
+
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
|
| 398 |
+
return attn(attn_input, context=encoder_hidden_states, mask=context_mask) * q_gate
|
packages/ltx-core/src/ltx_core/model/transformer/transformer_args.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, replace
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.transformer.adaln import AdaLayerNormSingle
|
| 6 |
+
from ltx_core.model.transformer.modality import Modality
|
| 7 |
+
from ltx_core.model.transformer.rope import (
|
| 8 |
+
LTXRopeType,
|
| 9 |
+
generate_freq_grid_np,
|
| 10 |
+
generate_freq_grid_pytorch,
|
| 11 |
+
precompute_freqs_cis,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass(frozen=True)
|
| 16 |
+
class TransformerArgs:
|
| 17 |
+
x: torch.Tensor
|
| 18 |
+
context: torch.Tensor
|
| 19 |
+
context_mask: torch.Tensor
|
| 20 |
+
timesteps: torch.Tensor
|
| 21 |
+
embedded_timestep: torch.Tensor
|
| 22 |
+
positional_embeddings: torch.Tensor
|
| 23 |
+
cross_positional_embeddings: torch.Tensor | None
|
| 24 |
+
cross_scale_shift_timestep: torch.Tensor | None
|
| 25 |
+
cross_gate_timestep: torch.Tensor | None
|
| 26 |
+
enabled: bool
|
| 27 |
+
prompt_timestep: torch.Tensor | None = None
|
| 28 |
+
self_attention_mask: torch.Tensor | None = (
|
| 29 |
+
None # Additive log-space self-attention bias (B, 1, T, T), None = full attention
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class TransformerArgsPreprocessor:
|
| 34 |
+
def __init__( # noqa: PLR0913
|
| 35 |
+
self,
|
| 36 |
+
patchify_proj: torch.nn.Linear,
|
| 37 |
+
adaln: AdaLayerNormSingle,
|
| 38 |
+
inner_dim: int,
|
| 39 |
+
max_pos: list[int],
|
| 40 |
+
num_attention_heads: int,
|
| 41 |
+
use_middle_indices_grid: bool,
|
| 42 |
+
timestep_scale_multiplier: int,
|
| 43 |
+
double_precision_rope: bool,
|
| 44 |
+
positional_embedding_theta: float,
|
| 45 |
+
rope_type: LTXRopeType,
|
| 46 |
+
caption_projection: torch.nn.Module | None = None,
|
| 47 |
+
prompt_adaln: AdaLayerNormSingle | None = None,
|
| 48 |
+
) -> None:
|
| 49 |
+
self.patchify_proj = patchify_proj
|
| 50 |
+
self.adaln = adaln
|
| 51 |
+
self.inner_dim = inner_dim
|
| 52 |
+
self.max_pos = max_pos
|
| 53 |
+
self.num_attention_heads = num_attention_heads
|
| 54 |
+
self.use_middle_indices_grid = use_middle_indices_grid
|
| 55 |
+
self.timestep_scale_multiplier = timestep_scale_multiplier
|
| 56 |
+
self.double_precision_rope = double_precision_rope
|
| 57 |
+
self.positional_embedding_theta = positional_embedding_theta
|
| 58 |
+
self.rope_type = rope_type
|
| 59 |
+
self.caption_projection = caption_projection
|
| 60 |
+
self.prompt_adaln = prompt_adaln
|
| 61 |
+
|
| 62 |
+
def _prepare_timestep(
|
| 63 |
+
self, timestep: torch.Tensor, adaln: AdaLayerNormSingle, batch_size: int, hidden_dtype: torch.dtype
|
| 64 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 65 |
+
"""Prepare timestep embeddings."""
|
| 66 |
+
timestep_scaled = timestep * self.timestep_scale_multiplier
|
| 67 |
+
timestep, embedded_timestep = adaln(
|
| 68 |
+
timestep_scaled.flatten(),
|
| 69 |
+
hidden_dtype=hidden_dtype,
|
| 70 |
+
)
|
| 71 |
+
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
| 72 |
+
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
| 73 |
+
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
| 74 |
+
|
| 75 |
+
return timestep, embedded_timestep
|
| 76 |
+
|
| 77 |
+
def _prepare_context(
|
| 78 |
+
self,
|
| 79 |
+
context: torch.Tensor,
|
| 80 |
+
x: torch.Tensor,
|
| 81 |
+
) -> torch.Tensor:
|
| 82 |
+
"""Prepare context for transformer blocks."""
|
| 83 |
+
if self.caption_projection is not None:
|
| 84 |
+
context = self.caption_projection(context)
|
| 85 |
+
batch_size = x.shape[0]
|
| 86 |
+
return context.view(batch_size, -1, x.shape[-1])
|
| 87 |
+
|
| 88 |
+
def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None:
|
| 89 |
+
"""Prepare attention mask."""
|
| 90 |
+
if attention_mask is None or torch.is_floating_point(attention_mask):
|
| 91 |
+
return attention_mask
|
| 92 |
+
|
| 93 |
+
return (attention_mask - 1).to(x_dtype).reshape(
|
| 94 |
+
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
| 95 |
+
) * torch.finfo(x_dtype).max
|
| 96 |
+
|
| 97 |
+
def _prepare_self_attention_mask(
|
| 98 |
+
self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype
|
| 99 |
+
) -> torch.Tensor | None:
|
| 100 |
+
"""Prepare self-attention mask by converting [0,1] values to additive log-space bias.
|
| 101 |
+
Input shape: (B, T, T) with values in [0, 1].
|
| 102 |
+
Output shape: (B, 1, T, T) with 0.0 for full attention and a large negative value
|
| 103 |
+
for masked positions.
|
| 104 |
+
Positions with attention_mask <= 0 are fully masked (mapped to the dtype's minimum
|
| 105 |
+
representable value). Strictly positive entries are converted via log-space for
|
| 106 |
+
smooth attenuation, with small values clamped for numerical stability.
|
| 107 |
+
Returns None if input is None (no masking).
|
| 108 |
+
"""
|
| 109 |
+
if attention_mask is None:
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
# Convert [0, 1] attention mask to additive log-space bias:
|
| 113 |
+
# 1.0 -> log(1.0) = 0.0 (no bias, full attention)
|
| 114 |
+
# 0.0 -> finfo.min (fully masked)
|
| 115 |
+
finfo = torch.finfo(x_dtype)
|
| 116 |
+
eps = finfo.tiny
|
| 117 |
+
|
| 118 |
+
bias = torch.full_like(attention_mask, finfo.min, dtype=x_dtype)
|
| 119 |
+
positive = attention_mask > 0
|
| 120 |
+
if positive.any():
|
| 121 |
+
bias[positive] = torch.log(attention_mask[positive].clamp(min=eps)).to(x_dtype)
|
| 122 |
+
|
| 123 |
+
return bias.unsqueeze(1) # (B, 1, T, T) for head broadcast
|
| 124 |
+
|
| 125 |
+
def _prepare_positional_embeddings(
|
| 126 |
+
self,
|
| 127 |
+
positions: torch.Tensor,
|
| 128 |
+
inner_dim: int,
|
| 129 |
+
max_pos: list[int],
|
| 130 |
+
use_middle_indices_grid: bool,
|
| 131 |
+
num_attention_heads: int,
|
| 132 |
+
x_dtype: torch.dtype,
|
| 133 |
+
) -> torch.Tensor:
|
| 134 |
+
"""Prepare positional embeddings."""
|
| 135 |
+
freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
|
| 136 |
+
pe = precompute_freqs_cis(
|
| 137 |
+
positions,
|
| 138 |
+
dim=inner_dim,
|
| 139 |
+
out_dtype=x_dtype,
|
| 140 |
+
theta=self.positional_embedding_theta,
|
| 141 |
+
max_pos=max_pos,
|
| 142 |
+
use_middle_indices_grid=use_middle_indices_grid,
|
| 143 |
+
num_attention_heads=num_attention_heads,
|
| 144 |
+
rope_type=self.rope_type,
|
| 145 |
+
freq_grid_generator=freq_grid_generator,
|
| 146 |
+
)
|
| 147 |
+
return pe
|
| 148 |
+
|
| 149 |
+
def prepare(
|
| 150 |
+
self,
|
| 151 |
+
modality: Modality,
|
| 152 |
+
cross_modality: Modality | None = None, # noqa: ARG002
|
| 153 |
+
) -> TransformerArgs:
|
| 154 |
+
x = self.patchify_proj(modality.latent)
|
| 155 |
+
batch_size = x.shape[0]
|
| 156 |
+
timestep, embedded_timestep = self._prepare_timestep(
|
| 157 |
+
modality.timesteps, self.adaln, batch_size, modality.latent.dtype
|
| 158 |
+
)
|
| 159 |
+
prompt_timestep = None
|
| 160 |
+
if self.prompt_adaln is not None:
|
| 161 |
+
prompt_timestep, _ = self._prepare_timestep(
|
| 162 |
+
modality.sigma, self.prompt_adaln, batch_size, modality.latent.dtype
|
| 163 |
+
)
|
| 164 |
+
context = self._prepare_context(modality.context, x)
|
| 165 |
+
attention_mask = self._prepare_attention_mask(modality.context_mask, modality.latent.dtype)
|
| 166 |
+
pe = self._prepare_positional_embeddings(
|
| 167 |
+
positions=modality.positions,
|
| 168 |
+
inner_dim=self.inner_dim,
|
| 169 |
+
max_pos=self.max_pos,
|
| 170 |
+
use_middle_indices_grid=self.use_middle_indices_grid,
|
| 171 |
+
num_attention_heads=self.num_attention_heads,
|
| 172 |
+
x_dtype=modality.latent.dtype,
|
| 173 |
+
)
|
| 174 |
+
self_attention_mask = self._prepare_self_attention_mask(modality.attention_mask, modality.latent.dtype)
|
| 175 |
+
return TransformerArgs(
|
| 176 |
+
x=x,
|
| 177 |
+
context=context,
|
| 178 |
+
context_mask=attention_mask,
|
| 179 |
+
timesteps=timestep,
|
| 180 |
+
embedded_timestep=embedded_timestep,
|
| 181 |
+
positional_embeddings=pe,
|
| 182 |
+
cross_positional_embeddings=None,
|
| 183 |
+
cross_scale_shift_timestep=None,
|
| 184 |
+
cross_gate_timestep=None,
|
| 185 |
+
enabled=modality.enabled,
|
| 186 |
+
prompt_timestep=prompt_timestep,
|
| 187 |
+
self_attention_mask=self_attention_mask,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class MultiModalTransformerArgsPreprocessor:
|
| 192 |
+
def __init__( # noqa: PLR0913
|
| 193 |
+
self,
|
| 194 |
+
patchify_proj: torch.nn.Linear,
|
| 195 |
+
adaln: AdaLayerNormSingle,
|
| 196 |
+
cross_scale_shift_adaln: AdaLayerNormSingle,
|
| 197 |
+
cross_gate_adaln: AdaLayerNormSingle,
|
| 198 |
+
inner_dim: int,
|
| 199 |
+
max_pos: list[int],
|
| 200 |
+
num_attention_heads: int,
|
| 201 |
+
cross_pe_max_pos: int,
|
| 202 |
+
use_middle_indices_grid: bool,
|
| 203 |
+
audio_cross_attention_dim: int,
|
| 204 |
+
timestep_scale_multiplier: int,
|
| 205 |
+
double_precision_rope: bool,
|
| 206 |
+
positional_embedding_theta: float,
|
| 207 |
+
rope_type: LTXRopeType,
|
| 208 |
+
av_ca_timestep_scale_multiplier: int,
|
| 209 |
+
caption_projection: torch.nn.Module | None = None,
|
| 210 |
+
prompt_adaln: AdaLayerNormSingle | None = None,
|
| 211 |
+
) -> None:
|
| 212 |
+
self.simple_preprocessor = TransformerArgsPreprocessor(
|
| 213 |
+
patchify_proj=patchify_proj,
|
| 214 |
+
adaln=adaln,
|
| 215 |
+
inner_dim=inner_dim,
|
| 216 |
+
max_pos=max_pos,
|
| 217 |
+
num_attention_heads=num_attention_heads,
|
| 218 |
+
use_middle_indices_grid=use_middle_indices_grid,
|
| 219 |
+
timestep_scale_multiplier=timestep_scale_multiplier,
|
| 220 |
+
double_precision_rope=double_precision_rope,
|
| 221 |
+
positional_embedding_theta=positional_embedding_theta,
|
| 222 |
+
rope_type=rope_type,
|
| 223 |
+
caption_projection=caption_projection,
|
| 224 |
+
prompt_adaln=prompt_adaln,
|
| 225 |
+
)
|
| 226 |
+
self.cross_scale_shift_adaln = cross_scale_shift_adaln
|
| 227 |
+
self.cross_gate_adaln = cross_gate_adaln
|
| 228 |
+
self.cross_pe_max_pos = cross_pe_max_pos
|
| 229 |
+
self.audio_cross_attention_dim = audio_cross_attention_dim
|
| 230 |
+
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
| 231 |
+
|
| 232 |
+
def prepare(
|
| 233 |
+
self,
|
| 234 |
+
modality: Modality,
|
| 235 |
+
cross_modality: Modality | None = None,
|
| 236 |
+
) -> TransformerArgs:
|
| 237 |
+
transformer_args = self.simple_preprocessor.prepare(modality)
|
| 238 |
+
if cross_modality is None:
|
| 239 |
+
return transformer_args
|
| 240 |
+
|
| 241 |
+
if cross_modality.sigma.numel() > 1:
|
| 242 |
+
if cross_modality.sigma.shape[0] != modality.timesteps.shape[0]:
|
| 243 |
+
raise ValueError("Cross modality sigma must have the same batch size as the modality")
|
| 244 |
+
if cross_modality.sigma.ndim != 1:
|
| 245 |
+
raise ValueError("Cross modality sigma must be a 1D tensor")
|
| 246 |
+
|
| 247 |
+
cross_timestep = cross_modality.sigma.view(
|
| 248 |
+
modality.timesteps.shape[0], 1, *[1] * len(modality.timesteps.shape[2:])
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
|
| 252 |
+
positions=modality.positions[:, 0:1, :],
|
| 253 |
+
inner_dim=self.audio_cross_attention_dim,
|
| 254 |
+
max_pos=[self.cross_pe_max_pos],
|
| 255 |
+
use_middle_indices_grid=True,
|
| 256 |
+
num_attention_heads=self.simple_preprocessor.num_attention_heads,
|
| 257 |
+
x_dtype=modality.latent.dtype,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
|
| 261 |
+
timestep=cross_timestep,
|
| 262 |
+
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
| 263 |
+
batch_size=transformer_args.x.shape[0],
|
| 264 |
+
hidden_dtype=modality.latent.dtype,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
return replace(
|
| 268 |
+
transformer_args,
|
| 269 |
+
cross_positional_embeddings=cross_pe,
|
| 270 |
+
cross_scale_shift_timestep=cross_scale_shift_timestep,
|
| 271 |
+
cross_gate_timestep=cross_gate_timestep,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
def _prepare_cross_attention_timestep(
|
| 275 |
+
self,
|
| 276 |
+
timestep: torch.Tensor | None,
|
| 277 |
+
timestep_scale_multiplier: int,
|
| 278 |
+
batch_size: int,
|
| 279 |
+
hidden_dtype: torch.dtype,
|
| 280 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 281 |
+
"""Prepare cross attention timestep embeddings."""
|
| 282 |
+
timestep = timestep * timestep_scale_multiplier
|
| 283 |
+
|
| 284 |
+
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
| 285 |
+
|
| 286 |
+
scale_shift_timestep, _ = self.cross_scale_shift_adaln(
|
| 287 |
+
timestep.flatten(),
|
| 288 |
+
hidden_dtype=hidden_dtype,
|
| 289 |
+
)
|
| 290 |
+
scale_shift_timestep = scale_shift_timestep.view(batch_size, -1, scale_shift_timestep.shape[-1])
|
| 291 |
+
gate_noise_timestep, _ = self.cross_gate_adaln(
|
| 292 |
+
timestep.flatten() * av_ca_factor,
|
| 293 |
+
hidden_dtype=hidden_dtype,
|
| 294 |
+
)
|
| 295 |
+
gate_noise_timestep = gate_noise_timestep.view(batch_size, -1, gate_noise_timestep.shape[-1])
|
| 296 |
+
|
| 297 |
+
return scale_shift_timestep, gate_noise_timestep
|
packages/ltx-core/src/ltx_core/model/video_vae/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Video VAE package."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.video_vae.model_configurator import (
|
| 4 |
+
VAE_DECODER_COMFY_KEYS_FILTER,
|
| 5 |
+
VAE_ENCODER_COMFY_KEYS_FILTER,
|
| 6 |
+
VideoDecoderConfigurator,
|
| 7 |
+
VideoEncoderConfigurator,
|
| 8 |
+
)
|
| 9 |
+
from ltx_core.model.video_vae.tiling import SpatialTilingConfig, TemporalTilingConfig, TilingConfig
|
| 10 |
+
from ltx_core.model.video_vae.video_vae import VideoDecoder, VideoEncoder, decode_video, get_video_chunks_number
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"VAE_DECODER_COMFY_KEYS_FILTER",
|
| 14 |
+
"VAE_ENCODER_COMFY_KEYS_FILTER",
|
| 15 |
+
"SpatialTilingConfig",
|
| 16 |
+
"TemporalTilingConfig",
|
| 17 |
+
"TilingConfig",
|
| 18 |
+
"VideoDecoder",
|
| 19 |
+
"VideoDecoderConfigurator",
|
| 20 |
+
"VideoEncoder",
|
| 21 |
+
"VideoEncoderConfigurator",
|
| 22 |
+
"decode_video",
|
| 23 |
+
"get_video_chunks_number",
|
| 24 |
+
]
|
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (811 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/convolution.cpython-312.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/enums.cpython-312.pyc
ADDED
|
Binary file (1.01 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/model_configurator.cpython-312.pyc
ADDED
|
Binary file (4.24 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/ops.cpython-312.pyc
ADDED
|
Binary file (5.01 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/resnet.cpython-312.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/sampling.cpython-312.pyc
ADDED
|
Binary file (4.96 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/tiling.cpython-312.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/video_vae/__pycache__/video_vae.cpython-312.pyc
ADDED
|
Binary file (44.5 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/model/video_vae/convolution.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
from ltx_core.model.video_vae.enums import PaddingModeType
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def make_conv_nd( # noqa: PLR0913
|
| 12 |
+
dims: Union[int, Tuple[int, int]],
|
| 13 |
+
in_channels: int,
|
| 14 |
+
out_channels: int,
|
| 15 |
+
kernel_size: int,
|
| 16 |
+
stride: int = 1,
|
| 17 |
+
padding: int = 0,
|
| 18 |
+
dilation: int = 1,
|
| 19 |
+
groups: int = 1,
|
| 20 |
+
bias: bool = True,
|
| 21 |
+
causal: bool = False,
|
| 22 |
+
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
| 23 |
+
temporal_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
| 24 |
+
) -> nn.Module:
|
| 25 |
+
if not (spatial_padding_mode == temporal_padding_mode or causal):
|
| 26 |
+
raise NotImplementedError("spatial and temporal padding modes must be equal")
|
| 27 |
+
if dims == 2:
|
| 28 |
+
return nn.Conv2d(
|
| 29 |
+
in_channels=in_channels,
|
| 30 |
+
out_channels=out_channels,
|
| 31 |
+
kernel_size=kernel_size,
|
| 32 |
+
stride=stride,
|
| 33 |
+
padding=padding,
|
| 34 |
+
dilation=dilation,
|
| 35 |
+
groups=groups,
|
| 36 |
+
bias=bias,
|
| 37 |
+
padding_mode=spatial_padding_mode.value,
|
| 38 |
+
)
|
| 39 |
+
elif dims == 3:
|
| 40 |
+
if causal:
|
| 41 |
+
return CausalConv3d(
|
| 42 |
+
in_channels=in_channels,
|
| 43 |
+
out_channels=out_channels,
|
| 44 |
+
kernel_size=kernel_size,
|
| 45 |
+
stride=stride,
|
| 46 |
+
dilation=dilation,
|
| 47 |
+
groups=groups,
|
| 48 |
+
bias=bias,
|
| 49 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 50 |
+
)
|
| 51 |
+
return nn.Conv3d(
|
| 52 |
+
in_channels=in_channels,
|
| 53 |
+
out_channels=out_channels,
|
| 54 |
+
kernel_size=kernel_size,
|
| 55 |
+
stride=stride,
|
| 56 |
+
padding=padding,
|
| 57 |
+
dilation=dilation,
|
| 58 |
+
groups=groups,
|
| 59 |
+
bias=bias,
|
| 60 |
+
padding_mode=spatial_padding_mode.value,
|
| 61 |
+
)
|
| 62 |
+
elif dims == (2, 1):
|
| 63 |
+
return DualConv3d(
|
| 64 |
+
in_channels=in_channels,
|
| 65 |
+
out_channels=out_channels,
|
| 66 |
+
kernel_size=kernel_size,
|
| 67 |
+
stride=stride,
|
| 68 |
+
padding=padding,
|
| 69 |
+
bias=bias,
|
| 70 |
+
padding_mode=spatial_padding_mode.value,
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def make_linear_nd(
|
| 77 |
+
dims: int,
|
| 78 |
+
in_channels: int,
|
| 79 |
+
out_channels: int,
|
| 80 |
+
bias: bool = True,
|
| 81 |
+
) -> nn.Module:
|
| 82 |
+
if dims == 2:
|
| 83 |
+
return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)
|
| 84 |
+
elif dims in (3, (2, 1)):
|
| 85 |
+
return nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)
|
| 86 |
+
else:
|
| 87 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class DualConv3d(nn.Module):
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
in_channels: int,
|
| 94 |
+
out_channels: int,
|
| 95 |
+
kernel_size: int,
|
| 96 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
| 97 |
+
padding: Union[int, Tuple[int, int, int]] = 0,
|
| 98 |
+
dilation: Union[int, Tuple[int, int, int]] = 1,
|
| 99 |
+
groups: int = 1,
|
| 100 |
+
bias: bool = True,
|
| 101 |
+
padding_mode: str = "zeros",
|
| 102 |
+
) -> None:
|
| 103 |
+
super(DualConv3d, self).__init__()
|
| 104 |
+
|
| 105 |
+
self.in_channels = in_channels
|
| 106 |
+
self.out_channels = out_channels
|
| 107 |
+
self.padding_mode = padding_mode
|
| 108 |
+
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
| 109 |
+
if isinstance(kernel_size, int):
|
| 110 |
+
kernel_size = (kernel_size, kernel_size, kernel_size)
|
| 111 |
+
if kernel_size == (1, 1, 1):
|
| 112 |
+
raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.")
|
| 113 |
+
if isinstance(stride, int):
|
| 114 |
+
stride = (stride, stride, stride)
|
| 115 |
+
if isinstance(padding, int):
|
| 116 |
+
padding = (padding, padding, padding)
|
| 117 |
+
if isinstance(dilation, int):
|
| 118 |
+
dilation = (dilation, dilation, dilation)
|
| 119 |
+
|
| 120 |
+
# Set parameters for convolutions
|
| 121 |
+
self.groups = groups
|
| 122 |
+
self.bias = bias
|
| 123 |
+
|
| 124 |
+
# Define the size of the channels after the first convolution
|
| 125 |
+
intermediate_channels = out_channels if in_channels < out_channels else in_channels
|
| 126 |
+
|
| 127 |
+
# Define parameters for the first convolution
|
| 128 |
+
self.weight1 = nn.Parameter(
|
| 129 |
+
torch.Tensor(
|
| 130 |
+
intermediate_channels,
|
| 131 |
+
in_channels // groups,
|
| 132 |
+
1,
|
| 133 |
+
kernel_size[1],
|
| 134 |
+
kernel_size[2],
|
| 135 |
+
)
|
| 136 |
+
)
|
| 137 |
+
self.stride1 = (1, stride[1], stride[2])
|
| 138 |
+
self.padding1 = (0, padding[1], padding[2])
|
| 139 |
+
self.dilation1 = (1, dilation[1], dilation[2])
|
| 140 |
+
if bias:
|
| 141 |
+
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
|
| 142 |
+
else:
|
| 143 |
+
self.register_parameter("bias1", None)
|
| 144 |
+
|
| 145 |
+
# Define parameters for the second convolution
|
| 146 |
+
self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1))
|
| 147 |
+
self.stride2 = (stride[0], 1, 1)
|
| 148 |
+
self.padding2 = (padding[0], 0, 0)
|
| 149 |
+
self.dilation2 = (dilation[0], 1, 1)
|
| 150 |
+
if bias:
|
| 151 |
+
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
|
| 152 |
+
else:
|
| 153 |
+
self.register_parameter("bias2", None)
|
| 154 |
+
|
| 155 |
+
# Initialize weights and biases
|
| 156 |
+
self.reset_parameters()
|
| 157 |
+
|
| 158 |
+
def reset_parameters(self) -> None:
|
| 159 |
+
nn.init.kaiming_uniform_(self.weight1, a=torch.sqrt(5))
|
| 160 |
+
nn.init.kaiming_uniform_(self.weight2, a=torch.sqrt(5))
|
| 161 |
+
if self.bias:
|
| 162 |
+
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
|
| 163 |
+
bound1 = 1 / torch.sqrt(fan_in1)
|
| 164 |
+
nn.init.uniform_(self.bias1, -bound1, bound1)
|
| 165 |
+
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
|
| 166 |
+
bound2 = 1 / torch.sqrt(fan_in2)
|
| 167 |
+
nn.init.uniform_(self.bias2, -bound2, bound2)
|
| 168 |
+
|
| 169 |
+
def forward(
|
| 170 |
+
self,
|
| 171 |
+
x: torch.Tensor,
|
| 172 |
+
use_conv3d: bool = False,
|
| 173 |
+
skip_time_conv: bool = False,
|
| 174 |
+
) -> torch.Tensor:
|
| 175 |
+
if use_conv3d:
|
| 176 |
+
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
|
| 177 |
+
else:
|
| 178 |
+
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
|
| 179 |
+
|
| 180 |
+
def forward_with_3d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor:
|
| 181 |
+
# First convolution
|
| 182 |
+
x = F.conv3d(
|
| 183 |
+
x,
|
| 184 |
+
self.weight1,
|
| 185 |
+
self.bias1,
|
| 186 |
+
self.stride1,
|
| 187 |
+
self.padding1,
|
| 188 |
+
self.dilation1,
|
| 189 |
+
self.groups,
|
| 190 |
+
padding_mode=self.padding_mode,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
if skip_time_conv:
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
# Second convolution
|
| 197 |
+
x = F.conv3d(
|
| 198 |
+
x,
|
| 199 |
+
self.weight2,
|
| 200 |
+
self.bias2,
|
| 201 |
+
self.stride2,
|
| 202 |
+
self.padding2,
|
| 203 |
+
self.dilation2,
|
| 204 |
+
self.groups,
|
| 205 |
+
padding_mode=self.padding_mode,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
return x
|
| 209 |
+
|
| 210 |
+
def forward_with_2d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor:
|
| 211 |
+
b, _, _, h, w = x.shape
|
| 212 |
+
|
| 213 |
+
# First 2D convolution
|
| 214 |
+
x = rearrange(x, "b c d h w -> (b d) c h w")
|
| 215 |
+
# Squeeze the depth dimension out of weight1 since it's 1
|
| 216 |
+
weight1 = self.weight1.squeeze(2)
|
| 217 |
+
# Select stride, padding, and dilation for the 2D convolution
|
| 218 |
+
stride1 = (self.stride1[1], self.stride1[2])
|
| 219 |
+
padding1 = (self.padding1[1], self.padding1[2])
|
| 220 |
+
dilation1 = (self.dilation1[1], self.dilation1[2])
|
| 221 |
+
x = F.conv2d(
|
| 222 |
+
x,
|
| 223 |
+
weight1,
|
| 224 |
+
self.bias1,
|
| 225 |
+
stride1,
|
| 226 |
+
padding1,
|
| 227 |
+
dilation1,
|
| 228 |
+
self.groups,
|
| 229 |
+
padding_mode=self.padding_mode,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
_, _, h, w = x.shape
|
| 233 |
+
|
| 234 |
+
if skip_time_conv:
|
| 235 |
+
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
|
| 239 |
+
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
|
| 240 |
+
|
| 241 |
+
# Reshape weight2 to match the expected dimensions for conv1d
|
| 242 |
+
weight2 = self.weight2.squeeze(-1).squeeze(-1)
|
| 243 |
+
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
|
| 244 |
+
stride2 = self.stride2[0]
|
| 245 |
+
padding2 = self.padding2[0]
|
| 246 |
+
dilation2 = self.dilation2[0]
|
| 247 |
+
x = F.conv1d(
|
| 248 |
+
x,
|
| 249 |
+
weight2,
|
| 250 |
+
self.bias2,
|
| 251 |
+
stride2,
|
| 252 |
+
padding2,
|
| 253 |
+
dilation2,
|
| 254 |
+
self.groups,
|
| 255 |
+
padding_mode=self.padding_mode,
|
| 256 |
+
)
|
| 257 |
+
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
| 258 |
+
|
| 259 |
+
return x
|
| 260 |
+
|
| 261 |
+
@property
|
| 262 |
+
def weight(self) -> torch.Tensor:
|
| 263 |
+
return self.weight2
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class CausalConv3d(nn.Module):
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
in_channels: int,
|
| 270 |
+
out_channels: int,
|
| 271 |
+
kernel_size: int = 3,
|
| 272 |
+
stride: Union[int, Tuple[int]] = 1,
|
| 273 |
+
dilation: int = 1,
|
| 274 |
+
groups: int = 1,
|
| 275 |
+
bias: bool = True,
|
| 276 |
+
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
| 277 |
+
) -> None:
|
| 278 |
+
super().__init__()
|
| 279 |
+
|
| 280 |
+
self.in_channels = in_channels
|
| 281 |
+
self.out_channels = out_channels
|
| 282 |
+
|
| 283 |
+
kernel_size = (kernel_size, kernel_size, kernel_size)
|
| 284 |
+
self.time_kernel_size = kernel_size[0]
|
| 285 |
+
|
| 286 |
+
dilation = (dilation, 1, 1)
|
| 287 |
+
|
| 288 |
+
height_pad = kernel_size[1] // 2
|
| 289 |
+
width_pad = kernel_size[2] // 2
|
| 290 |
+
padding = (0, height_pad, width_pad)
|
| 291 |
+
|
| 292 |
+
self.conv = nn.Conv3d(
|
| 293 |
+
in_channels,
|
| 294 |
+
out_channels,
|
| 295 |
+
kernel_size,
|
| 296 |
+
stride=stride,
|
| 297 |
+
dilation=dilation,
|
| 298 |
+
padding=padding,
|
| 299 |
+
padding_mode=spatial_padding_mode.value,
|
| 300 |
+
groups=groups,
|
| 301 |
+
bias=bias,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor:
|
| 305 |
+
if causal:
|
| 306 |
+
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1))
|
| 307 |
+
x = torch.concatenate((first_frame_pad, x), dim=2)
|
| 308 |
+
else:
|
| 309 |
+
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))
|
| 310 |
+
last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))
|
| 311 |
+
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
| 312 |
+
x = self.conv(x)
|
| 313 |
+
return x
|
| 314 |
+
|
| 315 |
+
@property
|
| 316 |
+
def weight(self) -> torch.Tensor:
|
| 317 |
+
return self.conv.weight
|
packages/ltx-core/src/ltx_core/model/video_vae/model_configurator.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 2 |
+
from ltx_core.model.model_protocol import ModelConfigurator
|
| 3 |
+
from ltx_core.model.video_vae.enums import LogVarianceType, NormLayerType, PaddingModeType
|
| 4 |
+
from ltx_core.model.video_vae.video_vae import VideoDecoder, VideoEncoder
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class VideoEncoderConfigurator(ModelConfigurator[VideoEncoder]):
|
| 8 |
+
"""Configurator for creating a video VAE Encoder from a configuration dictionary."""
|
| 9 |
+
|
| 10 |
+
@classmethod
|
| 11 |
+
def from_config(cls: type[VideoEncoder], config: dict) -> VideoEncoder:
|
| 12 |
+
config = config.get("vae", {})
|
| 13 |
+
convolution_dimensions = config.get("dims", 3)
|
| 14 |
+
in_channels = config.get("in_channels", 3)
|
| 15 |
+
latent_channels = config.get("latent_channels", 128)
|
| 16 |
+
spatial_padding_mode = PaddingModeType(config.get("spatial_padding_mode", "zeros"))
|
| 17 |
+
encoder_blocks = config.get("encoder_blocks", [])
|
| 18 |
+
patch_size = config.get("patch_size", 4)
|
| 19 |
+
norm_layer_str = config.get("norm_layer", "pixel_norm")
|
| 20 |
+
latent_log_var_str = config.get("latent_log_var", "uniform")
|
| 21 |
+
|
| 22 |
+
return VideoEncoder(
|
| 23 |
+
convolution_dimensions=convolution_dimensions,
|
| 24 |
+
in_channels=in_channels,
|
| 25 |
+
out_channels=latent_channels,
|
| 26 |
+
encoder_blocks=encoder_blocks,
|
| 27 |
+
patch_size=patch_size,
|
| 28 |
+
norm_layer=NormLayerType(norm_layer_str),
|
| 29 |
+
latent_log_var=LogVarianceType(latent_log_var_str),
|
| 30 |
+
encoder_spatial_padding_mode=spatial_padding_mode,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class VideoDecoderConfigurator(ModelConfigurator[VideoDecoder]):
|
| 35 |
+
"""Configurator for creating a video VAE Decoder from a configuration dictionary."""
|
| 36 |
+
|
| 37 |
+
@classmethod
|
| 38 |
+
def from_config(cls: type[VideoDecoder], config: dict) -> VideoDecoder:
|
| 39 |
+
config = config.get("vae", {})
|
| 40 |
+
convolution_dimensions = config.get("dims", 3)
|
| 41 |
+
latent_channels = config.get("latent_channels", 128)
|
| 42 |
+
spatial_padding_mode = PaddingModeType(config.get("spatial_padding_mode", "reflect"))
|
| 43 |
+
out_channels = config.get("out_channels", 3)
|
| 44 |
+
decoder_blocks = config.get("decoder_blocks", [])
|
| 45 |
+
patch_size = config.get("patch_size", 4)
|
| 46 |
+
norm_layer_str = config.get("norm_layer", "pixel_norm")
|
| 47 |
+
causal = config.get("causal_decoder", False)
|
| 48 |
+
timestep_conditioning = config.get("timestep_conditioning", True)
|
| 49 |
+
base_channels = config.get("decoder_base_channels", 128)
|
| 50 |
+
|
| 51 |
+
return VideoDecoder(
|
| 52 |
+
convolution_dimensions=convolution_dimensions,
|
| 53 |
+
in_channels=latent_channels,
|
| 54 |
+
out_channels=out_channels,
|
| 55 |
+
decoder_blocks=decoder_blocks,
|
| 56 |
+
patch_size=patch_size,
|
| 57 |
+
norm_layer=NormLayerType(norm_layer_str),
|
| 58 |
+
causal=causal,
|
| 59 |
+
timestep_conditioning=timestep_conditioning,
|
| 60 |
+
decoder_spatial_padding_mode=spatial_padding_mode,
|
| 61 |
+
base_channels=base_channels,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
VAE_DECODER_COMFY_KEYS_FILTER = (
|
| 66 |
+
SDOps("VAE_DECODER_COMFY_KEYS_FILTER")
|
| 67 |
+
.with_matching(prefix="vae.decoder.")
|
| 68 |
+
.with_matching(prefix="vae.per_channel_statistics.")
|
| 69 |
+
.with_replacement("vae.decoder.", "")
|
| 70 |
+
.with_replacement("vae.per_channel_statistics.", "per_channel_statistics.")
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
VAE_ENCODER_COMFY_KEYS_FILTER = (
|
| 74 |
+
SDOps("VAE_ENCODER_COMFY_KEYS_FILTER")
|
| 75 |
+
.with_matching(prefix="vae.encoder.")
|
| 76 |
+
.with_matching(prefix="vae.per_channel_statistics.")
|
| 77 |
+
.with_replacement("vae.encoder.", "")
|
| 78 |
+
.with_replacement("vae.per_channel_statistics.", "per_channel_statistics.")
|
| 79 |
+
)
|
packages/ltx-core/src/ltx_core/model/video_vae/resnet.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from ltx_core.model.common.normalization import PixelNorm
|
| 7 |
+
from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings
|
| 8 |
+
from ltx_core.model.video_vae.convolution import make_conv_nd, make_linear_nd
|
| 9 |
+
from ltx_core.model.video_vae.enums import NormLayerType, PaddingModeType
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ResnetBlock3D(nn.Module):
|
| 13 |
+
r"""
|
| 14 |
+
A Resnet block.
|
| 15 |
+
Parameters:
|
| 16 |
+
in_channels (`int`): The number of channels in the input.
|
| 17 |
+
out_channels (`int`, *optional*, default to be `None`):
|
| 18 |
+
The number of output channels for the first conv layer. If None, same as `in_channels`.
|
| 19 |
+
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
| 20 |
+
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
| 21 |
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
dims: Union[int, Tuple[int, int]],
|
| 27 |
+
in_channels: int,
|
| 28 |
+
out_channels: Optional[int] = None,
|
| 29 |
+
dropout: float = 0.0,
|
| 30 |
+
groups: int = 32,
|
| 31 |
+
eps: float = 1e-6,
|
| 32 |
+
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
| 33 |
+
inject_noise: bool = False,
|
| 34 |
+
timestep_conditioning: bool = False,
|
| 35 |
+
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
| 36 |
+
):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.in_channels = in_channels
|
| 39 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 40 |
+
self.out_channels = out_channels
|
| 41 |
+
self.inject_noise = inject_noise
|
| 42 |
+
|
| 43 |
+
if norm_layer == NormLayerType.GROUP_NORM:
|
| 44 |
+
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
| 45 |
+
elif norm_layer == NormLayerType.PIXEL_NORM:
|
| 46 |
+
self.norm1 = PixelNorm()
|
| 47 |
+
|
| 48 |
+
self.non_linearity = nn.SiLU()
|
| 49 |
+
|
| 50 |
+
self.conv1 = make_conv_nd(
|
| 51 |
+
dims,
|
| 52 |
+
in_channels,
|
| 53 |
+
out_channels,
|
| 54 |
+
kernel_size=3,
|
| 55 |
+
stride=1,
|
| 56 |
+
padding=1,
|
| 57 |
+
causal=True,
|
| 58 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if inject_noise:
|
| 62 |
+
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
| 63 |
+
|
| 64 |
+
if norm_layer == NormLayerType.GROUP_NORM:
|
| 65 |
+
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
| 66 |
+
elif norm_layer == NormLayerType.PIXEL_NORM:
|
| 67 |
+
self.norm2 = PixelNorm()
|
| 68 |
+
|
| 69 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 70 |
+
|
| 71 |
+
self.conv2 = make_conv_nd(
|
| 72 |
+
dims,
|
| 73 |
+
out_channels,
|
| 74 |
+
out_channels,
|
| 75 |
+
kernel_size=3,
|
| 76 |
+
stride=1,
|
| 77 |
+
padding=1,
|
| 78 |
+
causal=True,
|
| 79 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if inject_noise:
|
| 83 |
+
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
| 84 |
+
|
| 85 |
+
self.conv_shortcut = (
|
| 86 |
+
make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels)
|
| 87 |
+
if in_channels != out_channels
|
| 88 |
+
else nn.Identity()
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Using GroupNorm with 1 group is equivalent to LayerNorm but works with (B, C, ...) layout
|
| 92 |
+
# avoiding the need for dimension rearrangement used in standard nn.LayerNorm
|
| 93 |
+
self.norm3 = (
|
| 94 |
+
nn.GroupNorm(num_groups=1, num_channels=in_channels, eps=eps, affine=True)
|
| 95 |
+
if in_channels != out_channels
|
| 96 |
+
else nn.Identity()
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
self.timestep_conditioning = timestep_conditioning
|
| 100 |
+
|
| 101 |
+
if timestep_conditioning:
|
| 102 |
+
self.scale_shift_table = nn.Parameter(torch.zeros(4, in_channels))
|
| 103 |
+
|
| 104 |
+
def _feed_spatial_noise(
|
| 105 |
+
self,
|
| 106 |
+
hidden_states: torch.Tensor,
|
| 107 |
+
per_channel_scale: torch.Tensor,
|
| 108 |
+
generator: Optional[torch.Generator] = None,
|
| 109 |
+
) -> torch.Tensor:
|
| 110 |
+
spatial_shape = hidden_states.shape[-2:]
|
| 111 |
+
device = hidden_states.device
|
| 112 |
+
dtype = hidden_states.dtype
|
| 113 |
+
|
| 114 |
+
# similar to the "explicit noise inputs" method in style-gan
|
| 115 |
+
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype, generator=generator)[None]
|
| 116 |
+
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
|
| 117 |
+
hidden_states = hidden_states + scaled_noise
|
| 118 |
+
|
| 119 |
+
return hidden_states
|
| 120 |
+
|
| 121 |
+
def forward(
|
| 122 |
+
self,
|
| 123 |
+
input_tensor: torch.Tensor,
|
| 124 |
+
causal: bool = True,
|
| 125 |
+
timestep: Optional[torch.Tensor] = None,
|
| 126 |
+
generator: Optional[torch.Generator] = None,
|
| 127 |
+
) -> torch.Tensor:
|
| 128 |
+
hidden_states = input_tensor
|
| 129 |
+
batch_size = hidden_states.shape[0]
|
| 130 |
+
|
| 131 |
+
hidden_states = self.norm1(hidden_states)
|
| 132 |
+
if self.timestep_conditioning:
|
| 133 |
+
if timestep is None:
|
| 134 |
+
raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True")
|
| 135 |
+
ada_values = self.scale_shift_table[None, ..., None, None, None].to(
|
| 136 |
+
device=hidden_states.device, dtype=hidden_states.dtype
|
| 137 |
+
) + timestep.reshape(
|
| 138 |
+
batch_size,
|
| 139 |
+
4,
|
| 140 |
+
-1,
|
| 141 |
+
timestep.shape[-3],
|
| 142 |
+
timestep.shape[-2],
|
| 143 |
+
timestep.shape[-1],
|
| 144 |
+
)
|
| 145 |
+
shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
|
| 146 |
+
|
| 147 |
+
hidden_states = hidden_states * (1 + scale1) + shift1
|
| 148 |
+
|
| 149 |
+
hidden_states = self.non_linearity(hidden_states)
|
| 150 |
+
|
| 151 |
+
hidden_states = self.conv1(hidden_states, causal=causal)
|
| 152 |
+
|
| 153 |
+
if self.inject_noise:
|
| 154 |
+
hidden_states = self._feed_spatial_noise(
|
| 155 |
+
hidden_states,
|
| 156 |
+
self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype),
|
| 157 |
+
generator=generator,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
hidden_states = self.norm2(hidden_states)
|
| 161 |
+
|
| 162 |
+
if self.timestep_conditioning:
|
| 163 |
+
hidden_states = hidden_states * (1 + scale2) + shift2
|
| 164 |
+
|
| 165 |
+
hidden_states = self.non_linearity(hidden_states)
|
| 166 |
+
|
| 167 |
+
hidden_states = self.dropout(hidden_states)
|
| 168 |
+
|
| 169 |
+
hidden_states = self.conv2(hidden_states, causal=causal)
|
| 170 |
+
|
| 171 |
+
if self.inject_noise:
|
| 172 |
+
hidden_states = self._feed_spatial_noise(
|
| 173 |
+
hidden_states,
|
| 174 |
+
self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype),
|
| 175 |
+
generator=generator,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
input_tensor = self.norm3(input_tensor)
|
| 179 |
+
|
| 180 |
+
batch_size = input_tensor.shape[0]
|
| 181 |
+
|
| 182 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
| 183 |
+
|
| 184 |
+
output_tensor = input_tensor + hidden_states
|
| 185 |
+
|
| 186 |
+
return output_tensor
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class UNetMidBlock3D(nn.Module):
|
| 190 |
+
"""
|
| 191 |
+
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
| 192 |
+
Args:
|
| 193 |
+
in_channels (`int`): The number of input channels.
|
| 194 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
| 195 |
+
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
| 196 |
+
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
| 197 |
+
resnet_groups (`int`, *optional*, defaults to 32):
|
| 198 |
+
The number of groups to use in the group normalization layers of the resnet blocks.
|
| 199 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
| 200 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 201 |
+
inject_noise (`bool`, *optional*, defaults to `False`):
|
| 202 |
+
Whether to inject noise into the hidden states.
|
| 203 |
+
timestep_conditioning (`bool`, *optional*, defaults to `False`):
|
| 204 |
+
Whether to condition the hidden states on the timestep.
|
| 205 |
+
Returns:
|
| 206 |
+
`torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
| 207 |
+
in_channels, height, width)`.
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
dims: Union[int, Tuple[int, int]],
|
| 213 |
+
in_channels: int,
|
| 214 |
+
dropout: float = 0.0,
|
| 215 |
+
num_layers: int = 1,
|
| 216 |
+
resnet_eps: float = 1e-6,
|
| 217 |
+
resnet_groups: int = 32,
|
| 218 |
+
norm_layer: NormLayerType = NormLayerType.GROUP_NORM,
|
| 219 |
+
inject_noise: bool = False,
|
| 220 |
+
timestep_conditioning: bool = False,
|
| 221 |
+
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
| 222 |
+
):
|
| 223 |
+
super().__init__()
|
| 224 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 225 |
+
|
| 226 |
+
self.timestep_conditioning = timestep_conditioning
|
| 227 |
+
|
| 228 |
+
if timestep_conditioning:
|
| 229 |
+
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
| 230 |
+
embedding_dim=in_channels * 4, size_emb_dim=0
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
self.res_blocks = nn.ModuleList(
|
| 234 |
+
[
|
| 235 |
+
ResnetBlock3D(
|
| 236 |
+
dims=dims,
|
| 237 |
+
in_channels=in_channels,
|
| 238 |
+
out_channels=in_channels,
|
| 239 |
+
eps=resnet_eps,
|
| 240 |
+
groups=resnet_groups,
|
| 241 |
+
dropout=dropout,
|
| 242 |
+
norm_layer=norm_layer,
|
| 243 |
+
inject_noise=inject_noise,
|
| 244 |
+
timestep_conditioning=timestep_conditioning,
|
| 245 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 246 |
+
)
|
| 247 |
+
for _ in range(num_layers)
|
| 248 |
+
]
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
def forward(
|
| 252 |
+
self,
|
| 253 |
+
hidden_states: torch.Tensor,
|
| 254 |
+
causal: bool = True,
|
| 255 |
+
timestep: Optional[torch.Tensor] = None,
|
| 256 |
+
generator: Optional[torch.Generator] = None,
|
| 257 |
+
) -> torch.Tensor:
|
| 258 |
+
timestep_embed = None
|
| 259 |
+
if self.timestep_conditioning:
|
| 260 |
+
if timestep is None:
|
| 261 |
+
raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True")
|
| 262 |
+
batch_size = hidden_states.shape[0]
|
| 263 |
+
timestep_embed = self.time_embedder(
|
| 264 |
+
timestep=timestep.flatten(),
|
| 265 |
+
hidden_dtype=hidden_states.dtype,
|
| 266 |
+
)
|
| 267 |
+
timestep_embed = timestep_embed.view(batch_size, timestep_embed.shape[-1], 1, 1, 1)
|
| 268 |
+
|
| 269 |
+
for resnet in self.res_blocks:
|
| 270 |
+
hidden_states = resnet(
|
| 271 |
+
hidden_states,
|
| 272 |
+
causal=causal,
|
| 273 |
+
timestep=timestep_embed,
|
| 274 |
+
generator=generator,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
return hidden_states
|
packages/ltx-core/src/ltx_core/model/video_vae/tiling.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Callable, List, NamedTuple, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def compute_trapezoidal_mask_1d(
|
| 9 |
+
length: int,
|
| 10 |
+
ramp_left: int,
|
| 11 |
+
ramp_right: int,
|
| 12 |
+
left_starts_from_0: bool = False,
|
| 13 |
+
) -> torch.Tensor:
|
| 14 |
+
"""
|
| 15 |
+
Generate a 1D trapezoidal blending mask with linear ramps.
|
| 16 |
+
Args:
|
| 17 |
+
length: Output length of the mask.
|
| 18 |
+
ramp_left: Fade-in length on the left.
|
| 19 |
+
ramp_right: Fade-out length on the right.
|
| 20 |
+
left_starts_from_0: Whether the ramp starts from 0 or first non-zero value.
|
| 21 |
+
Useful for temporal tiles where the first tile is causal.
|
| 22 |
+
Returns:
|
| 23 |
+
A 1D tensor of shape `(length,)` with values in [0, 1].
|
| 24 |
+
"""
|
| 25 |
+
if length <= 0:
|
| 26 |
+
raise ValueError("Mask length must be positive.")
|
| 27 |
+
|
| 28 |
+
ramp_left = max(0, min(ramp_left, length))
|
| 29 |
+
ramp_right = max(0, min(ramp_right, length))
|
| 30 |
+
|
| 31 |
+
mask = torch.ones(length)
|
| 32 |
+
|
| 33 |
+
if ramp_left > 0:
|
| 34 |
+
interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2
|
| 35 |
+
fade_in = torch.linspace(0.0, 1.0, interval_length)[:-1]
|
| 36 |
+
if not left_starts_from_0:
|
| 37 |
+
fade_in = fade_in[1:]
|
| 38 |
+
mask[:ramp_left] *= fade_in
|
| 39 |
+
|
| 40 |
+
if ramp_right > 0:
|
| 41 |
+
fade_out = torch.linspace(1.0, 0.0, steps=ramp_right + 2)[1:-1]
|
| 42 |
+
mask[-ramp_right:] *= fade_out
|
| 43 |
+
|
| 44 |
+
return mask.clamp_(0, 1)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def compute_rectangular_mask_1d(
|
| 48 |
+
length: int,
|
| 49 |
+
left_ramp: int,
|
| 50 |
+
right_ramp: int,
|
| 51 |
+
) -> torch.Tensor:
|
| 52 |
+
"""
|
| 53 |
+
Generate a 1D rectangular (pulse) mask.
|
| 54 |
+
Args:
|
| 55 |
+
length: Output length of the mask.
|
| 56 |
+
left_ramp: Number of elements at the start of the mask to set to 0.
|
| 57 |
+
right_ramp: Number of elements at the end of the mask to set to 0.
|
| 58 |
+
Returns:
|
| 59 |
+
A 1D tensor of shape `(length,)` with values 0 or 1.
|
| 60 |
+
"""
|
| 61 |
+
if length <= 0:
|
| 62 |
+
raise ValueError("Mask length must be positive.")
|
| 63 |
+
|
| 64 |
+
mask = torch.ones(length)
|
| 65 |
+
if left_ramp > 0:
|
| 66 |
+
mask[:left_ramp] = 0
|
| 67 |
+
if right_ramp > 0:
|
| 68 |
+
mask[-right_ramp:] = 0
|
| 69 |
+
return mask
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@dataclass(frozen=True)
|
| 73 |
+
class SpatialTilingConfig:
|
| 74 |
+
"""Configuration for dividing each frame into spatial tiles with optional overlap.
|
| 75 |
+
Args:
|
| 76 |
+
tile_size_in_pixels (int): Size of each tile in pixels. Must be at least 64 and divisible by 32.
|
| 77 |
+
tile_overlap_in_pixels (int, optional): Overlap between tiles in pixels. Must be divisible by 32. Defaults to 0.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
tile_size_in_pixels: int
|
| 81 |
+
tile_overlap_in_pixels: int = 0
|
| 82 |
+
|
| 83 |
+
def __post_init__(self) -> None:
|
| 84 |
+
if self.tile_size_in_pixels < 64:
|
| 85 |
+
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
|
| 86 |
+
if self.tile_size_in_pixels % 32 != 0:
|
| 87 |
+
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
|
| 88 |
+
if self.tile_overlap_in_pixels % 32 != 0:
|
| 89 |
+
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
|
| 90 |
+
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@dataclass(frozen=True)
|
| 97 |
+
class TemporalTilingConfig:
|
| 98 |
+
"""Configuration for dividing a video into temporal tiles (chunks of frames) with optional overlap.
|
| 99 |
+
Args:
|
| 100 |
+
tile_size_in_frames (int): Number of frames in each tile. Must be at least 16 and divisible by 8.
|
| 101 |
+
tile_overlap_in_frames (int, optional): Number of overlapping frames between consecutive tiles.
|
| 102 |
+
Must be divisible by 8. Defaults to 0.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
tile_size_in_frames: int
|
| 106 |
+
tile_overlap_in_frames: int = 0
|
| 107 |
+
|
| 108 |
+
def __post_init__(self) -> None:
|
| 109 |
+
if self.tile_size_in_frames < 16:
|
| 110 |
+
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
|
| 111 |
+
if self.tile_size_in_frames % 8 != 0:
|
| 112 |
+
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
|
| 113 |
+
if self.tile_overlap_in_frames % 8 != 0:
|
| 114 |
+
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
|
| 115 |
+
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
|
| 116 |
+
raise ValueError(
|
| 117 |
+
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@dataclass(frozen=True)
|
| 122 |
+
class TilingConfig:
|
| 123 |
+
"""Configuration for splitting video into tiles with optional overlap.
|
| 124 |
+
Attributes:
|
| 125 |
+
spatial_config: Configuration for splitting spatial dimensions into tiles.
|
| 126 |
+
temporal_config: Configuration for splitting temporal dimension into tiles.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
spatial_config: SpatialTilingConfig | None = None
|
| 130 |
+
temporal_config: TemporalTilingConfig | None = None
|
| 131 |
+
|
| 132 |
+
@classmethod
|
| 133 |
+
def default(cls) -> "TilingConfig":
|
| 134 |
+
return cls(
|
| 135 |
+
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
|
| 136 |
+
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@dataclass(frozen=True)
|
| 141 |
+
class DimensionIntervals:
|
| 142 |
+
"""Defines how a single dimension is split into overlapping intervals (tiles).
|
| 143 |
+
Each list has length N where N is the number of intervals. The i-th element
|
| 144 |
+
of each list describes the i-th interval.
|
| 145 |
+
Attributes:
|
| 146 |
+
starts: Start index of each interval (inclusive).
|
| 147 |
+
ends: End index of each interval (exclusive).
|
| 148 |
+
left_ramps: Length of the left blend ramp for each interval.
|
| 149 |
+
Used to create masks that fade in from 0 to 1.
|
| 150 |
+
right_ramps: Length of the right blend ramp for each interval.
|
| 151 |
+
Used to create masks that fade out from 1 to 0.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
starts: List[int]
|
| 155 |
+
ends: List[int]
|
| 156 |
+
left_ramps: List[int]
|
| 157 |
+
right_ramps: List[int]
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@dataclass(frozen=True)
|
| 161 |
+
class TensorTilingSpec:
|
| 162 |
+
"""Specifies how a tensor of a given shape is split into intervals (tiles) along each dimension.
|
| 163 |
+
Attributes:
|
| 164 |
+
original_shape: Shape of the tensor being tiled.
|
| 165 |
+
dimension_intervals: Per-dimension intervals (starts, ends, ramps) for each axis.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
original_shape: torch.Size
|
| 169 |
+
dimension_intervals: Tuple[DimensionIntervals, ...]
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# Operation to split a single dimension of the tensor into intervals based on the length along the dimension.
|
| 173 |
+
SplitOperation = Callable[[int], DimensionIntervals]
|
| 174 |
+
# Operation to map the intervals in input dimension to slices and masks along a corresponding output dimension.
|
| 175 |
+
MappingOperation = Callable[[DimensionIntervals], tuple[list[slice], list[torch.Tensor | None]]]
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def default_split_operation(length: int) -> DimensionIntervals:
|
| 179 |
+
return DimensionIntervals(starts=[0], ends=[length], left_ramps=[0], right_ramps=[0])
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
DEFAULT_SPLIT_OPERATION: SplitOperation = default_split_operation
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def default_mapping_operation(
|
| 186 |
+
_intervals: DimensionIntervals,
|
| 187 |
+
) -> tuple[list[slice], list[torch.Tensor | None]]:
|
| 188 |
+
return [slice(0, None)], [None]
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
DEFAULT_MAPPING_OPERATION: MappingOperation = default_mapping_operation
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class Tile(NamedTuple):
|
| 195 |
+
"""
|
| 196 |
+
Represents a single tile.
|
| 197 |
+
Attributes:
|
| 198 |
+
in_coords:
|
| 199 |
+
Tuple of slices specifying where to cut the tile from the INPUT tensor.
|
| 200 |
+
out_coords:
|
| 201 |
+
Tuple of slices specifying where this tile's OUTPUT should be placed in the reconstructed OUTPUT tensor.
|
| 202 |
+
masks_1d:
|
| 203 |
+
Per-dimension masks in OUTPUT units.
|
| 204 |
+
These are used to create all-dimensional blending mask.
|
| 205 |
+
Methods:
|
| 206 |
+
blend_mask:
|
| 207 |
+
Create a single N-D mask from the per-dimension masks.
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
in_coords: Tuple[slice, ...]
|
| 211 |
+
out_coords: Tuple[slice, ...]
|
| 212 |
+
masks_1d: Tuple[Tuple[torch.Tensor, ...]]
|
| 213 |
+
|
| 214 |
+
@property
|
| 215 |
+
def blend_mask(self) -> torch.Tensor:
|
| 216 |
+
num_dims = len(self.out_coords)
|
| 217 |
+
per_dimension_masks: List[torch.Tensor] = []
|
| 218 |
+
|
| 219 |
+
for dim_idx in range(num_dims):
|
| 220 |
+
mask_1d = self.masks_1d[dim_idx]
|
| 221 |
+
view_shape = [1] * num_dims
|
| 222 |
+
if mask_1d is None:
|
| 223 |
+
# Broadcast mask along this dimension (length 1).
|
| 224 |
+
one = torch.ones(1)
|
| 225 |
+
|
| 226 |
+
view_shape[dim_idx] = 1
|
| 227 |
+
per_dimension_masks.append(one.view(*view_shape))
|
| 228 |
+
continue
|
| 229 |
+
|
| 230 |
+
# Reshape (L,) -> (1, ..., L, ..., 1) so masks across dimensions broadcast-multiply.
|
| 231 |
+
view_shape[dim_idx] = mask_1d.shape[0]
|
| 232 |
+
per_dimension_masks.append(mask_1d.view(*view_shape))
|
| 233 |
+
|
| 234 |
+
# Multiply per-dimension masks to form the full N-D mask (separable blending window).
|
| 235 |
+
combined_mask = per_dimension_masks[0]
|
| 236 |
+
for mask in per_dimension_masks[1:]:
|
| 237 |
+
combined_mask = combined_mask * mask
|
| 238 |
+
|
| 239 |
+
return combined_mask
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def create_tiles_from_intervals_and_mappers(
|
| 243 |
+
intervals: TensorTilingSpec,
|
| 244 |
+
mappers: List[MappingOperation],
|
| 245 |
+
) -> List[Tile]:
|
| 246 |
+
full_dim_input_slices = []
|
| 247 |
+
full_dim_output_slices = []
|
| 248 |
+
full_dim_masks_1d = []
|
| 249 |
+
for axis_index in range(len(intervals.original_shape)):
|
| 250 |
+
dimension_intervals = intervals.dimension_intervals[axis_index]
|
| 251 |
+
starts = dimension_intervals.starts
|
| 252 |
+
ends = dimension_intervals.ends
|
| 253 |
+
input_slices = [slice(s, e) for s, e in zip(starts, ends, strict=True)]
|
| 254 |
+
output_slices, masks_1d = mappers[axis_index](dimension_intervals)
|
| 255 |
+
full_dim_input_slices.append(input_slices)
|
| 256 |
+
full_dim_output_slices.append(output_slices)
|
| 257 |
+
full_dim_masks_1d.append(masks_1d)
|
| 258 |
+
|
| 259 |
+
tiles = []
|
| 260 |
+
tile_in_coords = list(itertools.product(*full_dim_input_slices))
|
| 261 |
+
tile_out_coords = list(itertools.product(*full_dim_output_slices))
|
| 262 |
+
tile_mask_1ds = list(itertools.product(*full_dim_masks_1d))
|
| 263 |
+
for in_coord, out_coord, mask_1d in zip(tile_in_coords, tile_out_coords, tile_mask_1ds, strict=True):
|
| 264 |
+
tiles.append(
|
| 265 |
+
Tile(
|
| 266 |
+
in_coords=in_coord,
|
| 267 |
+
out_coords=out_coord,
|
| 268 |
+
masks_1d=mask_1d,
|
| 269 |
+
)
|
| 270 |
+
)
|
| 271 |
+
return tiles
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def create_tiles(
|
| 275 |
+
tensor_shape: torch.Size,
|
| 276 |
+
splitters: List[SplitOperation],
|
| 277 |
+
mappers: List[MappingOperation],
|
| 278 |
+
) -> List[Tile]:
|
| 279 |
+
if len(splitters) != len(tensor_shape):
|
| 280 |
+
raise ValueError(
|
| 281 |
+
f"Number of splitters must be equal to number of dimensions in tensor shape, "
|
| 282 |
+
f"got {len(splitters)} and {len(tensor_shape)}"
|
| 283 |
+
)
|
| 284 |
+
if len(mappers) != len(tensor_shape):
|
| 285 |
+
raise ValueError(
|
| 286 |
+
f"Number of mappers must be equal to number of dimensions in tensor shape, "
|
| 287 |
+
f"got {len(mappers)} and {len(tensor_shape)}"
|
| 288 |
+
)
|
| 289 |
+
intervals = [splitter(length) for splitter, length in zip(splitters, tensor_shape, strict=True)]
|
| 290 |
+
tiling_spec = TensorTilingSpec(original_shape=tensor_shape, dimension_intervals=tuple(intervals))
|
| 291 |
+
return create_tiles_from_intervals_and_mappers(tiling_spec, mappers)
|
packages/ltx-core/src/ltx_core/model/video_vae/video_vae.py
ADDED
|
@@ -0,0 +1,1219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from dataclasses import replace
|
| 3 |
+
from typing import Any, Callable, Iterator, List, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from ltx_core.model.common.normalization import PixelNorm
|
| 10 |
+
from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings
|
| 11 |
+
from ltx_core.model.video_vae.convolution import make_conv_nd
|
| 12 |
+
from ltx_core.model.video_vae.enums import LogVarianceType, NormLayerType, PaddingModeType
|
| 13 |
+
from ltx_core.model.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
| 14 |
+
from ltx_core.model.video_vae.resnet import ResnetBlock3D, UNetMidBlock3D
|
| 15 |
+
from ltx_core.model.video_vae.sampling import DepthToSpaceUpsample, SpaceToDepthDownsample
|
| 16 |
+
from ltx_core.model.video_vae.tiling import (
|
| 17 |
+
DEFAULT_MAPPING_OPERATION,
|
| 18 |
+
DEFAULT_SPLIT_OPERATION,
|
| 19 |
+
DimensionIntervals,
|
| 20 |
+
MappingOperation,
|
| 21 |
+
SplitOperation,
|
| 22 |
+
Tile,
|
| 23 |
+
TilingConfig,
|
| 24 |
+
compute_rectangular_mask_1d,
|
| 25 |
+
compute_trapezoidal_mask_1d,
|
| 26 |
+
create_tiles,
|
| 27 |
+
)
|
| 28 |
+
from ltx_core.types import VIDEO_SCALE_FACTORS, SpatioTemporalScaleFactors, VideoLatentShape
|
| 29 |
+
|
| 30 |
+
logger: logging.Logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _make_encoder_block(
|
| 34 |
+
block_name: str,
|
| 35 |
+
block_config: dict[str, Any],
|
| 36 |
+
in_channels: int,
|
| 37 |
+
convolution_dimensions: int,
|
| 38 |
+
norm_layer: NormLayerType,
|
| 39 |
+
norm_num_groups: int,
|
| 40 |
+
spatial_padding_mode: PaddingModeType,
|
| 41 |
+
) -> Tuple[nn.Module, int]:
|
| 42 |
+
out_channels = in_channels
|
| 43 |
+
|
| 44 |
+
if block_name == "res_x":
|
| 45 |
+
block = UNetMidBlock3D(
|
| 46 |
+
dims=convolution_dimensions,
|
| 47 |
+
in_channels=in_channels,
|
| 48 |
+
num_layers=block_config["num_layers"],
|
| 49 |
+
resnet_eps=1e-6,
|
| 50 |
+
resnet_groups=norm_num_groups,
|
| 51 |
+
norm_layer=norm_layer,
|
| 52 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 53 |
+
)
|
| 54 |
+
elif block_name == "res_x_y":
|
| 55 |
+
out_channels = in_channels * block_config.get("multiplier", 2)
|
| 56 |
+
block = ResnetBlock3D(
|
| 57 |
+
dims=convolution_dimensions,
|
| 58 |
+
in_channels=in_channels,
|
| 59 |
+
out_channels=out_channels,
|
| 60 |
+
eps=1e-6,
|
| 61 |
+
groups=norm_num_groups,
|
| 62 |
+
norm_layer=norm_layer,
|
| 63 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 64 |
+
)
|
| 65 |
+
elif block_name == "compress_time":
|
| 66 |
+
block = make_conv_nd(
|
| 67 |
+
dims=convolution_dimensions,
|
| 68 |
+
in_channels=in_channels,
|
| 69 |
+
out_channels=out_channels,
|
| 70 |
+
kernel_size=3,
|
| 71 |
+
stride=(2, 1, 1),
|
| 72 |
+
causal=True,
|
| 73 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 74 |
+
)
|
| 75 |
+
elif block_name == "compress_space":
|
| 76 |
+
block = make_conv_nd(
|
| 77 |
+
dims=convolution_dimensions,
|
| 78 |
+
in_channels=in_channels,
|
| 79 |
+
out_channels=out_channels,
|
| 80 |
+
kernel_size=3,
|
| 81 |
+
stride=(1, 2, 2),
|
| 82 |
+
causal=True,
|
| 83 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 84 |
+
)
|
| 85 |
+
elif block_name == "compress_all":
|
| 86 |
+
block = make_conv_nd(
|
| 87 |
+
dims=convolution_dimensions,
|
| 88 |
+
in_channels=in_channels,
|
| 89 |
+
out_channels=out_channels,
|
| 90 |
+
kernel_size=3,
|
| 91 |
+
stride=(2, 2, 2),
|
| 92 |
+
causal=True,
|
| 93 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 94 |
+
)
|
| 95 |
+
elif block_name == "compress_all_x_y":
|
| 96 |
+
out_channels = in_channels * block_config.get("multiplier", 2)
|
| 97 |
+
block = make_conv_nd(
|
| 98 |
+
dims=convolution_dimensions,
|
| 99 |
+
in_channels=in_channels,
|
| 100 |
+
out_channels=out_channels,
|
| 101 |
+
kernel_size=3,
|
| 102 |
+
stride=(2, 2, 2),
|
| 103 |
+
causal=True,
|
| 104 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 105 |
+
)
|
| 106 |
+
elif block_name == "compress_all_res":
|
| 107 |
+
out_channels = in_channels * block_config.get("multiplier", 2)
|
| 108 |
+
block = SpaceToDepthDownsample(
|
| 109 |
+
dims=convolution_dimensions,
|
| 110 |
+
in_channels=in_channels,
|
| 111 |
+
out_channels=out_channels,
|
| 112 |
+
stride=(2, 2, 2),
|
| 113 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 114 |
+
)
|
| 115 |
+
elif block_name == "compress_space_res":
|
| 116 |
+
out_channels = in_channels * block_config.get("multiplier", 2)
|
| 117 |
+
block = SpaceToDepthDownsample(
|
| 118 |
+
dims=convolution_dimensions,
|
| 119 |
+
in_channels=in_channels,
|
| 120 |
+
out_channels=out_channels,
|
| 121 |
+
stride=(1, 2, 2),
|
| 122 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 123 |
+
)
|
| 124 |
+
elif block_name == "compress_time_res":
|
| 125 |
+
out_channels = in_channels * block_config.get("multiplier", 2)
|
| 126 |
+
block = SpaceToDepthDownsample(
|
| 127 |
+
dims=convolution_dimensions,
|
| 128 |
+
in_channels=in_channels,
|
| 129 |
+
out_channels=out_channels,
|
| 130 |
+
stride=(2, 1, 1),
|
| 131 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 132 |
+
)
|
| 133 |
+
else:
|
| 134 |
+
raise ValueError(f"unknown block: {block_name}")
|
| 135 |
+
|
| 136 |
+
return block, out_channels
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class VideoEncoder(nn.Module):
|
| 140 |
+
_DEFAULT_NORM_NUM_GROUPS = 32
|
| 141 |
+
"""
|
| 142 |
+
Variational Autoencoder Encoder. Encodes video frames into a latent representation.
|
| 143 |
+
The encoder compresses the input video through a series of downsampling operations controlled by
|
| 144 |
+
patch_size and encoder_blocks. The output is a normalized latent tensor with shape (B, 128, F', H', W').
|
| 145 |
+
Compression Behavior:
|
| 146 |
+
The total compression is determined by:
|
| 147 |
+
1. Initial spatial compression via patchify: H -> H/4, W -> W/4 (patch_size=4)
|
| 148 |
+
2. Sequential compression through encoder_blocks based on their stride patterns
|
| 149 |
+
Compression blocks apply 2x compression in specified dimensions:
|
| 150 |
+
- "compress_time" / "compress_time_res": temporal only
|
| 151 |
+
- "compress_space" / "compress_space_res": spatial only (H and W)
|
| 152 |
+
- "compress_all" / "compress_all_res": all dimensions (F, H, W)
|
| 153 |
+
- "res_x" / "res_x_y": no compression
|
| 154 |
+
Standard LTX Video configuration:
|
| 155 |
+
- patch_size=4
|
| 156 |
+
- encoder_blocks: 1x compress_space_res, 1x compress_time_res, 2x compress_all_res
|
| 157 |
+
- Final dimensions: F' = 1 + (F-1)/8, H' = H/32, W' = W/32
|
| 158 |
+
- Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16)
|
| 159 |
+
- Note: Input must have 1 + 8*k frames (e.g., 1, 9, 17, 25, 33...)
|
| 160 |
+
Args:
|
| 161 |
+
convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D).
|
| 162 |
+
in_channels: The number of input channels. For RGB images, this is 3.
|
| 163 |
+
out_channels: The number of output channels (latent channels). For latent channels, this is 128.
|
| 164 |
+
encoder_blocks: The list of blocks to construct the encoder. Each block is a tuple of (block_name, params)
|
| 165 |
+
where params is either an int (num_layers) or a dict with configuration.
|
| 166 |
+
patch_size: The patch size for initial spatial compression. Should be a power of 2.
|
| 167 |
+
norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 168 |
+
latent_log_var: The log variance mode. Can be either `per_channel`, `uniform`, `constant` or `none`.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
convolution_dimensions: int = 3,
|
| 174 |
+
in_channels: int = 3,
|
| 175 |
+
out_channels: int = 128,
|
| 176 |
+
encoder_blocks: List[Tuple[str, int]] | List[Tuple[str, dict[str, Any]]] = [], # noqa: B006
|
| 177 |
+
patch_size: int = 4,
|
| 178 |
+
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
| 179 |
+
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
|
| 180 |
+
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
| 181 |
+
):
|
| 182 |
+
super().__init__()
|
| 183 |
+
|
| 184 |
+
self.patch_size = patch_size
|
| 185 |
+
self.norm_layer = norm_layer
|
| 186 |
+
self.latent_channels = out_channels
|
| 187 |
+
self.latent_log_var = latent_log_var
|
| 188 |
+
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
| 189 |
+
|
| 190 |
+
# Per-channel statistics for normalizing latents
|
| 191 |
+
self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels)
|
| 192 |
+
|
| 193 |
+
in_channels = in_channels * patch_size**2
|
| 194 |
+
feature_channels = out_channels
|
| 195 |
+
|
| 196 |
+
self.conv_in = make_conv_nd(
|
| 197 |
+
dims=convolution_dimensions,
|
| 198 |
+
in_channels=in_channels,
|
| 199 |
+
out_channels=feature_channels,
|
| 200 |
+
kernel_size=3,
|
| 201 |
+
stride=1,
|
| 202 |
+
padding=1,
|
| 203 |
+
causal=True,
|
| 204 |
+
spatial_padding_mode=encoder_spatial_padding_mode,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
self.down_blocks = nn.ModuleList([])
|
| 208 |
+
|
| 209 |
+
for block_name, block_params in encoder_blocks:
|
| 210 |
+
# Convert int to dict format for uniform handling
|
| 211 |
+
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
| 212 |
+
|
| 213 |
+
block, feature_channels = _make_encoder_block(
|
| 214 |
+
block_name=block_name,
|
| 215 |
+
block_config=block_config,
|
| 216 |
+
in_channels=feature_channels,
|
| 217 |
+
convolution_dimensions=convolution_dimensions,
|
| 218 |
+
norm_layer=norm_layer,
|
| 219 |
+
norm_num_groups=self._norm_num_groups,
|
| 220 |
+
spatial_padding_mode=encoder_spatial_padding_mode,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
self.down_blocks.append(block)
|
| 224 |
+
|
| 225 |
+
# out
|
| 226 |
+
if norm_layer == NormLayerType.GROUP_NORM:
|
| 227 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6)
|
| 228 |
+
elif norm_layer == NormLayerType.PIXEL_NORM:
|
| 229 |
+
self.conv_norm_out = PixelNorm()
|
| 230 |
+
|
| 231 |
+
self.conv_act = nn.SiLU()
|
| 232 |
+
|
| 233 |
+
conv_out_channels = out_channels
|
| 234 |
+
if latent_log_var == LogVarianceType.PER_CHANNEL:
|
| 235 |
+
conv_out_channels *= 2
|
| 236 |
+
elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
|
| 237 |
+
conv_out_channels += 1
|
| 238 |
+
elif latent_log_var != LogVarianceType.NONE:
|
| 239 |
+
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
| 240 |
+
|
| 241 |
+
self.conv_out = make_conv_nd(
|
| 242 |
+
dims=convolution_dimensions,
|
| 243 |
+
in_channels=feature_channels,
|
| 244 |
+
out_channels=conv_out_channels,
|
| 245 |
+
kernel_size=3,
|
| 246 |
+
padding=1,
|
| 247 |
+
causal=True,
|
| 248 |
+
spatial_padding_mode=encoder_spatial_padding_mode,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
| 252 |
+
r"""
|
| 253 |
+
Encode video frames into normalized latent representation.
|
| 254 |
+
Args:
|
| 255 |
+
sample: Input video (B, C, F, H, W). F should be 1 + 8*k (e.g., 1, 9, 17, 25, 33...).
|
| 256 |
+
If not, the encoder crops the last frames to the nearest valid length.
|
| 257 |
+
Returns:
|
| 258 |
+
Normalized latent means (B, 128, F', H', W') where F' = 1+(F-1)/8, H' = H/32, W' = W/32.
|
| 259 |
+
Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16).
|
| 260 |
+
"""
|
| 261 |
+
# Validate frame count (crop to nearest valid length if needed)
|
| 262 |
+
frames_count = sample.shape[2]
|
| 263 |
+
if ((frames_count - 1) % 8) != 0:
|
| 264 |
+
frames_to_crop = (frames_count - 1) % 8
|
| 265 |
+
logger.warning(
|
| 266 |
+
"Invalid number of frames %s for encode; cropping last %s frames to satisfy 1 + 8*k.",
|
| 267 |
+
frames_count,
|
| 268 |
+
frames_to_crop,
|
| 269 |
+
)
|
| 270 |
+
sample = sample[:, :, :-frames_to_crop, ...]
|
| 271 |
+
|
| 272 |
+
# Initial spatial compression: trade spatial resolution for channel depth
|
| 273 |
+
# This reduces H,W by patch_size and increases channels, making convolutions more efficient
|
| 274 |
+
# Example: (B, 3, F, 512, 512) -> (B, 48, F, 128, 128) with patch_size=4
|
| 275 |
+
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
| 276 |
+
sample = self.conv_in(sample)
|
| 277 |
+
|
| 278 |
+
for down_block in self.down_blocks:
|
| 279 |
+
sample = down_block(sample)
|
| 280 |
+
|
| 281 |
+
sample = self.conv_norm_out(sample)
|
| 282 |
+
sample = self.conv_act(sample)
|
| 283 |
+
sample = self.conv_out(sample)
|
| 284 |
+
|
| 285 |
+
if self.latent_log_var == LogVarianceType.UNIFORM:
|
| 286 |
+
# Uniform Variance: model outputs N means and 1 shared log-variance channel.
|
| 287 |
+
# We need to expand the single logvar to match the number of means channels
|
| 288 |
+
# to create a format compatible with PER_CHANNEL (means + logvar, each with N channels).
|
| 289 |
+
# Sample shape: (B, N+1, ...) where N = latent_channels (e.g., 128 means + 1 logvar = 129)
|
| 290 |
+
# Target shape: (B, 2*N, ...) where first N are means, last N are logvar
|
| 291 |
+
|
| 292 |
+
if sample.shape[1] < 2:
|
| 293 |
+
raise ValueError(
|
| 294 |
+
f"Invalid channel count for UNIFORM mode: expected at least 2 channels "
|
| 295 |
+
f"(N means + 1 logvar), got {sample.shape[1]}"
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Extract means (first N channels) and logvar (last 1 channel)
|
| 299 |
+
means = sample[:, :-1, ...] # (B, N, ...)
|
| 300 |
+
logvar = sample[:, -1:, ...] # (B, 1, ...)
|
| 301 |
+
|
| 302 |
+
# Repeat logvar N times to match means channels
|
| 303 |
+
# Use expand/repeat pattern that works for both 4D and 5D tensors
|
| 304 |
+
num_channels = means.shape[1]
|
| 305 |
+
repeat_shape = [1, num_channels] + [1] * (sample.ndim - 2)
|
| 306 |
+
repeated_logvar = logvar.repeat(*repeat_shape) # (B, N, ...)
|
| 307 |
+
|
| 308 |
+
# Concatenate to create (B, 2*N, ...) format: [means, repeated_logvar]
|
| 309 |
+
sample = torch.cat([means, repeated_logvar], dim=1)
|
| 310 |
+
elif self.latent_log_var == LogVarianceType.CONSTANT:
|
| 311 |
+
sample = sample[:, :-1, ...]
|
| 312 |
+
approx_ln_0 = -30 # this is the minimal clamp value in DiagonalGaussianDistribution objects
|
| 313 |
+
sample = torch.cat(
|
| 314 |
+
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
|
| 315 |
+
dim=1,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# Split into means and logvar, then normalize means
|
| 319 |
+
means, _ = torch.chunk(sample, 2, dim=1)
|
| 320 |
+
return self.per_channel_statistics.normalize(means)
|
| 321 |
+
|
| 322 |
+
def tiled_encode(
|
| 323 |
+
self,
|
| 324 |
+
video: torch.Tensor,
|
| 325 |
+
tiling_config: TilingConfig | None = None,
|
| 326 |
+
) -> torch.Tensor:
|
| 327 |
+
"""Encode video to latent using tiled processing of the given video tensor.
|
| 328 |
+
Device Handling:
|
| 329 |
+
- Input video can be on CPU or GPU
|
| 330 |
+
- Accumulation buffers are created on model's device
|
| 331 |
+
- Each tile is automatically moved to model's device before encoding
|
| 332 |
+
- Output latent is returned on model's device
|
| 333 |
+
Args:
|
| 334 |
+
video: Input video tensor (B, 3, F, H, W) in range [-1, 1]
|
| 335 |
+
tiling_config: Tiling configuration for the video tensor
|
| 336 |
+
Returns:
|
| 337 |
+
Latent tensor (B, 128, F', H', W') on model's device
|
| 338 |
+
where F' = 1 + (F-1)/8, H' = H/32, W' = W/32
|
| 339 |
+
"""
|
| 340 |
+
# Detect model device and dtype
|
| 341 |
+
model_device = next(self.parameters()).device
|
| 342 |
+
model_dtype = next(self.parameters()).dtype
|
| 343 |
+
|
| 344 |
+
# Extract shape components
|
| 345 |
+
batch, _, frames, height, width = video.shape
|
| 346 |
+
|
| 347 |
+
# Check frame count and crop if needed
|
| 348 |
+
if (frames - 1) % VIDEO_SCALE_FACTORS.time != 0:
|
| 349 |
+
frames_to_crop = (frames - 1) % VIDEO_SCALE_FACTORS.time
|
| 350 |
+
logger.warning(
|
| 351 |
+
f"Number of frames {frames} of input video is not ({VIDEO_SCALE_FACTORS.time} * k + 1), "
|
| 352 |
+
f"last {frames_to_crop} frames will be cropped"
|
| 353 |
+
)
|
| 354 |
+
video = video[:, :, :-frames_to_crop, ...]
|
| 355 |
+
# Update frames after cropping
|
| 356 |
+
frames = video.shape[2]
|
| 357 |
+
|
| 358 |
+
# Calculate output latent shape (inverse of upscale)
|
| 359 |
+
latent_shape = VideoLatentShape(
|
| 360 |
+
batch=batch,
|
| 361 |
+
channels=self.latent_channels, # 128 for standard VAE
|
| 362 |
+
frames=(frames - 1) // VIDEO_SCALE_FACTORS.time + 1,
|
| 363 |
+
height=height // VIDEO_SCALE_FACTORS.height,
|
| 364 |
+
width=width // VIDEO_SCALE_FACTORS.width,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# Prepare tiles (operates on VIDEO dimensions)
|
| 368 |
+
tiles = prepare_tiles_for_encoding(video, tiling_config)
|
| 369 |
+
|
| 370 |
+
# Initialize accumulation buffers on model device
|
| 371 |
+
latent_buffer = torch.zeros(
|
| 372 |
+
latent_shape.to_torch_shape(),
|
| 373 |
+
device=model_device,
|
| 374 |
+
dtype=model_dtype,
|
| 375 |
+
)
|
| 376 |
+
weights_buffer = torch.zeros_like(latent_buffer)
|
| 377 |
+
|
| 378 |
+
# Process each tile
|
| 379 |
+
for tile in tiles:
|
| 380 |
+
# Extract video tile from input (may be on CPU)
|
| 381 |
+
video_tile = video[tile.in_coords]
|
| 382 |
+
|
| 383 |
+
# Move tile to model device if needed
|
| 384 |
+
if video_tile.device != model_device or video_tile.dtype != model_dtype:
|
| 385 |
+
video_tile = video_tile.to(device=model_device, dtype=model_dtype)
|
| 386 |
+
|
| 387 |
+
# Encode tile to latent (output on model device)
|
| 388 |
+
latent_tile = self.forward(video_tile)
|
| 389 |
+
|
| 390 |
+
# Move blend mask to model device
|
| 391 |
+
mask = tile.blend_mask.to(
|
| 392 |
+
device=model_device,
|
| 393 |
+
dtype=model_dtype,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# Weighted accumulation in latent space
|
| 397 |
+
latent_buffer[tile.out_coords] += latent_tile * mask
|
| 398 |
+
weights_buffer[tile.out_coords] += mask
|
| 399 |
+
|
| 400 |
+
del latent_tile, mask, video_tile
|
| 401 |
+
|
| 402 |
+
# Normalize by accumulated weights
|
| 403 |
+
weights_buffer = weights_buffer.clamp(min=1e-8)
|
| 404 |
+
return latent_buffer / weights_buffer
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def prepare_tiles_for_encoding(
|
| 408 |
+
video: torch.Tensor,
|
| 409 |
+
tiling_config: TilingConfig | None = None,
|
| 410 |
+
) -> List[Tile]:
|
| 411 |
+
"""Prepare tiles for VAE encoding.
|
| 412 |
+
Args:
|
| 413 |
+
video: Input video tensor (B, 3, F, H, W) in range [-1, 1]
|
| 414 |
+
tiling_config: Tiling configuration for the video tensor
|
| 415 |
+
Returns:
|
| 416 |
+
List of tiles for the video tensor
|
| 417 |
+
"""
|
| 418 |
+
|
| 419 |
+
splitters = [DEFAULT_SPLIT_OPERATION] * len(video.shape)
|
| 420 |
+
mappers = [DEFAULT_MAPPING_OPERATION] * len(video.shape)
|
| 421 |
+
minimum_spatial_overlap_px = 64
|
| 422 |
+
minimum_temporal_overlap_frames = 16
|
| 423 |
+
|
| 424 |
+
if tiling_config is not None and tiling_config.spatial_config is not None:
|
| 425 |
+
cfg = tiling_config.spatial_config
|
| 426 |
+
|
| 427 |
+
tile_size_px = cfg.tile_size_in_pixels
|
| 428 |
+
overlap_px = cfg.tile_overlap_in_pixels
|
| 429 |
+
|
| 430 |
+
# Set minimum spatial overlap to 64 pixels in order to allow cutting padding from
|
| 431 |
+
# the front and back of the tiles and concatenate tiles without artifacts.
|
| 432 |
+
# The encoder uses symmetric padding (pad=1) in H and W at each conv layer. At tile
|
| 433 |
+
# boundaries, convs see padding (zeros/reflect) instead of real neighbor pixels, causing
|
| 434 |
+
# incorrect context near edges.
|
| 435 |
+
# For each overlap we discard 1 latent per edge (32px at scale 32) and concatenate tiles at a
|
| 436 |
+
# shared region with the next tile.
|
| 437 |
+
if overlap_px < minimum_spatial_overlap_px:
|
| 438 |
+
logger.warning(
|
| 439 |
+
f"Overlap pixels {overlap_px} in spatial tiling is less than \
|
| 440 |
+
{minimum_spatial_overlap_px}, setting to minimum required {minimum_spatial_overlap_px}"
|
| 441 |
+
)
|
| 442 |
+
overlap_px = minimum_spatial_overlap_px
|
| 443 |
+
|
| 444 |
+
# Define split and map operations for the spatial dimensions
|
| 445 |
+
|
| 446 |
+
# Height axis (H)
|
| 447 |
+
splitters[3] = split_with_symmetric_overlaps(tile_size_px, overlap_px)
|
| 448 |
+
mappers[3] = make_mapping_operation(map_spatial_interval_to_latent, scale=VIDEO_SCALE_FACTORS.height)
|
| 449 |
+
|
| 450 |
+
# Width axis (W)
|
| 451 |
+
splitters[4] = split_with_symmetric_overlaps(tile_size_px, overlap_px)
|
| 452 |
+
mappers[4] = make_mapping_operation(map_spatial_interval_to_latent, scale=VIDEO_SCALE_FACTORS.width)
|
| 453 |
+
|
| 454 |
+
if tiling_config is not None and tiling_config.temporal_config is not None:
|
| 455 |
+
cfg = tiling_config.temporal_config
|
| 456 |
+
tile_size_frames = cfg.tile_size_in_frames
|
| 457 |
+
overlap_frames = cfg.tile_overlap_in_frames
|
| 458 |
+
|
| 459 |
+
if overlap_frames < minimum_temporal_overlap_frames:
|
| 460 |
+
logger.warning(f"Overlap frames {overlap_frames} is less than 16, setting to minimum required 16")
|
| 461 |
+
overlap_frames = minimum_temporal_overlap_frames
|
| 462 |
+
|
| 463 |
+
splitters[2] = split_temporal_frames(tile_size_frames, overlap_frames)
|
| 464 |
+
mappers[2] = make_mapping_operation(map_temporal_interval_to_latent, scale=VIDEO_SCALE_FACTORS.time)
|
| 465 |
+
|
| 466 |
+
return create_tiles(video.shape, splitters, mappers)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def _make_decoder_block(
|
| 470 |
+
block_name: str,
|
| 471 |
+
block_config: dict[str, Any],
|
| 472 |
+
in_channels: int,
|
| 473 |
+
convolution_dimensions: int,
|
| 474 |
+
norm_layer: NormLayerType,
|
| 475 |
+
timestep_conditioning: bool,
|
| 476 |
+
norm_num_groups: int,
|
| 477 |
+
spatial_padding_mode: PaddingModeType,
|
| 478 |
+
) -> Tuple[nn.Module, int]:
|
| 479 |
+
out_channels = in_channels
|
| 480 |
+
if block_name == "res_x":
|
| 481 |
+
block = UNetMidBlock3D(
|
| 482 |
+
dims=convolution_dimensions,
|
| 483 |
+
in_channels=in_channels,
|
| 484 |
+
num_layers=block_config["num_layers"],
|
| 485 |
+
resnet_eps=1e-6,
|
| 486 |
+
resnet_groups=norm_num_groups,
|
| 487 |
+
norm_layer=norm_layer,
|
| 488 |
+
inject_noise=block_config.get("inject_noise", False),
|
| 489 |
+
timestep_conditioning=timestep_conditioning,
|
| 490 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 491 |
+
)
|
| 492 |
+
elif block_name == "attn_res_x":
|
| 493 |
+
block = UNetMidBlock3D(
|
| 494 |
+
dims=convolution_dimensions,
|
| 495 |
+
in_channels=in_channels,
|
| 496 |
+
num_layers=block_config["num_layers"],
|
| 497 |
+
resnet_groups=norm_num_groups,
|
| 498 |
+
norm_layer=norm_layer,
|
| 499 |
+
inject_noise=block_config.get("inject_noise", False),
|
| 500 |
+
timestep_conditioning=timestep_conditioning,
|
| 501 |
+
attention_head_dim=block_config["attention_head_dim"],
|
| 502 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 503 |
+
)
|
| 504 |
+
elif block_name == "res_x_y":
|
| 505 |
+
out_channels = in_channels // block_config.get("multiplier", 2)
|
| 506 |
+
block = ResnetBlock3D(
|
| 507 |
+
dims=convolution_dimensions,
|
| 508 |
+
in_channels=in_channels,
|
| 509 |
+
out_channels=out_channels,
|
| 510 |
+
eps=1e-6,
|
| 511 |
+
groups=norm_num_groups,
|
| 512 |
+
norm_layer=norm_layer,
|
| 513 |
+
inject_noise=block_config.get("inject_noise", False),
|
| 514 |
+
timestep_conditioning=False,
|
| 515 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 516 |
+
)
|
| 517 |
+
elif block_name == "compress_time":
|
| 518 |
+
out_channels = in_channels // block_config.get("multiplier", 1)
|
| 519 |
+
block = DepthToSpaceUpsample(
|
| 520 |
+
dims=convolution_dimensions,
|
| 521 |
+
in_channels=in_channels,
|
| 522 |
+
stride=(2, 1, 1),
|
| 523 |
+
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
| 524 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 525 |
+
)
|
| 526 |
+
elif block_name == "compress_space":
|
| 527 |
+
out_channels = in_channels // block_config.get("multiplier", 1)
|
| 528 |
+
block = DepthToSpaceUpsample(
|
| 529 |
+
dims=convolution_dimensions,
|
| 530 |
+
in_channels=in_channels,
|
| 531 |
+
stride=(1, 2, 2),
|
| 532 |
+
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
| 533 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 534 |
+
)
|
| 535 |
+
elif block_name == "compress_all":
|
| 536 |
+
out_channels = in_channels // block_config.get("multiplier", 1)
|
| 537 |
+
block = DepthToSpaceUpsample(
|
| 538 |
+
dims=convolution_dimensions,
|
| 539 |
+
in_channels=in_channels,
|
| 540 |
+
stride=(2, 2, 2),
|
| 541 |
+
residual=block_config.get("residual", False),
|
| 542 |
+
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
| 543 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 544 |
+
)
|
| 545 |
+
else:
|
| 546 |
+
raise ValueError(f"unknown layer: {block_name}")
|
| 547 |
+
|
| 548 |
+
return block, out_channels
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
class VideoDecoder(nn.Module):
|
| 552 |
+
_DEFAULT_NORM_NUM_GROUPS = 32
|
| 553 |
+
"""
|
| 554 |
+
Variational Autoencoder Decoder. Decodes latent representation into video frames.
|
| 555 |
+
The decoder upsamples latents through a series of upsampling operations (inverse of encoder).
|
| 556 |
+
Output dimensions: F = 8x(F'-1) + 1, H = 32xH', W = 32xW' for standard LTX Video configuration.
|
| 557 |
+
Upsampling blocks expand dimensions by 2x in specified dimensions:
|
| 558 |
+
- "compress_time": temporal only
|
| 559 |
+
- "compress_space": spatial only (H and W)
|
| 560 |
+
- "compress_all": all dimensions (F, H, W)
|
| 561 |
+
- "res_x" / "res_x_y" / "attn_res_x": no upsampling
|
| 562 |
+
Causal Mode:
|
| 563 |
+
causal=False (standard): Symmetric padding, allows future frame dependencies.
|
| 564 |
+
causal=True: Causal padding, each frame depends only on past/current frames.
|
| 565 |
+
First frame removed after temporal upsampling in both modes. Output shape unchanged.
|
| 566 |
+
Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512) for both modes.
|
| 567 |
+
Args:
|
| 568 |
+
convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D).
|
| 569 |
+
in_channels: The number of input channels (latent channels). Default is 128.
|
| 570 |
+
out_channels: The number of output channels. For RGB images, this is 3.
|
| 571 |
+
decoder_blocks: The list of blocks to construct the decoder. Each block is a tuple of (block_name, params)
|
| 572 |
+
where params is either an int (num_layers) or a dict with configuration.
|
| 573 |
+
patch_size: Final spatial expansion factor. For standard LTX Video, use 4 for 4x spatial expansion:
|
| 574 |
+
H -> Hx4, W -> Wx4. Should be a power of 2.
|
| 575 |
+
norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 576 |
+
causal: Whether to use causal convolutions. For standard LTX Video, use False for symmetric padding.
|
| 577 |
+
When True, uses causal padding (past/current frames only).
|
| 578 |
+
timestep_conditioning: Whether to condition the decoder on timestep for denoising.
|
| 579 |
+
"""
|
| 580 |
+
|
| 581 |
+
def __init__(
|
| 582 |
+
self,
|
| 583 |
+
convolution_dimensions: int = 3,
|
| 584 |
+
in_channels: int = 128,
|
| 585 |
+
out_channels: int = 3,
|
| 586 |
+
decoder_blocks: List[Tuple[str, int | dict]] = [], # noqa: B006
|
| 587 |
+
patch_size: int = 4,
|
| 588 |
+
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
| 589 |
+
causal: bool = False,
|
| 590 |
+
timestep_conditioning: bool = False,
|
| 591 |
+
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
| 592 |
+
base_channels: int = 128,
|
| 593 |
+
):
|
| 594 |
+
super().__init__()
|
| 595 |
+
|
| 596 |
+
# Spatiotemporal downscaling between decoded video space and VAE latents.
|
| 597 |
+
# According to the LTXV paper, the standard configuration downsamples
|
| 598 |
+
# video inputs by a factor of 8 in the temporal dimension and 32 in
|
| 599 |
+
# each spatial dimension (height and width). This parameter determines how
|
| 600 |
+
# many video frames and pixels correspond to a single latent cell.
|
| 601 |
+
self.video_downscale_factors = SpatioTemporalScaleFactors(
|
| 602 |
+
time=8,
|
| 603 |
+
width=32,
|
| 604 |
+
height=32,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
self.patch_size = patch_size
|
| 608 |
+
out_channels = out_channels * patch_size**2
|
| 609 |
+
self.causal = causal
|
| 610 |
+
self.timestep_conditioning = timestep_conditioning
|
| 611 |
+
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
| 612 |
+
|
| 613 |
+
# Per-channel statistics for denormalizing latents
|
| 614 |
+
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
|
| 615 |
+
|
| 616 |
+
# Noise and timestep parameters for decoder conditioning
|
| 617 |
+
self.decode_noise_scale = 0.025
|
| 618 |
+
self.decode_timestep = 0.05
|
| 619 |
+
|
| 620 |
+
# LTX VAE decoder architecture uses 3 upsampler blocks with multiplier equals to 2.
|
| 621 |
+
# Hence the total feature_channels is multiplied by 8 (2^3).
|
| 622 |
+
feature_channels = base_channels * 8
|
| 623 |
+
|
| 624 |
+
self.conv_in = make_conv_nd(
|
| 625 |
+
dims=convolution_dimensions,
|
| 626 |
+
in_channels=in_channels,
|
| 627 |
+
out_channels=feature_channels,
|
| 628 |
+
kernel_size=3,
|
| 629 |
+
stride=1,
|
| 630 |
+
padding=1,
|
| 631 |
+
causal=True,
|
| 632 |
+
spatial_padding_mode=decoder_spatial_padding_mode,
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
self.up_blocks = nn.ModuleList([])
|
| 636 |
+
|
| 637 |
+
for block_name, block_params in list(reversed(decoder_blocks)):
|
| 638 |
+
# Convert int to dict format for uniform handling
|
| 639 |
+
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
| 640 |
+
|
| 641 |
+
block, feature_channels = _make_decoder_block(
|
| 642 |
+
block_name=block_name,
|
| 643 |
+
block_config=block_config,
|
| 644 |
+
in_channels=feature_channels,
|
| 645 |
+
convolution_dimensions=convolution_dimensions,
|
| 646 |
+
norm_layer=norm_layer,
|
| 647 |
+
timestep_conditioning=timestep_conditioning,
|
| 648 |
+
norm_num_groups=self._norm_num_groups,
|
| 649 |
+
spatial_padding_mode=decoder_spatial_padding_mode,
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
self.up_blocks.append(block)
|
| 653 |
+
|
| 654 |
+
if norm_layer == NormLayerType.GROUP_NORM:
|
| 655 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6)
|
| 656 |
+
elif norm_layer == NormLayerType.PIXEL_NORM:
|
| 657 |
+
self.conv_norm_out = PixelNorm()
|
| 658 |
+
|
| 659 |
+
self.conv_act = nn.SiLU()
|
| 660 |
+
self.conv_out = make_conv_nd(
|
| 661 |
+
dims=convolution_dimensions,
|
| 662 |
+
in_channels=feature_channels,
|
| 663 |
+
out_channels=out_channels,
|
| 664 |
+
kernel_size=3,
|
| 665 |
+
padding=1,
|
| 666 |
+
causal=True,
|
| 667 |
+
spatial_padding_mode=decoder_spatial_padding_mode,
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
if timestep_conditioning:
|
| 671 |
+
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0))
|
| 672 |
+
self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
| 673 |
+
embedding_dim=feature_channels * 2, size_emb_dim=0
|
| 674 |
+
)
|
| 675 |
+
self.last_scale_shift_table = nn.Parameter(torch.empty(2, feature_channels))
|
| 676 |
+
|
| 677 |
+
def forward(
|
| 678 |
+
self,
|
| 679 |
+
sample: torch.Tensor,
|
| 680 |
+
timestep: torch.Tensor | None = None,
|
| 681 |
+
generator: torch.Generator | None = None,
|
| 682 |
+
) -> torch.Tensor:
|
| 683 |
+
r"""
|
| 684 |
+
Decode latent representation into video frames.
|
| 685 |
+
Args:
|
| 686 |
+
sample: Latent tensor (B, 128, F', H', W').
|
| 687 |
+
timestep: Timestep for conditioning (if timestep_conditioning=True). Uses default 0.05 if None.
|
| 688 |
+
generator: Random generator for deterministic noise injection (if inject_noise=True in blocks).
|
| 689 |
+
Returns:
|
| 690 |
+
Decoded video (B, 3, F, H, W) where F = 8x(F'-1) + 1, H = 32xH', W = 32xW'.
|
| 691 |
+
Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512).
|
| 692 |
+
Note: First frame is removed after temporal upsampling regardless of causal mode.
|
| 693 |
+
When causal=False, allows future frame dependencies in convolutions but maintains same output shape.
|
| 694 |
+
"""
|
| 695 |
+
batch_size = sample.shape[0]
|
| 696 |
+
|
| 697 |
+
# Add noise if timestep conditioning is enabled
|
| 698 |
+
if self.timestep_conditioning:
|
| 699 |
+
noise = (
|
| 700 |
+
torch.randn(
|
| 701 |
+
sample.size(),
|
| 702 |
+
generator=generator,
|
| 703 |
+
dtype=sample.dtype,
|
| 704 |
+
device=sample.device,
|
| 705 |
+
)
|
| 706 |
+
* self.decode_noise_scale
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
| 710 |
+
|
| 711 |
+
# Denormalize latents
|
| 712 |
+
sample = self.per_channel_statistics.un_normalize(sample)
|
| 713 |
+
|
| 714 |
+
# Use default decode_timestep if timestep not provided
|
| 715 |
+
if timestep is None and self.timestep_conditioning:
|
| 716 |
+
timestep = torch.full((batch_size,), self.decode_timestep, device=sample.device, dtype=sample.dtype)
|
| 717 |
+
|
| 718 |
+
sample = self.conv_in(sample, causal=self.causal)
|
| 719 |
+
|
| 720 |
+
scaled_timestep = None
|
| 721 |
+
if self.timestep_conditioning:
|
| 722 |
+
if timestep is None:
|
| 723 |
+
raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True")
|
| 724 |
+
scaled_timestep = timestep * self.timestep_scale_multiplier.to(sample)
|
| 725 |
+
|
| 726 |
+
for up_block in self.up_blocks:
|
| 727 |
+
if isinstance(up_block, UNetMidBlock3D):
|
| 728 |
+
block_kwargs = {
|
| 729 |
+
"causal": self.causal,
|
| 730 |
+
"timestep": scaled_timestep if self.timestep_conditioning else None,
|
| 731 |
+
"generator": generator,
|
| 732 |
+
}
|
| 733 |
+
sample = up_block(sample, **block_kwargs)
|
| 734 |
+
elif isinstance(up_block, ResnetBlock3D):
|
| 735 |
+
sample = up_block(sample, causal=self.causal, generator=generator)
|
| 736 |
+
else:
|
| 737 |
+
sample = up_block(sample, causal=self.causal)
|
| 738 |
+
|
| 739 |
+
sample = self.conv_norm_out(sample)
|
| 740 |
+
|
| 741 |
+
if self.timestep_conditioning:
|
| 742 |
+
embedded_timestep = self.last_time_embedder(
|
| 743 |
+
timestep=scaled_timestep.flatten(),
|
| 744 |
+
hidden_dtype=sample.dtype,
|
| 745 |
+
)
|
| 746 |
+
embedded_timestep = embedded_timestep.view(batch_size, embedded_timestep.shape[-1], 1, 1, 1)
|
| 747 |
+
ada_values = self.last_scale_shift_table[None, ..., None, None, None].to(
|
| 748 |
+
device=sample.device, dtype=sample.dtype
|
| 749 |
+
) + embedded_timestep.reshape(
|
| 750 |
+
batch_size,
|
| 751 |
+
2,
|
| 752 |
+
-1,
|
| 753 |
+
embedded_timestep.shape[-3],
|
| 754 |
+
embedded_timestep.shape[-2],
|
| 755 |
+
embedded_timestep.shape[-1],
|
| 756 |
+
)
|
| 757 |
+
shift, scale = ada_values.unbind(dim=1)
|
| 758 |
+
sample = sample * (1 + scale) + shift
|
| 759 |
+
|
| 760 |
+
sample = self.conv_act(sample)
|
| 761 |
+
sample = self.conv_out(sample, causal=self.causal)
|
| 762 |
+
|
| 763 |
+
# Final spatial expansion: reverse the initial patchify from encoder
|
| 764 |
+
# Moves pixels from channels back to spatial dimensions
|
| 765 |
+
# Example: (B, 48, F, 128, 128) -> (B, 3, F, 512, 512) with patch_size=4
|
| 766 |
+
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
| 767 |
+
|
| 768 |
+
return sample
|
| 769 |
+
|
| 770 |
+
def _prepare_tiles(
|
| 771 |
+
self,
|
| 772 |
+
latent: torch.Tensor,
|
| 773 |
+
tiling_config: TilingConfig | None = None,
|
| 774 |
+
) -> List[Tile]:
|
| 775 |
+
splitters = [DEFAULT_SPLIT_OPERATION] * len(latent.shape)
|
| 776 |
+
mappers = [DEFAULT_MAPPING_OPERATION] * len(latent.shape)
|
| 777 |
+
if tiling_config is not None and tiling_config.spatial_config is not None:
|
| 778 |
+
cfg = tiling_config.spatial_config
|
| 779 |
+
long_side = max(latent.shape[3], latent.shape[4])
|
| 780 |
+
|
| 781 |
+
def enable_on_axis(axis_idx: int, factor: int) -> None:
|
| 782 |
+
size = cfg.tile_size_in_pixels // factor
|
| 783 |
+
overlap = cfg.tile_overlap_in_pixels // factor
|
| 784 |
+
axis_length = latent.shape[axis_idx]
|
| 785 |
+
lower_threshold = max(2, overlap + 1)
|
| 786 |
+
tile_size = max(lower_threshold, round(size * axis_length / long_side))
|
| 787 |
+
splitters[axis_idx] = split_with_symmetric_overlaps(tile_size, overlap)
|
| 788 |
+
mappers[axis_idx] = make_mapping_operation(map_spatial_interval_to_pixel, scale=factor)
|
| 789 |
+
|
| 790 |
+
enable_on_axis(3, self.video_downscale_factors.height)
|
| 791 |
+
enable_on_axis(4, self.video_downscale_factors.width)
|
| 792 |
+
|
| 793 |
+
if tiling_config is not None and tiling_config.temporal_config is not None:
|
| 794 |
+
cfg = tiling_config.temporal_config
|
| 795 |
+
tile_size = cfg.tile_size_in_frames // self.video_downscale_factors.time
|
| 796 |
+
overlap = cfg.tile_overlap_in_frames // self.video_downscale_factors.time
|
| 797 |
+
splitters[2] = split_temporal_latents(tile_size, overlap)
|
| 798 |
+
mappers[2] = make_mapping_operation(map_temporal_interval_to_frame, scale=self.video_downscale_factors.time)
|
| 799 |
+
|
| 800 |
+
return create_tiles(latent.shape, splitters, mappers)
|
| 801 |
+
|
| 802 |
+
def tiled_decode(
|
| 803 |
+
self,
|
| 804 |
+
latent: torch.Tensor,
|
| 805 |
+
tiling_config: TilingConfig | None = None,
|
| 806 |
+
timestep: torch.Tensor | None = None,
|
| 807 |
+
generator: torch.Generator | None = None,
|
| 808 |
+
) -> Iterator[torch.Tensor]:
|
| 809 |
+
"""
|
| 810 |
+
Decode a latent tensor into video frames using tiled processing.
|
| 811 |
+
Splits the latent tensor into tiles, decodes each tile individually,
|
| 812 |
+
and yields video chunks as they become available.
|
| 813 |
+
Args:
|
| 814 |
+
latent: Input latent tensor (B, C, F', H', W').
|
| 815 |
+
tiling_config: Tiling configuration for the latent tensor.
|
| 816 |
+
timestep: Optional timestep for decoder conditioning.
|
| 817 |
+
generator: Optional random generator for deterministic decoding.
|
| 818 |
+
Yields:
|
| 819 |
+
Video chunks (B, C, T, H, W) by temporal slices;
|
| 820 |
+
"""
|
| 821 |
+
|
| 822 |
+
# Calculate full video shape from latent shape to get spatial dimensions
|
| 823 |
+
full_video_shape = VideoLatentShape.from_torch_shape(latent.shape).upscale(self.video_downscale_factors)
|
| 824 |
+
tiles = self._prepare_tiles(latent, tiling_config)
|
| 825 |
+
|
| 826 |
+
temporal_groups = self._group_tiles_by_temporal_slice(tiles)
|
| 827 |
+
|
| 828 |
+
# State for temporal overlap handling
|
| 829 |
+
previous_chunk = None
|
| 830 |
+
previous_weights = None
|
| 831 |
+
previous_temporal_slice = None
|
| 832 |
+
|
| 833 |
+
for temporal_group_tiles in temporal_groups:
|
| 834 |
+
curr_temporal_slice = temporal_group_tiles[0].out_coords[2]
|
| 835 |
+
|
| 836 |
+
# Calculate the shape of the temporal buffer for this group of tiles.
|
| 837 |
+
# The temporal length depends on whether this is the first tile (starts at 0) or not.
|
| 838 |
+
# - First tile: (frames - 1) * scale + 1
|
| 839 |
+
# - Subsequent tiles: frames * scale
|
| 840 |
+
# This logic is handled by TemporalAxisMapping and reflected in out_coords.
|
| 841 |
+
temporal_tile_buffer_shape = full_video_shape._replace(
|
| 842 |
+
frames=curr_temporal_slice.stop - curr_temporal_slice.start,
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
buffer = torch.zeros(
|
| 846 |
+
temporal_tile_buffer_shape.to_torch_shape(),
|
| 847 |
+
device=latent.device,
|
| 848 |
+
dtype=latent.dtype,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
curr_weights = self._accumulate_temporal_group_into_buffer(
|
| 852 |
+
group_tiles=temporal_group_tiles,
|
| 853 |
+
buffer=buffer,
|
| 854 |
+
latent=latent,
|
| 855 |
+
timestep=timestep,
|
| 856 |
+
generator=generator,
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
# Blend with previous temporal chunk if it exists
|
| 860 |
+
if previous_chunk is not None:
|
| 861 |
+
# Check if current temporal slice overlaps with previous temporal slice
|
| 862 |
+
if previous_temporal_slice.stop > curr_temporal_slice.start:
|
| 863 |
+
overlap_len = previous_temporal_slice.stop - curr_temporal_slice.start
|
| 864 |
+
temporal_overlap_slice = slice(curr_temporal_slice.start - previous_temporal_slice.start, None)
|
| 865 |
+
|
| 866 |
+
# The overlap is already masked before it reaches this step. Each tile is accumulated into buffer
|
| 867 |
+
# with its trapezoidal mask, and curr_weights accumulates the same mask. In the overlap blend we add
|
| 868 |
+
# the masked values (buffer[...]) and the corresponding weights (curr_weights[...]) into the
|
| 869 |
+
# previous buffers, then later normalize by weights.
|
| 870 |
+
previous_chunk[:, :, temporal_overlap_slice, :, :] += buffer[:, :, slice(0, overlap_len), :, :]
|
| 871 |
+
previous_weights[:, :, temporal_overlap_slice, :, :] += curr_weights[
|
| 872 |
+
:, :, slice(0, overlap_len), :, :
|
| 873 |
+
]
|
| 874 |
+
|
| 875 |
+
buffer[:, :, slice(0, overlap_len), :, :] = previous_chunk[:, :, temporal_overlap_slice, :, :]
|
| 876 |
+
curr_weights[:, :, slice(0, overlap_len), :, :] = previous_weights[
|
| 877 |
+
:, :, temporal_overlap_slice, :, :
|
| 878 |
+
]
|
| 879 |
+
|
| 880 |
+
# Yield the non-overlapping part of the previous chunk
|
| 881 |
+
previous_weights = previous_weights.clamp(min=1e-8)
|
| 882 |
+
yield_len = curr_temporal_slice.start - previous_temporal_slice.start
|
| 883 |
+
yield (previous_chunk / previous_weights)[:, :, :yield_len, :, :]
|
| 884 |
+
|
| 885 |
+
# Update state for next iteration
|
| 886 |
+
previous_chunk = buffer
|
| 887 |
+
previous_weights = curr_weights
|
| 888 |
+
previous_temporal_slice = curr_temporal_slice
|
| 889 |
+
|
| 890 |
+
# Yield any remaining chunk
|
| 891 |
+
if previous_chunk is not None:
|
| 892 |
+
previous_weights = previous_weights.clamp(min=1e-8)
|
| 893 |
+
yield previous_chunk / previous_weights
|
| 894 |
+
|
| 895 |
+
def _group_tiles_by_temporal_slice(self, tiles: List[Tile]) -> List[List[Tile]]:
|
| 896 |
+
"""Group tiles by their temporal output slice."""
|
| 897 |
+
if not tiles:
|
| 898 |
+
return []
|
| 899 |
+
|
| 900 |
+
groups = []
|
| 901 |
+
current_slice = tiles[0].out_coords[2]
|
| 902 |
+
current_group = []
|
| 903 |
+
|
| 904 |
+
for tile in tiles:
|
| 905 |
+
tile_slice = tile.out_coords[2]
|
| 906 |
+
if tile_slice == current_slice:
|
| 907 |
+
current_group.append(tile)
|
| 908 |
+
else:
|
| 909 |
+
groups.append(current_group)
|
| 910 |
+
current_slice = tile_slice
|
| 911 |
+
current_group = [tile]
|
| 912 |
+
|
| 913 |
+
# Add the final group
|
| 914 |
+
if current_group:
|
| 915 |
+
groups.append(current_group)
|
| 916 |
+
|
| 917 |
+
return groups
|
| 918 |
+
|
| 919 |
+
def _accumulate_temporal_group_into_buffer(
|
| 920 |
+
self,
|
| 921 |
+
group_tiles: List[Tile],
|
| 922 |
+
buffer: torch.Tensor,
|
| 923 |
+
latent: torch.Tensor,
|
| 924 |
+
timestep: torch.Tensor | None,
|
| 925 |
+
generator: torch.Generator | None,
|
| 926 |
+
) -> torch.Tensor:
|
| 927 |
+
"""
|
| 928 |
+
Decode and accumulate all tiles of a temporal group into a local buffer.
|
| 929 |
+
The buffer is local to the group and always starts at time 0; temporal coordinates
|
| 930 |
+
are rebased by subtracting temporal_slice.start.
|
| 931 |
+
"""
|
| 932 |
+
temporal_slice = group_tiles[0].out_coords[2]
|
| 933 |
+
|
| 934 |
+
weights = torch.zeros_like(buffer)
|
| 935 |
+
|
| 936 |
+
for tile in group_tiles:
|
| 937 |
+
decoded_tile = self.forward(latent[tile.in_coords], timestep, generator)
|
| 938 |
+
mask = tile.blend_mask.to(device=buffer.device, dtype=buffer.dtype)
|
| 939 |
+
temporal_offset = tile.out_coords[2].start - temporal_slice.start
|
| 940 |
+
# Use the tile's output coordinate length, not the decoded tile's length,
|
| 941 |
+
# as the decoder may produce a different number of frames than expected
|
| 942 |
+
expected_temporal_len = tile.out_coords[2].stop - tile.out_coords[2].start
|
| 943 |
+
decoded_temporal_len = decoded_tile.shape[2]
|
| 944 |
+
|
| 945 |
+
# Ensure we don't exceed the buffer or decoded tile bounds
|
| 946 |
+
actual_temporal_len = min(expected_temporal_len, decoded_temporal_len, buffer.shape[2] - temporal_offset)
|
| 947 |
+
|
| 948 |
+
chunk_coords = (
|
| 949 |
+
slice(None), # batch
|
| 950 |
+
slice(None), # channels
|
| 951 |
+
slice(temporal_offset, temporal_offset + actual_temporal_len),
|
| 952 |
+
tile.out_coords[3], # height
|
| 953 |
+
tile.out_coords[4], # width
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
# Slice decoded_tile and mask to match the actual length we're writing
|
| 957 |
+
decoded_slice = decoded_tile[:, :, :actual_temporal_len, :, :]
|
| 958 |
+
mask_slice = mask[:, :, :actual_temporal_len, :, :] if mask.shape[2] > 1 else mask
|
| 959 |
+
|
| 960 |
+
buffer[chunk_coords] += decoded_slice * mask_slice
|
| 961 |
+
weights[chunk_coords] += mask_slice
|
| 962 |
+
|
| 963 |
+
return weights
|
| 964 |
+
|
| 965 |
+
|
| 966 |
+
def decode_video(
|
| 967 |
+
latent: torch.Tensor,
|
| 968 |
+
video_decoder: VideoDecoder,
|
| 969 |
+
tiling_config: TilingConfig | None = None,
|
| 970 |
+
generator: torch.Generator | None = None,
|
| 971 |
+
) -> Iterator[torch.Tensor]:
|
| 972 |
+
"""
|
| 973 |
+
Decode a video latent tensor with the given decoder.
|
| 974 |
+
Args:
|
| 975 |
+
latent: Tensor [c, f, h, w]
|
| 976 |
+
video_decoder: Decoder module.
|
| 977 |
+
tiling_config: Optional tiling settings.
|
| 978 |
+
generator: Optional random generator for deterministic decoding.
|
| 979 |
+
Yields:
|
| 980 |
+
Decoded chunk [f, h, w, c], uint8 in [0, 255].
|
| 981 |
+
"""
|
| 982 |
+
|
| 983 |
+
def convert_to_uint8(frames: torch.Tensor) -> torch.Tensor:
|
| 984 |
+
frames = (((frames + 1.0) / 2.0).clamp(0.0, 1.0) * 255.0).to(torch.uint8)
|
| 985 |
+
frames = rearrange(frames[0], "c f h w -> f h w c")
|
| 986 |
+
return frames
|
| 987 |
+
|
| 988 |
+
if tiling_config is not None:
|
| 989 |
+
for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator):
|
| 990 |
+
yield convert_to_uint8(frames)
|
| 991 |
+
else:
|
| 992 |
+
decoded_video = video_decoder(latent, generator=generator)
|
| 993 |
+
yield convert_to_uint8(decoded_video)
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int:
|
| 997 |
+
"""
|
| 998 |
+
Get the number of video chunks for a given number of frames and tiling configuration.
|
| 999 |
+
Args:
|
| 1000 |
+
num_frames: Number of frames in the video.
|
| 1001 |
+
tiling_config: Tiling configuration.
|
| 1002 |
+
Returns:
|
| 1003 |
+
Number of video chunks.
|
| 1004 |
+
"""
|
| 1005 |
+
if not tiling_config or not tiling_config.temporal_config:
|
| 1006 |
+
return 1
|
| 1007 |
+
cfg = tiling_config.temporal_config
|
| 1008 |
+
frame_stride = cfg.tile_size_in_frames - cfg.tile_overlap_in_frames
|
| 1009 |
+
return (num_frames - 1 + frame_stride - 1) // frame_stride
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
def split_with_symmetric_overlaps(size: int, overlap: int) -> SplitOperation:
|
| 1013 |
+
def split(dimension_size: int) -> DimensionIntervals:
|
| 1014 |
+
if dimension_size <= size:
|
| 1015 |
+
return DEFAULT_SPLIT_OPERATION(dimension_size)
|
| 1016 |
+
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
|
| 1017 |
+
starts = [i * (size - overlap) for i in range(amount)]
|
| 1018 |
+
ends = [start + size for start in starts]
|
| 1019 |
+
ends[-1] = dimension_size
|
| 1020 |
+
left_ramps = [0] + [overlap] * (amount - 1)
|
| 1021 |
+
right_ramps = [overlap] * (amount - 1) + [0]
|
| 1022 |
+
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
|
| 1023 |
+
|
| 1024 |
+
return split
|
| 1025 |
+
|
| 1026 |
+
|
| 1027 |
+
def split_temporal_latents(size: int, overlap: int) -> SplitOperation:
|
| 1028 |
+
"""Split a temporal axis into overlapping tiles with causal handling.
|
| 1029 |
+
Example with size=24, overlap=8 (units are whatever axis you split):
|
| 1030 |
+
Non-causal split would produce:
|
| 1031 |
+
Tile 0: [0, 24), left_ramp=0, right_ramp=8
|
| 1032 |
+
Tile 1: [16, 40), left_ramp=8, right_ramp=8
|
| 1033 |
+
Tile 2: [32, 56), left_ramp=8, right_ramp=0
|
| 1034 |
+
Causal split produces:
|
| 1035 |
+
Tile 0: [0, 24), left_ramp=0, right_ramp=8 (unchanged - starts at anchor)
|
| 1036 |
+
Tile 1: [15, 40), left_ramp=9, right_ramp=8 (shifted back 1, ramp +1)
|
| 1037 |
+
Tile 2: [31, 56), left_ramp=9, right_ramp=0 (shifted back 1, ramp +1)
|
| 1038 |
+
This ensures each tile can causally depend on frames from previous tiles while maintaining
|
| 1039 |
+
proper temporal continuity through the blend ramps.
|
| 1040 |
+
Args:
|
| 1041 |
+
size: Tile size in *axis units* (latent steps for LTX time tiling)
|
| 1042 |
+
overlap: Overlap between tiles in the same units
|
| 1043 |
+
Returns:
|
| 1044 |
+
Split operation that divides temporal dimension with causal handling
|
| 1045 |
+
"""
|
| 1046 |
+
non_causal_split = split_with_symmetric_overlaps(size, overlap)
|
| 1047 |
+
|
| 1048 |
+
def split(dimension_size: int) -> DimensionIntervals:
|
| 1049 |
+
if dimension_size <= size:
|
| 1050 |
+
return DEFAULT_SPLIT_OPERATION(dimension_size)
|
| 1051 |
+
intervals = non_causal_split(dimension_size)
|
| 1052 |
+
|
| 1053 |
+
starts = intervals.starts
|
| 1054 |
+
starts[1:] = [s - 1 for s in starts[1:]]
|
| 1055 |
+
|
| 1056 |
+
# Extend blend ramps by 1 for non-first tiles to blend over the extra frame
|
| 1057 |
+
left_ramps = intervals.left_ramps
|
| 1058 |
+
left_ramps[1:] = [r + 1 for r in left_ramps[1:]]
|
| 1059 |
+
|
| 1060 |
+
return replace(intervals, starts=starts, left_ramps=left_ramps)
|
| 1061 |
+
|
| 1062 |
+
return split
|
| 1063 |
+
|
| 1064 |
+
|
| 1065 |
+
def split_temporal_frames(tile_size_frames: int, overlap_frames: int) -> SplitOperation:
|
| 1066 |
+
"""Split a temporal axis in video frame space into overlapping tiles.
|
| 1067 |
+
Args:
|
| 1068 |
+
tile_size_frames: Tile length in frames.
|
| 1069 |
+
overlap_frames: Overlap between consecutive tiles in frames.
|
| 1070 |
+
Returns:
|
| 1071 |
+
Split operation that takes frame count and returns DimensionIntervals in frame indices.
|
| 1072 |
+
"""
|
| 1073 |
+
non_causal_split = split_with_symmetric_overlaps(tile_size_frames, overlap_frames)
|
| 1074 |
+
|
| 1075 |
+
def split(dimension_size: int) -> DimensionIntervals:
|
| 1076 |
+
if dimension_size <= tile_size_frames:
|
| 1077 |
+
return DEFAULT_SPLIT_OPERATION(dimension_size)
|
| 1078 |
+
intervals = non_causal_split(dimension_size)
|
| 1079 |
+
ends = intervals.ends
|
| 1080 |
+
ends[:-1] = [e + 1 for e in ends[:-1]]
|
| 1081 |
+
right_ramps = [0] * len(intervals.right_ramps)
|
| 1082 |
+
return replace(intervals, ends=ends, right_ramps=right_ramps)
|
| 1083 |
+
|
| 1084 |
+
return split
|
| 1085 |
+
|
| 1086 |
+
|
| 1087 |
+
def make_mapping_operation(
|
| 1088 |
+
map_func: Callable[[int, int, int, int, int], Tuple[slice, torch.Tensor | None]],
|
| 1089 |
+
scale: int,
|
| 1090 |
+
) -> MappingOperation:
|
| 1091 |
+
"""Create a mapping operation over a set of tiling intervals.
|
| 1092 |
+
The given mapping function is applied to each interval in the input dimension. The result function is used for
|
| 1093 |
+
creating tiles in the output dimension.
|
| 1094 |
+
Args:
|
| 1095 |
+
map_func: Mapping function to create the mapping operation from
|
| 1096 |
+
scale: Scale factor for the transformation, used as an argument for the mapping function
|
| 1097 |
+
Returns:
|
| 1098 |
+
Mapping operation that takes a set of tiling intervals and returns a set of slices and masks in the output
|
| 1099 |
+
dimension.
|
| 1100 |
+
"""
|
| 1101 |
+
|
| 1102 |
+
def map_op(intervals: DimensionIntervals) -> tuple[list[slice], list[torch.Tensor | None]]:
|
| 1103 |
+
output_slices: list[slice] = []
|
| 1104 |
+
masks_1d: list[torch.Tensor | None] = []
|
| 1105 |
+
number_of_slices = len(intervals.starts)
|
| 1106 |
+
for i in range(number_of_slices):
|
| 1107 |
+
start = intervals.starts[i]
|
| 1108 |
+
end = intervals.ends[i]
|
| 1109 |
+
left_ramp = intervals.left_ramps[i]
|
| 1110 |
+
right_ramp = intervals.right_ramps[i]
|
| 1111 |
+
output_slice, mask_1d = map_func(start, end, left_ramp, right_ramp, scale)
|
| 1112 |
+
output_slices.append(output_slice)
|
| 1113 |
+
masks_1d.append(mask_1d)
|
| 1114 |
+
return output_slices, masks_1d
|
| 1115 |
+
|
| 1116 |
+
return map_op
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
def map_temporal_interval_to_frame(
|
| 1120 |
+
begin: int,
|
| 1121 |
+
end: int,
|
| 1122 |
+
left_ramp: int,
|
| 1123 |
+
right_ramp: int,
|
| 1124 |
+
scale: int,
|
| 1125 |
+
) -> Tuple[slice, torch.Tensor]:
|
| 1126 |
+
"""Map temporal interval in latent space to video frame space.
|
| 1127 |
+
Args:
|
| 1128 |
+
begin: Start position in latent space
|
| 1129 |
+
end: End position in latent space
|
| 1130 |
+
left_ramp: Left ramp size in latent space
|
| 1131 |
+
right_ramp: Right ramp size in latent space
|
| 1132 |
+
scale: Scale factor for transformation
|
| 1133 |
+
Returns:
|
| 1134 |
+
Tuple of (output_slice, blend_mask)
|
| 1135 |
+
"""
|
| 1136 |
+
start = begin * scale
|
| 1137 |
+
stop = 1 + (end - 1) * scale
|
| 1138 |
+
|
| 1139 |
+
left_ramp_frames = 0 if left_ramp == 0 else 1 + (left_ramp - 1) * scale
|
| 1140 |
+
right_ramp_frames = right_ramp * scale
|
| 1141 |
+
|
| 1142 |
+
mask_1d = compute_trapezoidal_mask_1d(stop - start, left_ramp_frames, right_ramp_frames, True)
|
| 1143 |
+
return slice(start, stop), mask_1d
|
| 1144 |
+
|
| 1145 |
+
|
| 1146 |
+
def map_temporal_interval_to_latent(
|
| 1147 |
+
begin: int, end: int, left_ramp: int, right_ramp: int | None = None, scale: int = 1
|
| 1148 |
+
) -> Tuple[slice, torch.Tensor]:
|
| 1149 |
+
"""
|
| 1150 |
+
Map temporal interval in video frame space to latent space.
|
| 1151 |
+
Args:
|
| 1152 |
+
begin: Start position in video frame space
|
| 1153 |
+
end: End position in video frame space
|
| 1154 |
+
left_ramp: Left ramp size in video frame space
|
| 1155 |
+
right_ramp: Right ramp size in video frame space
|
| 1156 |
+
scale: Scale factor for transformation
|
| 1157 |
+
Returns:
|
| 1158 |
+
Tuple of (output_slice, blend_mask)
|
| 1159 |
+
"""
|
| 1160 |
+
start = begin // scale
|
| 1161 |
+
stop = (end - 1) // scale + 1
|
| 1162 |
+
|
| 1163 |
+
left_ramp_latents = 0 if left_ramp == 0 else 1 + (left_ramp - 1) // scale
|
| 1164 |
+
right_ramp_latents = right_ramp // scale
|
| 1165 |
+
|
| 1166 |
+
if right_ramp_latents != 0:
|
| 1167 |
+
raise ValueError("For tiled encoding, temporal tiles are expected to have a right ramp equal to 0")
|
| 1168 |
+
|
| 1169 |
+
mask_1d = compute_rectangular_mask_1d(stop - start, left_ramp_latents, right_ramp_latents)
|
| 1170 |
+
|
| 1171 |
+
return slice(start, stop), mask_1d
|
| 1172 |
+
|
| 1173 |
+
|
| 1174 |
+
def map_spatial_interval_to_pixel(
|
| 1175 |
+
begin: int,
|
| 1176 |
+
end: int,
|
| 1177 |
+
left_ramp: int,
|
| 1178 |
+
right_ramp: int,
|
| 1179 |
+
scale: int,
|
| 1180 |
+
) -> Tuple[slice, torch.Tensor]:
|
| 1181 |
+
"""Map spatial interval in latent space to pixel space.
|
| 1182 |
+
Args:
|
| 1183 |
+
begin: Start position in latent space
|
| 1184 |
+
end: End position in latent space
|
| 1185 |
+
left_ramp: Left ramp size in latent space
|
| 1186 |
+
right_ramp: Right ramp size in latent space
|
| 1187 |
+
scale: Scale factor for transformation
|
| 1188 |
+
"""
|
| 1189 |
+
start = begin * scale
|
| 1190 |
+
stop = end * scale
|
| 1191 |
+
mask_1d = compute_trapezoidal_mask_1d(stop - start, left_ramp * scale, right_ramp * scale, False)
|
| 1192 |
+
return slice(start, stop), mask_1d
|
| 1193 |
+
|
| 1194 |
+
|
| 1195 |
+
def map_spatial_interval_to_latent(
|
| 1196 |
+
begin: int,
|
| 1197 |
+
end: int,
|
| 1198 |
+
left_ramp: int,
|
| 1199 |
+
right_ramp: int,
|
| 1200 |
+
scale: int,
|
| 1201 |
+
) -> Tuple[slice, torch.Tensor]:
|
| 1202 |
+
"""Map spatial interval in pixel space to latent space.
|
| 1203 |
+
Args:
|
| 1204 |
+
begin: Start position in pixel space
|
| 1205 |
+
end: End position in pixel space
|
| 1206 |
+
left_ramp: Left ramp size in pixel space
|
| 1207 |
+
right_ramp: Right ramp size in pixel space
|
| 1208 |
+
scale: Scale factor for transformation
|
| 1209 |
+
Returns:
|
| 1210 |
+
Tuple of (output_slice, blend_mask)
|
| 1211 |
+
"""
|
| 1212 |
+
start = begin // scale
|
| 1213 |
+
stop = end // scale
|
| 1214 |
+
left_ramp = max(0, left_ramp // scale - 1)
|
| 1215 |
+
|
| 1216 |
+
right_ramp = 0 if right_ramp == 0 else 1
|
| 1217 |
+
|
| 1218 |
+
mask_1d = compute_rectangular_mask_1d(stop - start, left_ramp, right_ramp)
|
| 1219 |
+
return slice(start, stop), mask_1d
|
packages/ltx-core/src/ltx_core/quantization/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (597 Bytes). View file
|
|
|
packages/ltx-core/src/ltx_core/quantization/__pycache__/fp8_cast.cpython-312.pyc
ADDED
|
Binary file (7.76 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/quantization/__pycache__/fp8_scaled_mm.cpython-312.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
packages/ltx-core/src/ltx_core/quantization/__pycache__/policy.cpython-312.pyc
ADDED
|
Binary file (2.02 kB). View file
|
|
|